Skip to content

[Core] refactor transformer_2d forward logic into meaningful conditions.#7489

Merged
sayakpaul merged 20 commits intomainfrom
transformers-2d-refactor
Apr 10, 2024
Merged

[Core] refactor transformer_2d forward logic into meaningful conditions.#7489
sayakpaul merged 20 commits intomainfrom
transformers-2d-refactor

Conversation

@sayakpaul
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul commented Mar 27, 2024

What does this PR do?

Refactors the forward() of Transformers2DModel for easier readability.

More specifically, the PR refactors the forward() method of Transformers2DModel to have the following unified structure:

def forward(self):
    # handle attention masking. 
    ...

    # 1. input-level operations.
    if self.is_input_continuous:
        ... = self._operate_on_continuous_inputs(...)
    elif self.is_input_vectorized:
    	... = self.latent_image_embedding(...) # for vectorized inputs only this is needed.
    elif self.is_input_patches:
        ... = self._operate_on_patched_inputs(...)

    # 2. Blocks (common across all the variants)
    for block in self.transformer_blocks:
    	...

    # 3. Outputs.
    if self.is_input_continuous:
        ... = self._get_output_for_continuous_inputs(...)
    elif self.is_input_vectorized:
        output = self._get_output_for_vectorized_inputs(...)
    elif self.is_input_patches:
        output = self._get_output_for_patched_inputs(...)

    return output

About the possibility of spinning out separate transformer classes for patched and vectorized inputs, what's the consensus? Currently, we have DiT and PixArt-Alpha that use patched inputs. So, for those checkpoints (DiT and PixArt-Alpha), are we thinking of just updating the configs to use PatchedTransformer2DModel? (Same could be done for the VectorizedTransformer2DModel)

If that's the case I think it could be quite breaking of a change. Many folks use the DiT and PixArt models especially in the light of many Open SoRA initiative. Introducing a change like this could be problematic. Should we consider throwing a deprecation warning, instead?

@yiyixuxu @DN6 please let me know your thoughts.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu March 27, 2024 09:50
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Copy Markdown
Member Author

@DN6 @yiyixuxu a gentle ping.

Copy link
Copy Markdown
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Left a comment, but it's mostly a nit. Is the failing test related?

@sayakpaul
Copy link
Copy Markdown
Member Author

@DN6 WDYT?

About the possibility of spinning out separate transformer classes for patched and vectorized inputs, what's the consensus? Currently, we have DiT and PixArt-Alpha that use patched inputs. So, for those checkpoints (DiT and PixArt-Alpha), are we thinking of just updating the configs to use PatchedTransformer2DModel? (Same could be done for the VectorizedTransformer2DModel)

(refer to the OP)

@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented Apr 8, 2024

I'm cool with this PR once you and @DN6 are happy with this!

If that's the case I think it could be quite breaking of a change. Many folks use the DiT and PixArt models especially in the light of many Open SoRA initiative. Introducing a change like this could be problematic. Should we consider throwing a deprecation warning, instead?

for this, yes, I agree it is a highly used block, and I think we should give more thought to avoid breaking changes.
can potentially deprecate Tranformer2DModel and use 3 different Transformers2DModel instead: ContinousTransformers2DMode, PatchedTransformers2DModel, VectorizedTransformers2DModel;
just spin out the patched and vectorized, and sending a deprecation message for them also works I think

would love to hear your thoughts @DN6

@sayakpaul
Copy link
Copy Markdown
Member Author

Cool. Will work on the deprecation in a future PR then as we discuss the best design.

@sayakpaul
Copy link
Copy Markdown
Member Author

@DN6 I had to revert your suggestion because height and width are calculated differently when inputs are continuous and patched. Hope that's okay.

Gonna merge the PR after the CI is green.

@sayakpaul sayakpaul merged commit 44f6b85 into main Apr 10, 2024
@sayakpaul sayakpaul deleted the transformers-2d-refactor branch April 10, 2024 03:03
sayakpaul added a commit that referenced this pull request Dec 23, 2024
…ions. (#7489)

* refactor transformer_2d forward logic into meaningful conditions.

* Empty-Commit

* fix: _operate_on_patched_inputs

* fix: _operate_on_patched_inputs

* check

* fix: patch output computation block.

* fix: _operate_on_patched_inputs.

* remove print.

* move operations to blocks.

* more readability neats.

* empty commit

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Revert "Apply suggestions from code review"

This reverts commit 12178b1.

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants