Skip to content

[Core] refactor transformers 2d into multiple init variants.#7491

Merged
sayakpaul merged 30 commits intomainfrom
transformers-2d-refactor-ii
Apr 3, 2024
Merged

[Core] refactor transformers 2d into multiple init variants.#7491
sayakpaul merged 30 commits intomainfrom
transformers-2d-refactor-ii

Conversation

@sayakpaul
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul commented Mar 27, 2024

What does this PR do?

For a new Transformer variant, we should do "Transformer2DModelForXXX" going forward.

I would love an initial review of the design.

Adjacent to #7489. I believe together with this PR and the current one, the class will read more modular.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu March 27, 2024 11:16
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul marked this pull request as draft March 27, 2024 12:30
@sayakpaul sayakpaul removed request for DN6 and yiyixuxu March 27, 2024 12:30
@sayakpaul sayakpaul requested review from DN6 and yiyixuxu March 28, 2024 01:55
@sayakpaul sayakpaul removed request for DN6 and yiyixuxu March 28, 2024 03:39
@sayakpaul sayakpaul changed the title [Core] refactor transformers 2d into multiple legacy variants. [Core] refactor transformers 2d into multiple init variants. Mar 28, 2024
@sayakpaul sayakpaul requested review from DN6 and yiyixuxu March 28, 2024 04:42
Copy link
Copy Markdown
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the init methods approach. Left a couple of comments.

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu March 29, 2024 07:52
@sayakpaul
Copy link
Copy Markdown
Member Author

@DN6 @yiyixuxu all the comments have been addressed. PTAL.

@sayakpaul
Copy link
Copy Markdown
Member Author

sayakpaul commented Mar 29, 2024

@DN6 @yiyixuxu the problem with the failing tests (8) are a little shaky in nature. These failures happen ONLY when the input type is continuous.

They're failing because the parameters of the Transformer2Model are changing depending on the position of its internal components (transformer_blocks, proj_in, etc.) EVEN WHEN seeded.

Earlier, the self.transformer_blocks init was between self.proj_in and self._proj_out. This PR doesn't have that kind of positioning as the final part of the main __init__() initializes self.transformer_blocks. To confirm things, I put the init of self.transformer_blocks after

self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)

and before

if self.use_linear_projection:

All the tests pass after this change.

You might ask why does this not happen for the other input types (patched and vectorized)? Because we only configure Transformer2DModel for continuous input types in tests/models/test_layers_utils.py (which is where the 8 tests are failing).

I see two potential solutions:

  • Change the assertion values in the tests
  • Add self.transformer_blocks init in _init_continuous_input() like the suggested above.

I think we shouldn't do pt. 2 and settle with pt. 1. The tests are mild in nature and don't seem to be testing for anything critical.

LMK.

@yiyixuxu
Copy link
Copy Markdown
Collaborator

I see two potential solutions:

Change the assertion values in the tests
Add self.transformer_blocks init in _init_continuous_input() like the suggested above.

I left a comment here, #7491 (comment)
I don't think we should move the transformer init inside these _init* methods just to get tests pass but I would actually in favor of doing that because it is easier to read IMO, depends on the input type, user can just go to one of the init* method instead of going back and forth;
but I don't feel strongly about this, I think it is a nice clean up regardless. So I will leave it to you guys to decide :)

@sayakpaul
Copy link
Copy Markdown
Member Author

but I would actually in favor of doing that because it is easier to read IMO, depends on the input type, user can just go to one of the init* method instead of going back and forth;

This hits the point home. So, I will follow that. Thanks!

@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented Mar 30, 2024

@sayakpaul @DN6
should we just spin out Transformer2DModelPatched and Transformer2DModelVectorized? we are pretty much already there - I think most of the work is already done.

@sayakpaul
Copy link
Copy Markdown
Member Author

Move out the classes to their own modules? I think it might be better done after #7489.

@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented Mar 30, 2024

Move out the classes to their own modules? I think it might be better done after #7489.

oh sure,
let me know what you think @DN6: we are already creating 3 separate __init__ and forward, they are practically already 3 different classes. I think the only models that use patch are pixart and dit; vectorized is only for vq diffusion but we will have to double-check

@sayakpaul
Copy link
Copy Markdown
Member Author

