Skip to content

PixArt-Sigma Implementation#7654

Merged
yiyixuxu merged 47 commits intohuggingface:mainfrom
lawrence-cj:main
Apr 24, 2024
Merged

PixArt-Sigma Implementation#7654
yiyixuxu merged 47 commits intohuggingface:mainfrom
lawrence-cj:main

Conversation

@lawrence-cj
Copy link
Copy Markdown
Contributor

What does this PR do?

This PR

  1. integrates the PixArtSigmaPipeline
  2. moves the use_additional_conditions to the __init__ function of Transformer2DModel
  3. makes the one-step sampling(DMD) of PixArtAlphaPipeline possible.
  4. adds conversion file to convert pickled PixArt-Sigma checkpoint to the diffusers safetensors file.

After merging, we can directly run with:

>>> import torch
>>> from diffusers import PixArtSigmaPipeline

>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-Sigma-XL-2-512-MS" too.
>>> pipe = PixArtSigmaPipeline.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16)
>>> # Enable memory optimizations.
>>> # pipe.enable_model_cpu_offload()

>>> prompt = "A small cactus with a happy face in the Sahara desert."
>>> image = pipe(prompt).images[0]

Cc: @sayakpaul @yiyixuxu

@@ -0,0 +1,430 @@
# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Help me understand the difference between the PixArt-Alpha and the PixArt-Sigma pipeline implementations? This way, we can decide if we really need to introduce a separate pipeline class here.

Copy link
Copy Markdown
Contributor Author

@lawrence-cj lawrence-cj Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main difference is PixArt-Sigma doesn't have micro-conditions and PixArt-Sigma's transformer will use the qk_norm and key/value compression later. Therefore, the pipeline needs to accept these two.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the only difference? we can add a if else statement in pixart-alpha pipeline to support this

# 6.1 Prepare micro-conditions.
        added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
        if self.transformer.config.sample_size == 128:
            resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
            aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
            resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
            aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)

            if do_classifier_free_guidance:
                resolution = torch.cat([resolution, resolution], dim=0)
                aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)

            added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PixArt-Sigma will have if self.transformer.config.sample_size == 128: as well but it doesn't use the addition conditions such as resolution and aspect ratio. But it could still be checked with self.transformer.config.use_additional_condition I believe?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lawrence-cj

PixArt-Sigma's transformer will use the qk_norm and key/value compression later.

This will affect the Transformer block, right and the not pipeline itself. So, I guess that's fine to club it in a single pipeline?

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this is already looking to be in a very good shape. Thank you!

I think once the comments are addressed and we decide if the changes warrant a new pipeline class, we will need to:

  • Add a slow test
  • Documentation

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul requested a review from yiyixuxu April 17, 2024 03:48
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with the PR given:

  • The code quality tests pass
  • We add documentation
  • We add fast tests
  • Slow tests can be added later when other Sigma variants (with compressed KV) are in

@Beinsezii
Copy link
Copy Markdown
Contributor

Is it just me or does the example

pipe = PixArtSigmaPipeline.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16)

not work because the huggingface repo only has a transformer folder?

I see convert_pixart_sigma_to_diffusers.py has a workaround for it but the main pipeline does not have any special case for using from_pretrained on a transformer-only repository.

@lawrence-cj
Copy link
Copy Markdown
Contributor Author

Let me add the remaining package to the hugging face repo

@yiyixuxu
Copy link
Copy Markdown
Collaborator

@lawrence-cj
why do we need to update alpha tests?
I don't think we changed the model or pipeline in a way that would affect the results

@lawrence-cj
Copy link
Copy Markdown
Contributor Author

I don’t know why. Actually I didn’t change anything in alpha. I guess it maybe a historical problem?

@yiyixuxu
Copy link
Copy Markdown
Collaborator

@lawrence-cj can we run the pixart alpha pipeline under this PR to see if the image output generated looks ok?

I will look into the test difference to make sure we did not break anything

vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = False,
do_binarize: bool = True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
do_binarize: bool = True,
do_binarize: bool = False,

resample: str = "lanczos",
do_normalize: bool = False,
do_binarize: bool = True,
do_convert_grayscale: bool = True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
do_convert_grayscale: bool = True,
do_convert_grayscale: bool = False,

@yiyixuxu
Copy link
Copy Markdown
Collaborator

@lawrence-cj
i reverted the changes you made to the alpha test and updated the default arg values for PixArt image processor #7654 (comment)

can you run some doc examples for both sigma and alpha, once confirm the outputs are ok and all our tests pass, I will merge

@yiyixuxu
Copy link
Copy Markdown
Collaborator

yiyixuxu commented Apr 23, 2024

our hub is experiencing issue so most CI failings are not relevant here (resolved)

@yiyixuxu
Copy link
Copy Markdown
Collaborator

here are my testing script and outputs:

  • alpha works as expected
  • sigma is currently generating noise (but I; 'm not sure if I ran it correctly, I just used the alpha pipeline and swapped out the transformer)
  • not sure if one-step output is expected
# test pixart 

from diffusers import PixArtAlphaPipeline, PixArtSigmaPipeline, Transformer2DModel, DDPMScheduler
import torch
import gc

weight_dtype = torch.float16

# test 1 step
# transformer = Transformer2DModel.from_pretrained("PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", subfolder="transformer", torch_dtype=weight_dtype)
# scheduler = DDPMScheduler.from_pretrained("PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", subfolder="scheduler")

# pipe = PixArtAlphaPipeline.from_pretrained(
#         "PixArt-alpha/PixArt-XL-2-1024-MS",
#         transformer=transformer,
#         scheduler=scheduler,
#         torch_dtype=weight_dtype,
#     )
# pipe.enable_model_cpu_offload()

# images = pipe(prompt='dog', timesteps=[400], num_inference_steps=1, output_type="pil",).images[0]
# images.save('yiyi_test_4_out_1step.png')

# del pipe
# torch.cuda.empty_cache()
# gc.collect()
#test pixart alpha

pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()


prompt = "A small cactus with a happy face in the Sahara desert."
images = pipe(prompt=prompt).images[0]
images.save("yiyi_test_4_out_alpha.png")

pipe.transformer = None
torch.cuda.empty_cache()
gc.collect()

# test pixart sigma
transformer = Transformer2DModel.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
    subfolder="transformer",
    torch_dtype=torch.float16,
    use_additional_conditions=False,
)
pipe = PixArtSigmaPipeline.from_pipe(pipe, transformer=transformer,torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]
image.save("yiyi_test_4_out_sigma.png")

alpha one-step
yiyi_test_4_out_1step
alpha
yiyi_test_4_out_alpha
sigma
yiyi_test_4_out_sigma

def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)

# PixArt transformer model does not work with sequential offload so skip it for now
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm disabling the sequential offloading test for now because it is not working correctly for transformer models
created an issue about this huggingface/accelerate#2701

@lawrence-cj
Copy link
Copy Markdown
Contributor Author

lawrence-cj commented Apr 24, 2024

@yiyixuxu For the wrong results of one step generation, we must set the guidance_scale=1 to make it right. Maybe, we can add a warning in the code base?

And here is the correct one.

# test 1 step
transformer = Transformer2DModel.from_pretrained("PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", subfolder="transformer", torch_dtype=weight_dtype)
scheduler = DDPMScheduler.from_pretrained("PixArt-alpha/PixArt-Alpha-DMD-XL-2-512x512", subfolder="scheduler")

pipe = PixArtAlphaPipeline.from_pretrained(
       "PixArt-alpha/PixArt-XL-2-1024-MS",
        transformer=transformer,
        scheduler=scheduler,
        torch_dtype=weight_dtype,
    )
# pipe.enable_model_cpu_offload()
pipe.to('cuda')

images = pipe(prompt='dog', timesteps=[400], num_inference_steps=1, guidance_scale=1, output_type="pil",).images[0]
images.save('yiyi_test_4_out_1step.png')

del pipe
torch.cuda.empty_cache()
gc.collect()

Generating:
21951D96-9862-4145-8EF2-21B4BCCB8AC6

@lawrence-cj
Copy link
Copy Markdown
Contributor Author

lawrence-cj commented Apr 24, 2024

With your code, the Sigma pipeline does generating pure noises.
But, I use this to get the normal result:

from diffusers import PixArtAlphaPipeline, PixArtSigmaPipeline, Transformer2DModel, DDPMScheduler
import torch


transformer = Transformer2DModel.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
    subfolder="transformer",
    torch_dtype=torch.float16,
    use_additional_conditions=False,
)
pipe = PixArtSigmaPipeline.from_pretrained("PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers", transformer=transformer, torch_dtype=torch.float16).to('cuda')
pipe.to('cuda')
# pipe.enable_model_cpu_offload()
prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]
image.save("yiyi_test_4_out_sigma.png")

generating:
855334A4-53D0-4682-B3F7-B867375933C0

@yiyixuxu
Copy link
Copy Markdown
Collaborator

@lawrence-cj ok! If Sigma works correctly, can we update the test for sigma pipeline to make sure it passes?
the image processor was not working correctly before when you made these tests for sigmas, so I think the expected results are currently wrong and need to be updated

@yiyixuxu
Copy link
Copy Markdown
Collaborator

@lawrence-cj

for the wrong results of one step generation, we must set the guidance_scale=1 to make it right. Maybe, we can add a warning in the code base?

I think a doc example to demonstrate current usage is sufficient, no need for a warning I think

@lawrence-cj
Copy link
Copy Markdown
Contributor Author

Cool Cool. Let me fix it. Thx!

@lawrence-cj lawrence-cj requested a review from yiyixuxu April 24, 2024 07:44
@yiyixuxu yiyixuxu merged commit 39215aa into huggingface:main Apr 24, 2024
@yiyixuxu
Copy link
Copy Markdown
Collaborator

merged!
for follow up, can we:

  1. add doc pages
  2. make sure all the repos can work with from_pretrained ? e.g. these ones currently only have a transformer folder PixArt-alpha/PixArt-Sigma-XL-2-2K-MS

@lawrence-cj
Copy link
Copy Markdown
Contributor Author

Congrads!🎉 Thank you so much for your support. @yiyixuxu
Cc: @sayakpaul @badayvedat

add doc pages
make sure all the repos can work with from_pretrained ? e.g. these ones currently only have a transformer folder PixArt-alpha/PixArt-Sigma-XL-2-2K-MS

No problems! I will do all the test for the original and new repos and also add a doc page.

@sayakpaul
Copy link
Copy Markdown
Member

Thanks all.

sayakpaul added a commit that referenced this pull request Dec 23, 2024
* support PixArt-DMD

---------

Co-authored-by: jschen <chenjunsong4@h-partners.com>
Co-authored-by: badayvedat <badayvedat@gmail.com>
Co-authored-by: Vedat Baday <54285744+badayvedat@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants