-
Notifications
You must be signed in to change notification settings - Fork 6.9k
PixArt-Sigma Implementation #7654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
adf23f3
support PixArt-DMD
f9b184c
add PixArtSigmaPipeline
1732580
add converting file
5cbf701
move `use_additional_conditions` to the `__init__` function.
95c65d0
remove unused flag
lawrence-cj 6392d2d
remove unused package
lawrence-cj a0301cf
fix: circular import
badayvedat e755d19
bug fixed
lawrence-cj c1c56a7
Merge branch 'main' into main
badayvedat b80b9ce
Merge pull request #12 from badayvedat/main
lawrence-cj bd3720e
Merge branch 'main' into main
sayakpaul 645d9b6
Merge branch 'huggingface:main' into main
lawrence-cj deee7aa
make style
lawrence-cj b32c361
add a warning about `use_additional_conditions`
lawrence-cj bde7bb3
Merge branch 'huggingface:main' into main
lawrence-cj e785009
add test files
lawrence-cj 249be0c
make style
lawrence-cj 416bfb5
Merge branch 'main' into main
sayakpaul 3142bfd
Merge branch 'huggingface:main' into main
lawrence-cj 0d8273d
1. fixed inheritance from DiffusionPipeline
lawrence-cj a62b103
Update src/diffusers/models/transformers/transformer_2d.py
lawrence-cj 0d23b19
Update src/diffusers/models/transformers/transformer_2d.py
lawrence-cj f260a67
Update src/diffusers/models/transformers/transformer_2d.py
lawrence-cj 115ab91
Update src/diffusers/models/transformers/transformer_2d.py
lawrence-cj 971cdd3
Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
lawrence-cj 11d9512
Merge branch 'huggingface:main' into main
lawrence-cj 7e66f16
add copy from info
lawrence-cj 86f1a6e
Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
lawrence-cj 2f16746
add PixArtImageProcessor and remove the relative code in alpha and si…
lawrence-cj 35769dd
make style
lawrence-cj e5107bf
combine PixArtImageProcessor and VAEImageProcessor
lawrence-cj b611190
make fix-copies
lawrence-cj 4220c8a
make fix-copies again
lawrence-cj ff94f3f
copies
d8818c1
Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
lawrence-cj b956721
fast test passed
lawrence-cj d6d94bb
Merge branch 'huggingface:main' into main
lawrence-cj e57c032
style
64d0f92
alpha fast test passed
lawrence-cj 3f1cc34
Merge branch 'huggingface:main' into main
lawrence-cj f62f0a5
update vae image processor defaults
d73b926
Merge branch 'main' of github.com:lawrence-cj/diffusers into pix
bc7cb56
Revert "alpha fast test passed"
537d325
empty
ab55762
skip the sequential offload tests
65b17cb
Merge branch 'huggingface:main' into main
lawrence-cj b08342e
sigma fast test passed & make style
lawrence-cj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,223 @@ | ||
| import argparse | ||
| import os | ||
|
|
||
| import torch | ||
| from transformers import T5EncoderModel, T5Tokenizer | ||
|
|
||
| from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtSigmaPipeline, Transformer2DModel | ||
|
|
||
|
|
||
| ckpt_id = "PixArt-alpha" | ||
| # https://github.com/PixArt-alpha/PixArt-sigma/blob/dd087141864e30ec44f12cb7448dd654be065e88/scripts/inference.py#L158 | ||
| interpolation_scale = {256: 0.5, 512: 1, 1024: 2, 2048: 4} | ||
|
|
||
|
|
||
| def main(args): | ||
| all_state_dict = torch.load(args.orig_ckpt_path) | ||
| state_dict = all_state_dict.pop("state_dict") | ||
| converted_state_dict = {} | ||
|
|
||
| # Patch embeddings. | ||
| converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight") | ||
| converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias") | ||
|
|
||
| # Caption projection. | ||
| converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight") | ||
| converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias") | ||
| converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") | ||
| converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") | ||
|
|
||
| # AdaLN-single LN | ||
| converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( | ||
| "t_embedder.mlp.0.weight" | ||
| ) | ||
| converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") | ||
| converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( | ||
| "t_embedder.mlp.2.weight" | ||
| ) | ||
| converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") | ||
|
|
||
| if args.micro_condition: | ||
| # Resolution. | ||
| converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop( | ||
| "csize_embedder.mlp.0.weight" | ||
| ) | ||
| converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop( | ||
| "csize_embedder.mlp.0.bias" | ||
| ) | ||
| converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop( | ||
| "csize_embedder.mlp.2.weight" | ||
| ) | ||
| converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop( | ||
| "csize_embedder.mlp.2.bias" | ||
| ) | ||
| # Aspect ratio. | ||
| converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop( | ||
| "ar_embedder.mlp.0.weight" | ||
| ) | ||
| converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop( | ||
| "ar_embedder.mlp.0.bias" | ||
| ) | ||
| converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop( | ||
| "ar_embedder.mlp.2.weight" | ||
| ) | ||
| converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop( | ||
| "ar_embedder.mlp.2.bias" | ||
| ) | ||
| # Shared norm. | ||
| converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight") | ||
| converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias") | ||
|
|
||
| for depth in range(28): | ||
| # Transformer blocks. | ||
| converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( | ||
| f"blocks.{depth}.scale_shift_table" | ||
| ) | ||
| # Attention is all you need 🤘 | ||
|
|
||
| # Self attention. | ||
| q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0) | ||
| q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0) | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias | ||
| # Projection. | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop( | ||
| f"blocks.{depth}.attn.proj.weight" | ||
| ) | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop( | ||
| f"blocks.{depth}.attn.proj.bias" | ||
| ) | ||
| if args.qk_norm: | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.weight"] = state_dict.pop( | ||
| f"blocks.{depth}.attn.q_norm.weight" | ||
| ) | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.bias"] = state_dict.pop( | ||
| f"blocks.{depth}.attn.q_norm.bias" | ||
| ) | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.weight"] = state_dict.pop( | ||
| f"blocks.{depth}.attn.k_norm.weight" | ||
| ) | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.bias"] = state_dict.pop( | ||
| f"blocks.{depth}.attn.k_norm.bias" | ||
| ) | ||
|
|
||
| # Feed-forward. | ||
| converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop( | ||
| f"blocks.{depth}.mlp.fc1.weight" | ||
| ) | ||
| converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop( | ||
| f"blocks.{depth}.mlp.fc1.bias" | ||
| ) | ||
| converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop( | ||
| f"blocks.{depth}.mlp.fc2.weight" | ||
| ) | ||
| converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop( | ||
| f"blocks.{depth}.mlp.fc2.bias" | ||
| ) | ||
|
|
||
| # Cross-attention. | ||
| q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight") | ||
| q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias") | ||
| k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0) | ||
| k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0) | ||
|
|
||
| converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias | ||
|
|
||
| converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop( | ||
| f"blocks.{depth}.cross_attn.proj.weight" | ||
| ) | ||
| converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop( | ||
| f"blocks.{depth}.cross_attn.proj.bias" | ||
| ) | ||
|
|
||
| # Final block. | ||
| converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight") | ||
| converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias") | ||
| converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table") | ||
|
|
||
| # PixArt XL/2 | ||
| transformer = Transformer2DModel( | ||
| sample_size=args.image_size // 8, | ||
| num_layers=28, | ||
| attention_head_dim=72, | ||
| in_channels=4, | ||
| out_channels=8, | ||
| patch_size=2, | ||
| attention_bias=True, | ||
| num_attention_heads=16, | ||
| cross_attention_dim=1152, | ||
| activation_fn="gelu-approximate", | ||
| num_embeds_ada_norm=1000, | ||
| norm_type="ada_norm_single", | ||
| norm_elementwise_affine=False, | ||
| norm_eps=1e-6, | ||
| caption_channels=4096, | ||
| interpolation_scale=interpolation_scale[args.image_size], | ||
| use_additional_conditions=args.micro_condition, | ||
| ) | ||
| transformer.load_state_dict(converted_state_dict, strict=True) | ||
|
|
||
| assert transformer.pos_embed.pos_embed is not None | ||
| try: | ||
| state_dict.pop("y_embedder.y_embedding") | ||
| state_dict.pop("pos_embed") | ||
| except Exception as e: | ||
| print(f"Skipping {str(e)}") | ||
| pass | ||
| assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}" | ||
|
|
||
| num_model_params = sum(p.numel() for p in transformer.parameters()) | ||
| print(f"Total number of transformer parameters: {num_model_params}") | ||
|
|
||
| if args.only_transformer: | ||
| transformer.save_pretrained(os.path.join(args.dump_path, "transformer")) | ||
| else: | ||
| # pixart-Sigma vae link: https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae | ||
| vae = AutoencoderKL.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="vae") | ||
|
|
||
| scheduler = DPMSolverMultistepScheduler() | ||
|
|
||
| tokenizer = T5Tokenizer.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="tokenizer") | ||
| text_encoder = T5EncoderModel.from_pretrained( | ||
| f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="text_encoder" | ||
| ) | ||
|
|
||
| pipeline = PixArtSigmaPipeline( | ||
| tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler | ||
| ) | ||
|
|
||
| pipeline.save_pretrained(args.dump_path) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
|
|
||
| parser.add_argument( | ||
| "--micro_condition", action="store_true", help="If use Micro-condition in PixArtMS structure during training." | ||
| ) | ||
| parser.add_argument("--qk_norm", action="store_true", help="If use qk norm during training.") | ||
| parser.add_argument( | ||
| "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." | ||
| ) | ||
| parser.add_argument( | ||
| "--image_size", | ||
| default=1024, | ||
| type=int, | ||
| choices=[256, 512, 1024, 2048], | ||
| required=False, | ||
| help="Image size of pretrained model, 256, 512, 1024, or 2048.", | ||
| ) | ||
| parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") | ||
| parser.add_argument("--only_transformer", default=True, type=bool, required=True) | ||
|
|
||
| args = parser.parse_args() | ||
| main(args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.