I have an idea of the kind of refactoring @yiyixuxu suggesting. Would prefer doing it in a separate PR clubbing it with #7489. Would that be okay?

@DN6
Copy link
Copy Markdown
Collaborator

DN6 commented Apr 1, 2024

@sayakpaul @DN6 should we just spin out Transformer2DModelPatched and Transformer2DModelVectorized? we are pretty much already there - I think most of the work is already done.

Yeah I'm cool with spinning it out into dedicated classes. If possible, could they have model specific names e.g. PixartTransformer2D?

We can handle it in a separate PR if you prefer @sayakpaul

Would we need to merge this PR if it's just going to be changed soonish? Can we just take the work done here and in #7489 and open a new PR?

@yiyixuxu Is right that we are essentially creating three different init and forwards within the same object, so the signal here is that we should break it up into three different models.

@sayakpaul
Copy link
Copy Markdown
Member Author

sayakpaul commented Apr 1, 2024

@DN6

I would still prefer to handle the spinning out portion in #7489 and merge this PR.

If possible, could they have model specific names e.g. PixartTransformer2D?

I think we can gradually tackle these phasing out cases. More like a coarse-to-fine kind of approach.

Regarding #7491 (comment), I think it's a matter of preference. So, I would like to elaborate a bit on why I am tending to go via the duplicating route for this particular case.

The current PR allows the Transformer2DModel class to have the following structure:

class Transformer2DModel:
    def __init__(self, ...):
        # init common attributes that are shared across the board
        self.in_channels = ...
        self.out_channels = ...

        # handle initialization of specific attributes, modules, etc.
        if input_type == "continuous":
            self._init_continuous_inputs()
        elif input_type == "patched":
            self._init_patched_inputs()
        elif input_type == "vectorized":
            self._init_vectorized_inputs()

    def _init_continuous_inputs(self):
        # initialize attributes specific to continuous inputs.
        ...

        # initialize blocks specific to continuous inputs.
        # input block.
        ...

        # transformer blocks.
        self.transformer_blocks = ... # <- duplicate code for other init methods too.

        # output blocks.
        ...

    # rest of the init methods follow the above structure too. 

Now, if we move out the transformer block initialization, it would look like:

class Transformer2DModel:
    def __init__(self, ...):
        # init common attributes that are shared across the board
        self.in_channels = ...
        self.out_channels = ...

        # handle initialization of specific attributes, modules, etc.
        if input_type == "continuous":
            self._init_continuous_inputs()
        elif input_type == "patched":
            self._init_patched_inputs()
        elif input_type == "vectorized":
            self._init_vectorized_inputs()

        # transformer blocks.
        self.transformer_blocks = ...

    def _init_continuous_inputs(self):
        # initialize attributes specific to continuous inputs.
        ...

        # initialize blocks specific to continuous inputs.
        # input block.
        ...

        # output blocks.
        ...

    # rest of the init methods follow the above structure too. 

I think it breaks the linearity of the reading flow plus the tests (which are minor, I agree).

Hope this provides more reasoning.

Copy link
Copy Markdown
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍🏽

@sayakpaul sayakpaul merged commit a9a5b14 into main Apr 3, 2024
@sayakpaul sayakpaul deleted the transformers-2d-refactor-ii branch April 3, 2024 07:26
noskill pushed a commit to noskill/diffusers that referenced this pull request Apr 5, 2024
…face#7491)

* refactor transformers 2d into multiple legacy variants.

* fix: init.

* fix recursive init.

* add inits.

* make transformer block creation more modular.

* complete refactor.

* remove forward

* debug

* remove legacy blocks and refactor within the module itself.

* remove print

* guard caption projection

* remove fetcher.

* reduce the number of args.

* fix: norm_type

* group variables that are shared.

* remove _get_transformer_blocks

* harmonize the init function signatures.

* transformer_blocks to common

* repeat .
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* refactor transformers 2d into multiple legacy variants.

* fix: init.

* fix recursive init.

* add inits.

* make transformer block creation more modular.

* complete refactor.

* remove forward

* debug

* remove legacy blocks and refactor within the module itself.

* remove print

* guard caption projection

* remove fetcher.

* reduce the number of args.

* fix: norm_type

* group variables that are shared.

* remove _get_transformer_blocks

* harmonize the init function signatures.

* transformer_blocks to common

* repeat .
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants