[Core] Introduce class variants for Transformer2DModel#7647
Conversation
Transformer2Model
Transformer2ModelTransformer2DModel
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Is the plan here to eventually map the Also how feasible is it to break it up into model specific variants rather than input specific variants? e.g |
Yeah, that's the plan.
Feasible, but I am not sure if we have enough such transformer-based pipelines yet. Most of them vary across very few things (such as the norm type and a cross-attention layer). I think there is a fair trade-off to be had when deciding which variant to use. If there are too many arguments that are changing, better to use a dedicated class (like we did for the private model). If not, rely on an existing variant that is dependent on the input type. |
Transformer2DModelTransformer2DModel
|
@DN6 done. I think I have addressed all your comments. LMK. |
|
@DN6 resolved your comment on the location of |
|
LGTM. cc: @yiyixuxu in case you want to take a look too. |
yiyixuxu
left a comment
There was a problem hiding this comment.
nice!!
I left a comment - let me know if it is a concern, and feel free to merge if it's not or have addressed it
| shift, scale = ( | ||
| self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) | ||
| ).chunk(2, dim=1) | ||
| hidden_states = self.norm_out(hidden_states) | ||
| # Modulation | ||
| hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) |
There was a problem hiding this comment.
ohh ok do these tests should also fail on the current implementation - I don't think this refactor introduced any change that would cause them to fail, no?
| del module.proj_attn | ||
|
|
||
|
|
||
| class LegacyModelMixin(ModelMixin): |
| return {"hidden_states": hidden_states, "timestep": timesteps, "class_labels": class_label_ids} | ||
|
|
||
| @property | ||
| def input_shape(self): |
There was a problem hiding this comment.
are these properties used at all? maybe we can leverage them so we don't have to specify in in test_output?
can be in a separate PR if it makes sense.
There was a problem hiding this comment.
Yeah good idea. Can look into a little "input_shape" refactor in a future PR :)
* init for patches * finish patched model. * continuous transformer * vectorized transformer2d. * style. * inits. * fix-copies. * introduce DiTTransformer2DModel. * fixes * use REMAPPING as suggested by @DN6 * better logging. * add pixart transformer model. * inits. * caption_channels. * attention masking. * fix use_additional_conditions. * remove print. * debug * flatten * fix: assertion for sigma * handle remapping for modeling_utils * add tests for dit transformer2d * quality * placeholder for pixart tests * pixart tests * add _no_split_modules * add docs. * check * check * check * check * fix tests * fix tests * move Transformer output to modeling_output * move errors better and bring back use_additional_conditions attribute. * add unnecessary things from DiT. * clean up pixart * fix remapping * fix device_map things in pixart2d. * replace Transformer2DModel with appropriate classes in dit, pixart tests * empty * legacy mixin classes./ * use a remapping dict for fetching class names. * change to specifc model types in the pipeline implementations. * move _fetch_remapped_cls_from_config to modeling_loading_utils.py * fix dependency problems. * add deprecation note.
| def test_pixart_512_without_resolution_binning(self): | ||
| generator = torch.manual_seed(0) | ||
|
|
||
| transformer = Transformer2DModel.from_pretrained( |
There was a problem hiding this comment.
we should have kept this test, can we add it back, and name it test_pixart_512_without_resolution_binning_legacy_class or something like this?
ane make sure to have a similar slow test for dit
in the future, I think we should always kept the test with the legacy class name, no? so that we can make sure that everything still work fine from the old API
cc @DN6
* init for patches * finish patched model. * continuous transformer * vectorized transformer2d. * style. * inits. * fix-copies. * introduce DiTTransformer2DModel. * fixes * use REMAPPING as suggested by @DN6 * better logging. * add pixart transformer model. * inits. * caption_channels. * attention masking. * fix use_additional_conditions. * remove print. * debug * flatten * fix: assertion for sigma * handle remapping for modeling_utils * add tests for dit transformer2d * quality * placeholder for pixart tests * pixart tests * add _no_split_modules * add docs. * check * check * check * check * fix tests * fix tests * move Transformer output to modeling_output * move errors better and bring back use_additional_conditions attribute. * add unnecessary things from DiT. * clean up pixart * fix remapping * fix device_map things in pixart2d. * replace Transformer2DModel with appropriate classes in dit, pixart tests * empty * legacy mixin classes./ * use a remapping dict for fetching class names. * change to specifc model types in the pipeline implementations. * move _fetch_remapped_cls_from_config to modeling_loading_utils.py * fix dependency problems. * add deprecation note.
|
Hello, I need to use a deprecated model (vq-diffusion) now. Due to version issues, Transformer2DModel has been mapped to two variants, but these two variants are slightly different from the original vq-diffusion (specifically, different types of norms are used). Directly loading the pre-trained model will cause the from_pretrained of the LegacyModelMixin class to fall into a loop call until the buffer overflows. If DiTTransformer2DModel uses ada_norm, an error [NotImplementedError: Forward pass is not implemented when |
What does this PR do?
Introduces two variants of
Transformer2DModel:DiTTransformer2DModelPixArtTransformer2DModelFor the other instances where
Transformer2DModelis used, they should later be turned to blocks as they shouldn't be inheriting fromModelMixin(has been discussed internally).TODO:
(Will be tackled after I get an initial review)
Some comments are in-line.
LMK.