PixArt-Sigma Implementation#7654
Conversation
| @@ -0,0 +1,430 @@ | |||
| # Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. | |||
There was a problem hiding this comment.
Help me understand the difference between the PixArt-Alpha and the PixArt-Sigma pipeline implementations? This way, we can decide if we really need to introduce a separate pipeline class here.
There was a problem hiding this comment.
The main difference is PixArt-Sigma doesn't have micro-conditions and PixArt-Sigma's transformer will use the qk_norm and key/value compression later. Therefore, the pipeline needs to accept these two.
There was a problem hiding this comment.
is this the only difference? we can add a if else statement in pixart-alpha pipeline to support this
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if self.transformer.config.sample_size == 128:
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
if do_classifier_free_guidance:
resolution = torch.cat([resolution, resolution], dim=0)
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}There was a problem hiding this comment.
PixArt-Sigma will have if self.transformer.config.sample_size == 128: as well but it doesn't use the addition conditions such as resolution and aspect ratio. But it could still be checked with self.transformer.config.use_additional_condition I believe?
There was a problem hiding this comment.
PixArt-Sigma's transformer will use the qk_norm and key/value compression later.
This will affect the Transformer block, right and the not pipeline itself. So, I guess that's fine to club it in a single pipeline?
sayakpaul
left a comment
There was a problem hiding this comment.
Overall, this is already looking to be in a very good shape. Thank you!
I think once the comments are addressed and we decide if the changes warrant a new pipeline class, we will need to:
- Add a slow test
- Documentation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
sayakpaul
left a comment
There was a problem hiding this comment.
I am okay with the PR given:
- The code quality tests pass
- We add documentation
- We add fast tests
- Slow tests can be added later when other Sigma variants (with compressed KV) are in
|
Is it just me or does the example pipe = PixArtSigmaPipeline.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16)not work because the huggingface repo only has a transformer folder? I see |
|
Let me add the remaining package to the hugging face repo |
|
@lawrence-cj |
|
I don’t know why. Actually I didn’t change anything in alpha. I guess it maybe a historical problem? |
|
@lawrence-cj can we run the pixart alpha pipeline under this PR to see if the image output generated looks ok? I will look into the test difference to make sure we did not break anything |
src/diffusers/image_processor.py
Outdated
| vae_scale_factor: int = 8, | ||
| resample: str = "lanczos", | ||
| do_normalize: bool = False, | ||
| do_binarize: bool = True, |
There was a problem hiding this comment.
| do_binarize: bool = True, | |
| do_binarize: bool = False, |
src/diffusers/image_processor.py
Outdated
| resample: str = "lanczos", | ||
| do_normalize: bool = False, | ||
| do_binarize: bool = True, | ||
| do_convert_grayscale: bool = True, |
There was a problem hiding this comment.
| do_convert_grayscale: bool = True, | |
| do_convert_grayscale: bool = False, |
|
@lawrence-cj can you run some doc examples for both sigma and alpha, once confirm the outputs are ok and all our tests pass, I will merge |
|
|
|
here are my testing script and outputs:
# test pixart
from diffusers import PixArtAlphaPipeline, PixArtSigmaPipeline, Transformer2DModel, DDPMScheduler
import torch
import gc
weight_dtype = torch.float16
# test 1 step
# transformer = Transformer2DModel.from_pretrained("PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", subfolder="transformer", torch_dtype=weight_dtype)
# scheduler = DDPMScheduler.from_pretrained("PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", subfolder="scheduler")
# pipe = PixArtAlphaPipeline.from_pretrained(
# "PixArt-alpha/PixArt-XL-2-1024-MS",
# transformer=transformer,
# scheduler=scheduler,
# torch_dtype=weight_dtype,
# )
# pipe.enable_model_cpu_offload()
# images = pipe(prompt='dog', timesteps=[400], num_inference_steps=1, output_type="pil",).images[0]
# images.save('yiyi_test_4_out_1step.png')
# del pipe
# torch.cuda.empty_cache()
# gc.collect()
#test pixart alpha
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
images = pipe(prompt=prompt).images[0]
images.save("yiyi_test_4_out_alpha.png")
pipe.transformer = None
torch.cuda.empty_cache()
gc.collect()
# test pixart sigma
transformer = Transformer2DModel.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
subfolder="transformer",
torch_dtype=torch.float16,
use_additional_conditions=False,
)
pipe = PixArtSigmaPipeline.from_pipe(pipe, transformer=transformer,torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]
image.save("yiyi_test_4_out_sigma.png") |
| def test_inference_batch_single_identical(self): | ||
| self._test_inference_batch_single_identical(expected_max_diff=1e-3) | ||
|
|
||
| # PixArt transformer model does not work with sequential offload so skip it for now |
There was a problem hiding this comment.
I'm disabling the sequential offloading test for now because it is not working correctly for transformer models
created an issue about this huggingface/accelerate#2701
|
@yiyixuxu For the wrong results of one step generation, we must set the And here is the correct one. # test 1 step
transformer = Transformer2DModel.from_pretrained("PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", subfolder="transformer", torch_dtype=weight_dtype)
scheduler = DDPMScheduler.from_pretrained("PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", subfolder="scheduler")
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
transformer=transformer,
scheduler=scheduler,
torch_dtype=weight_dtype,
)
# pipe.enable_model_cpu_offload()
pipe.to('cuda')
images = pipe(prompt='dog', timesteps=[400], num_inference_steps=1, guidance_scale=1, output_type="pil",).images[0]
images.save('yiyi_test_4_out_1step.png')
del pipe
torch.cuda.empty_cache()
gc.collect() |
|
With your code, the Sigma pipeline does generating pure noises. from diffusers import PixArtAlphaPipeline, PixArtSigmaPipeline, Transformer2DModel, DDPMScheduler
import torch
transformer = Transformer2DModel.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
subfolder="transformer",
torch_dtype=torch.float16,
use_additional_conditions=False,
)
pipe = PixArtSigmaPipeline.from_pretrained("PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", transformer=transformer, torch_dtype=torch.float16).to('cuda')
pipe.to('cuda')
# pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]
image.save("yiyi_test_4_out_sigma.png") |
|
@lawrence-cj ok! If Sigma works correctly, can we update the test for sigma pipeline to make sure it passes? |
I think a doc example to demonstrate current usage is sufficient, no need for a warning I think |
|
Cool Cool. Let me fix it. Thx! |
|
merged!
|
|
Congrads!🎉 Thank you so much for your support. @yiyixuxu
No problems! I will do all the test for the original and new repos and also add a doc page. |
|
Thanks all. |
* support PixArt-DMD --------- Co-authored-by: jschen <chenjunsong4@h-partners.com> Co-authored-by: badayvedat <badayvedat@gmail.com> Co-authored-by: Vedat Baday <54285744+badayvedat@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: yiyixuxu <yixu310@gmail,com>





What does this PR do?
This PR
PixArtSigmaPipelineuse_additional_conditionsto the__init__function ofTransformer2DModelPixArtAlphaPipelinepossible.diffuserssafetensors file.After merging, we can directly run with:
Cc: @sayakpaul @yiyixuxu