Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
adf23f3
support PixArt-DMD
Apr 12, 2024
f9b184c
add PixArtSigmaPipeline
Apr 12, 2024
1732580
add converting file
Apr 12, 2024
5cbf701
move `use_additional_conditions` to the `__init__` function.
Apr 12, 2024
95c65d0
remove unused flag
lawrence-cj Apr 12, 2024
6392d2d
remove unused package
lawrence-cj Apr 12, 2024
a0301cf
fix: circular import
badayvedat Apr 12, 2024
e755d19
bug fixed
lawrence-cj Apr 13, 2024
c1c56a7
Merge branch 'main' into main
badayvedat Apr 14, 2024
b80b9ce
Merge pull request #12 from badayvedat/main
lawrence-cj Apr 15, 2024
bd3720e
Merge branch 'main' into main
sayakpaul Apr 15, 2024
645d9b6
Merge branch 'huggingface:main' into main
lawrence-cj Apr 17, 2024
deee7aa
make style
lawrence-cj Apr 17, 2024
b32c361
add a warning about `use_additional_conditions`
lawrence-cj Apr 17, 2024
bde7bb3
Merge branch 'huggingface:main' into main
lawrence-cj Apr 18, 2024
e785009
add test files
lawrence-cj Apr 18, 2024
249be0c
make style
lawrence-cj Apr 18, 2024
416bfb5
Merge branch 'main' into main
sayakpaul Apr 19, 2024
3142bfd
Merge branch 'huggingface:main' into main
lawrence-cj Apr 22, 2024
0d8273d
1. fixed inheritance from DiffusionPipeline
lawrence-cj Apr 22, 2024
a62b103
Update src/diffusers/models/transformers/transformer_2d.py
lawrence-cj Apr 23, 2024
0d23b19
Update src/diffusers/models/transformers/transformer_2d.py
lawrence-cj Apr 23, 2024
f260a67
Update src/diffusers/models/transformers/transformer_2d.py
lawrence-cj Apr 23, 2024
115ab91
Update src/diffusers/models/transformers/transformer_2d.py
lawrence-cj Apr 23, 2024
971cdd3
Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
lawrence-cj Apr 23, 2024
11d9512
Merge branch 'huggingface:main' into main
lawrence-cj Apr 23, 2024
7e66f16
add copy from info
lawrence-cj Apr 23, 2024
86f1a6e
Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
lawrence-cj Apr 23, 2024
2f16746
add PixArtImageProcessor and remove the relative code in alpha and si…
lawrence-cj Apr 23, 2024
35769dd
make style
lawrence-cj Apr 23, 2024
e5107bf
combine PixArtImageProcessor and VAEImageProcessor
lawrence-cj Apr 23, 2024
b611190
make fix-copies
lawrence-cj Apr 23, 2024
4220c8a
make fix-copies again
lawrence-cj Apr 23, 2024
ff94f3f
copies
Apr 23, 2024
d8818c1
Update src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
lawrence-cj Apr 23, 2024
b956721
fast test passed
lawrence-cj Apr 23, 2024
d6d94bb
Merge branch 'huggingface:main' into main
lawrence-cj Apr 23, 2024
e57c032
style
Apr 23, 2024
64d0f92
alpha fast test passed
lawrence-cj Apr 23, 2024
3f1cc34
Merge branch 'huggingface:main' into main
lawrence-cj Apr 23, 2024
f62f0a5
update vae image processor defaults
Apr 23, 2024
d73b926
Merge branch 'main' of github.com:lawrence-cj/diffusers into pix
Apr 23, 2024
bc7cb56
Revert "alpha fast test passed"
Apr 23, 2024
537d325
empty
Apr 23, 2024
ab55762
skip the sequential offload tests
Apr 23, 2024
65b17cb
Merge branch 'huggingface:main' into main
lawrence-cj Apr 24, 2024
b08342e
sigma fast test passed & make style
lawrence-cj Apr 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions scripts/convert_pixart_sigma_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import argparse
import os

import torch
from transformers import T5EncoderModel, T5Tokenizer

from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtSigmaPipeline, Transformer2DModel


ckpt_id = "PixArt-alpha"
# https://github.com/PixArt-alpha/PixArt-sigma/blob/dd087141864e30ec44f12cb7448dd654be065e88/scripts/inference.py#L158
interpolation_scale = {256: 0.5, 512: 1, 1024: 2, 2048: 4}


def main(args):
all_state_dict = torch.load(args.orig_ckpt_path)
state_dict = all_state_dict.pop("state_dict")
converted_state_dict = {}

# Patch embeddings.
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")

# Caption projection.
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")

# AdaLN-single LN
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")

if args.micro_condition:
# Resolution.
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop(
"csize_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop(
"csize_embedder.mlp.0.bias"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop(
"csize_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop(
"csize_embedder.mlp.2.bias"
)
# Aspect ratio.
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop(
"ar_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop(
"ar_embedder.mlp.0.bias"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop(
"ar_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop(
"ar_embedder.mlp.2.bias"
)
# Shared norm.
converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight")
converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias")

for depth in range(28):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
f"blocks.{depth}.scale_shift_table"
)
# Attention is all you need 🤘

# Self attention.
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.attn.proj.bias"
)
if args.qk_norm:
converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.weight"] = state_dict.pop(
f"blocks.{depth}.attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.bias"] = state_dict.pop(
f"blocks.{depth}.attn.q_norm.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.weight"] = state_dict.pop(
f"blocks.{depth}.attn.k_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.bias"] = state_dict.pop(
f"blocks.{depth}.attn.k_norm.bias"
)

# Feed-forward.
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.fc1.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.fc1.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.fc2.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.fc2.bias"
)

# Cross-attention.
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)

converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias

converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.bias"
)

# Final block.
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")

# PixArt XL/2
transformer = Transformer2DModel(
sample_size=args.image_size // 8,
num_layers=28,
attention_head_dim=72,
in_channels=4,
out_channels=8,
patch_size=2,
attention_bias=True,
num_attention_heads=16,
cross_attention_dim=1152,
activation_fn="gelu-approximate",
num_embeds_ada_norm=1000,
norm_type="ada_norm_single",
norm_elementwise_affine=False,
norm_eps=1e-6,
caption_channels=4096,
interpolation_scale=interpolation_scale[args.image_size],
use_additional_conditions=args.micro_condition,
)
transformer.load_state_dict(converted_state_dict, strict=True)

assert transformer.pos_embed.pos_embed is not None
try:
state_dict.pop("y_embedder.y_embedding")
state_dict.pop("pos_embed")
except Exception as e:
print(f"Skipping {str(e)}")
pass
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"

num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")

if args.only_transformer:
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
else:
# pixart-Sigma vae link: https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae
vae = AutoencoderKL.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="vae")

scheduler = DPMSolverMultistepScheduler()

tokenizer = T5Tokenizer.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(
f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="text_encoder"
)

pipeline = PixArtSigmaPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
)

pipeline.save_pretrained(args.dump_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--micro_condition", action="store_true", help="If use Micro-condition in PixArtMS structure during training."
)
parser.add_argument("--qk_norm", action="store_true", help="If use qk norm during training.")
parser.add_argument(
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--image_size",
default=1024,
type=int,
choices=[256, 512, 1024, 2048],
required=False,
help="Image size of pretrained model, 256, 512, 1024, or 2048.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--only_transformer", default=True, type=bool, required=True)

args = parser.parse_args()
main(args)
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@
"PaintByExamplePipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
"PixArtSigmaPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
Expand Down Expand Up @@ -637,6 +638,7 @@
PaintByExamplePipeline,
PIAPipeline,
PixArtAlphaPipeline,
PixArtSigmaPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
Expand Down
74 changes: 74 additions & 0 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,77 @@ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value
)

return mask_downsample


class PixArtImageProcessor(VaeImageProcessor):
"""
Image processor for PixArt image resize and crop.

Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `False`):
Whether to binarize the image to 0/1.
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format.
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to grayscale format.
"""

@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = True,
do_binarize: bool = False,
do_convert_grayscale: bool = False,
):
super().__init__(
do_resize=do_resize,
vae_scale_factor=vae_scale_factor,
resample=resample,
do_normalize=do_normalize,
do_binarize=do_binarize,
do_convert_grayscale=do_convert_grayscale,
)

@staticmethod
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
"""Returns binned height and width."""
ar = float(height / width)
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
default_hw = ratios[closest_ratio]
return int(default_hw[0]), int(default_hw[1])

@staticmethod
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
orig_height, orig_width = samples.shape[2], samples.shape[3]

# Check if resizing is needed
if orig_height != new_height or orig_width != new_width:
ratio = max(new_height / orig_height, new_width / orig_width)
resized_width = int(orig_width * ratio)
resized_height = int(orig_height * ratio)

# Resize
samples = F.interpolate(
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)

# Center Crop
start_x = (resized_width - new_width) // 2
end_x = start_x + new_width
start_y = (resized_height - new_height) // 2
end_y = start_y + new_height
samples = samples[:, :, start_y:end_y, start_x:end_x]

return samples
9 changes: 7 additions & 2 deletions src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
attention_type: str = "default",
caption_channels: int = None,
interpolation_scale: float = None,
use_additional_conditions: Optional[bool] = None,
):
super().__init__()

Expand All @@ -124,6 +125,12 @@ def __init__(
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.gradient_checkpointing = False
if use_additional_conditions is None:
if norm_type == "ada_norm_single" and sample_size == 128:
use_additional_conditions = True
else:
use_additional_conditions = False
self.use_additional_conditions = use_additional_conditions

# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
Expand Down Expand Up @@ -305,9 +312,7 @@ def _init_patched_inputs(self, norm_type):

# PixArt-Alpha blocks.
self.adaln_single = None
self.use_additional_conditions = False
if self.config.norm_type == "ada_norm_single":
self.use_additional_conditions = self.config.sample_size == 128
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
# additional conditions until we find better name
self.adaln_single = AdaLayerNormSingle(
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_cascade"] = [
Expand Down Expand Up @@ -450,7 +450,7 @@
from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_cascade import (
Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/pipelines/pixart_alpha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"]
_import_structure["pipeline_pixart_sigma"] = ["PixArtSigmaPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand All @@ -32,7 +33,13 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_pixart_alpha import PixArtAlphaPipeline
from .pipeline_pixart_alpha import (
ASPECT_RATIO_256_BIN,
ASPECT_RATIO_512_BIN,
ASPECT_RATIO_1024_BIN,
PixArtAlphaPipeline,
)
from .pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN, PixArtSigmaPipeline

else:
import sys
Expand Down
Loading