[Core] refactor transformers 2d into multiple init variants.#7491
[Core] refactor transformers 2d into multiple init variants.#7491
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
DN6
left a comment
There was a problem hiding this comment.
I like the init methods approach. Left a couple of comments.
|
@DN6 @yiyixuxu the problem with the failing tests (8) are a little shaky in nature. These failures happen ONLY when the input type is continuous. They're failing because the parameters of the Earlier, the and before All the tests pass after this change. You might ask why does this not happen for the other input types (patched and vectorized)? Because we only configure I see two potential solutions:
I think we shouldn't do pt. 2 and settle with pt. 1. The tests are mild in nature and don't seem to be testing for anything critical. LMK. |
I left a comment here, #7491 (comment) |
This hits the point home. So, I will follow that. Thanks! |
|
@sayakpaul @DN6 |
|
Move out the classes to their own modules? I think it might be better done after #7489. |
oh sure, |
Yeah I'm cool with spinning it out into dedicated classes. If possible, could they have model specific names e.g. PixartTransformer2D? We can handle it in a separate PR if you prefer @sayakpaul Would we need to merge this PR if it's just going to be changed soonish? Can we just take the work done here and in #7489 and open a new PR? @yiyixuxu Is right that we are essentially creating three different init and forwards within the same object, so the signal here is that we should break it up into three different models. |
|
I would still prefer to handle the spinning out portion in #7489 and merge this PR.
I think we can gradually tackle these phasing out cases. More like a coarse-to-fine kind of approach. Regarding #7491 (comment), I think it's a matter of preference. So, I would like to elaborate a bit on why I am tending to go via the duplicating route for this particular case. The current PR allows the class Transformer2DModel:
def __init__(self, ...):
# init common attributes that are shared across the board
self.in_channels = ...
self.out_channels = ...
# handle initialization of specific attributes, modules, etc.
if input_type == "continuous":
self._init_continuous_inputs()
elif input_type == "patched":
self._init_patched_inputs()
elif input_type == "vectorized":
self._init_vectorized_inputs()
def _init_continuous_inputs(self):
# initialize attributes specific to continuous inputs.
...
# initialize blocks specific to continuous inputs.
# input block.
...
# transformer blocks.
self.transformer_blocks = ... # <- duplicate code for other init methods too.
# output blocks.
...
# rest of the init methods follow the above structure too. Now, if we move out the transformer block initialization, it would look like: class Transformer2DModel:
def __init__(self, ...):
# init common attributes that are shared across the board
self.in_channels = ...
self.out_channels = ...
# handle initialization of specific attributes, modules, etc.
if input_type == "continuous":
self._init_continuous_inputs()
elif input_type == "patched":
self._init_patched_inputs()
elif input_type == "vectorized":
self._init_vectorized_inputs()
# transformer blocks.
self.transformer_blocks = ...
def _init_continuous_inputs(self):
# initialize attributes specific to continuous inputs.
...
# initialize blocks specific to continuous inputs.
# input block.
...
# output blocks.
...
# rest of the init methods follow the above structure too. I think it breaks the linearity of the reading flow plus the tests (which are minor, I agree). Hope this provides more reasoning. |
…face#7491) * refactor transformers 2d into multiple legacy variants. * fix: init. * fix recursive init. * add inits. * make transformer block creation more modular. * complete refactor. * remove forward * debug * remove legacy blocks and refactor within the module itself. * remove print * guard caption projection * remove fetcher. * reduce the number of args. * fix: norm_type * group variables that are shared. * remove _get_transformer_blocks * harmonize the init function signatures. * transformer_blocks to common * repeat .
* refactor transformers 2d into multiple legacy variants. * fix: init. * fix recursive init. * add inits. * make transformer block creation more modular. * complete refactor. * remove forward * debug * remove legacy blocks and refactor within the module itself. * remove print * guard caption projection * remove fetcher. * reduce the number of args. * fix: norm_type * group variables that are shared. * remove _get_transformer_blocks * harmonize the init function signatures. * transformer_blocks to common * repeat .
What does this PR do?
For a new Transformer variant, we should do "Transformer2DModelForXXX" going forward.
I would love an initial review of the design.
Adjacent to #7489. I believe together with this PR and the current one, the class will read more modular.