From 0fc8ca382c6ffeb10570df51abf135f93c4c0a5e Mon Sep 17 00:00:00 2001 From: George Korepanov Date: Wed, 19 Jul 2023 11:56:54 +0000 Subject: [PATCH 01/11] Do not force VAE upcast if custom VAE is used --- examples/controlnet/train_controlnet_sdxl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 6be07a38056f..d321b7534338 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -835,6 +835,7 @@ def main(args): vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, + force_upcast=True if args.pretrained_vae_name_or_path is None else False, ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision From 2ad7b7e3f03d92bbbf4b337f72405daa2e718c6e Mon Sep 17 00:00:00 2001 From: George Korepanov Date: Wed, 19 Jul 2023 11:57:29 +0000 Subject: [PATCH 02/11] Support switching cfg in controlnet pipeline --- .../controlnet/pipeline_controlnet_sd_xl.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 29d153ba0485..2c5cdb056dbe 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -911,12 +911,25 @@ def __call__( add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) + + if guess_mode and do_classifier_free_guidance: + # we will infer ControlNet only for the conditional batch + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if do_classifier_free_guidance: + # expand all inputs if we are doing classifier free guidance prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + if not guess_mode: + # if guess_mode is off, controlnet inputs are the same as inputs of base unet model + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = added_cond_kwargs + prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) @@ -926,25 +939,21 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + scaled_latents = self.scheduler.scale_model_input(latents, t) + latent_model_input = torch.cat([scaled_latents] * 2) if do_classifier_free_guidance else scaled_latents # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. - control_model_input = latents - control_model_input = self.scheduler.scale_model_input(control_model_input, t) - controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + control_model_input = scaled_latents else: - control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds + latent_model_input if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: cond_scale = controlnet_conditioning_scale * controlnet_keep[i] - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, @@ -952,7 +961,7 @@ def __call__( controlnet_cond=image, conditioning_scale=cond_scale, guess_mode=guess_mode, - added_cond_kwargs=added_cond_kwargs, + added_cond_kwargs=controlnet_added_cond_kwargs, return_dict=False, ) @@ -1002,7 +1011,8 @@ def __call__( latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + with torch.autocast(enabled=False, device_type="cuda"): + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents return StableDiffusionXLPipelineOutput(images=image) From 179a7feca0016712883fcd1cce64c5ebb21989d5 Mon Sep 17 00:00:00 2001 From: George Korepanov Date: Wed, 19 Jul 2023 12:07:11 +0000 Subject: [PATCH 03/11] typo --- src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 2c5cdb056dbe..6adee71d79e3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -947,7 +947,7 @@ def __call__( # Infer ControlNet only for the conditional batch. control_model_input = scaled_latents else: - latent_model_input + control_model_input = latent_model_input if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] From 6ed6a3fcece511046145afa8b48827e0986390e2 Mon Sep 17 00:00:00 2001 From: George Korepanov Date: Wed, 19 Jul 2023 12:24:01 +0000 Subject: [PATCH 04/11] Fixes from review --- examples/controlnet/train_controlnet_sdxl.py | 1 - .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index d321b7534338..6be07a38056f 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -835,7 +835,6 @@ def main(args): vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, - force_upcast=True if args.pretrained_vae_name_or_path is None else False, ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 6adee71d79e3..0ba7539d90d2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -911,7 +911,7 @@ def __call__( add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) - + if guess_mode and do_classifier_free_guidance: # we will infer ControlNet only for the conditional batch controlnet_prompt_embeds = prompt_embeds @@ -940,7 +940,7 @@ def __call__( for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance scaled_latents = self.scheduler.scale_model_input(latents, t) - latent_model_input = torch.cat([scaled_latents] * 2) if do_classifier_free_guidance else scaled_latents + latent_model_input = torch.cat([scaled_latents] * 2) if do_classifier_free_guidance else scaled_latents # controlnet(s) inference if guess_mode and do_classifier_free_guidance: @@ -1011,8 +1011,7 @@ def __call__( latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if not output_type == "latent": - with torch.autocast(enabled=False, device_type="cuda"): - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: image = latents return StableDiffusionXLPipelineOutput(images=image) From 5be9e23be734d95d8757fb01c5c91e589738c961 Mon Sep 17 00:00:00 2001 From: George Korepanov Date: Wed, 2 Aug 2023 15:08:48 +0000 Subject: [PATCH 05/11] Fix tests --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 0ba7539d90d2..14b5fd539cd7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -912,6 +912,10 @@ def __call__( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + if guess_mode and do_classifier_free_guidance: # we will infer ControlNet only for the conditional batch controlnet_prompt_embeds = prompt_embeds @@ -930,10 +934,6 @@ def __call__( controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = added_cond_kwargs - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: From acf965c96e049709a46d3ceda1b38cbc2daea243 Mon Sep 17 00:00:00 2001 From: George Korepanov Date: Wed, 2 Aug 2023 15:09:06 +0000 Subject: [PATCH 06/11] Add tests for guess mode and no cfg in controlnet SDXL --- tests/pipelines/controlnet/test_controlnet_sdxl.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 2f76be926971..e25016f6285f 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -258,3 +258,17 @@ def test_stable_diffusion_xl_multi_prompts(self): # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + + +class ControlNetPipelineSDXLGuessModeFastTests(ControlNetPipelineSDXLFastTests): + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed=seed) + inputs["guess_mode"] = True + return inputs + + +class ControlNetPipelineSDXLNoCFGFastTests(ControlNetPipelineSDXLFastTests): + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed=seed) + inputs["guidance_scale"] = 1.0 + return inputs From 33db8e55eeb4d6985f3cd3a4eedcef5cc4ba02b7 Mon Sep 17 00:00:00 2001 From: George Korepanov Date: Wed, 2 Aug 2023 16:15:33 +0000 Subject: [PATCH 07/11] Fix case with unbound local --- .../controlnet/pipeline_controlnet_sd_xl.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 14b5fd539cd7..5a092f27e0dd 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -916,8 +916,10 @@ def __call__( add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - if guess_mode and do_classifier_free_guidance: - # we will infer ControlNet only for the conditional batch + use_controlnet_conditional_branch_only = ( + guess_mode and do_classifier_free_guidance + ) + if use_controlnet_conditional_branch_only: controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} @@ -929,8 +931,8 @@ def __call__( added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - if not guess_mode: - # if guess_mode is off, controlnet inputs are the same as inputs of base unet model + if not use_controlnet_conditional_branch_only: + # controlnet inputs are the same as inputs of base unet model controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = added_cond_kwargs @@ -943,8 +945,7 @@ def __call__( latent_model_input = torch.cat([scaled_latents] * 2) if do_classifier_free_guidance else scaled_latents # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. + if use_controlnet_conditional_branch_only: control_model_input = scaled_latents else: control_model_input = latent_model_input From 46ee54dfc270db6457d4064624e3a6945fd0266c Mon Sep 17 00:00:00 2001 From: George Korepanov Date: Wed, 2 Aug 2023 16:35:41 +0000 Subject: [PATCH 08/11] Disable unrelevant test --- tests/pipelines/controlnet/test_controlnet_sdxl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index e25016f6285f..a53d2c2b88c2 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -272,3 +272,6 @@ def get_dummy_inputs(self, device, seed=0): inputs = super().get_dummy_inputs(device, seed=seed) inputs["guidance_scale"] = 1.0 return inputs + + def test_stable_diffusion_xl_multi_prompts(self): + pass From 79cff4855555826ad4ce7f627040f3d4cec577a8 Mon Sep 17 00:00:00 2001 From: George Korepanov Date: Wed, 2 Aug 2023 16:41:30 +0000 Subject: [PATCH 09/11] style --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 5a092f27e0dd..8440606735ec 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -916,9 +916,7 @@ def __call__( add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) - use_controlnet_conditional_branch_only = ( - guess_mode and do_classifier_free_guidance - ) + use_controlnet_conditional_branch_only = guess_mode and do_classifier_free_guidance if use_controlnet_conditional_branch_only: controlnet_prompt_embeds = prompt_embeds controlnet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} From f59a1dbb5932bc1fc4682d6c67bb32f78ee0199b Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 14 Aug 2023 08:58:49 +0530 Subject: [PATCH 10/11] style. --- tests/pipelines/controlnet/test_controlnet_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 78ff84df3658..3589559dabe6 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -258,7 +258,7 @@ def test_stable_diffusion_xl_multi_prompts(self): # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 - + # copied from test_stable_diffusion_xl.py def test_stable_diffusion_xl_prompt_embeds(self): components = self.get_dummy_components() From 87b7cfe780e27e36d06241967e63bf1c72faf15f Mon Sep 17 00:00:00 2001 From: George Korepanov Date: Fri, 18 Aug 2023 18:39:55 +0000 Subject: [PATCH 11/11] Add tests for all combinations of guess mode and cfg scale --- .../controlnet/test_controlnet_sdxl.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 3589559dabe6..be1d857a2f15 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -14,6 +14,8 @@ # limitations under the License. import unittest +from parameterized import parameterized_class +import itertools import numpy as np import torch @@ -46,6 +48,10 @@ enable_full_determinism() +@parameterized_class( + ("guess_mode", "guidance_scale"), + list(itertools.product([False, True], [1.0, 6.0])) +) class ControlNetPipelineSDXLFastTests( PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase ): @@ -55,6 +61,12 @@ class ControlNetPipelineSDXLFastTests( image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + @classmethod + def setUpClass(cls): + if cls == ControlNetPipelineSDXLFastTests: + raise unittest.SkipTest("`parameterized_class` bug, see https://github.com/wolever/parameterized/issues/119") + super().setUpClass() + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( @@ -157,7 +169,8 @@ def get_dummy_inputs(self, device, seed=0): "prompt": "A painting of a squirrel eating a burger", "generator": generator, "num_inference_steps": 2, - "guidance_scale": 6.0, + "guidance_scale": self.guidance_scale, + "guess_mode": self.guess_mode, "output_type": "numpy", "image": image, } @@ -233,6 +246,9 @@ def test_stable_diffusion_xl_multi_prompts(self): # ensure the results are not equal assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4 + if self.guidance_scale <= 1: # negative prompt has no effect without CFG + return + # manually set a negative_prompt inputs = self.get_dummy_inputs(torch_device) inputs["negative_prompt"] = "negative prompt" @@ -297,20 +313,3 @@ def test_stable_diffusion_xl_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 - - -class ControlNetPipelineSDXLGuessModeFastTests(ControlNetPipelineSDXLFastTests): - def get_dummy_inputs(self, device, seed=0): - inputs = super().get_dummy_inputs(device, seed=seed) - inputs["guess_mode"] = True - return inputs - - -class ControlNetPipelineSDXLNoCFGFastTests(ControlNetPipelineSDXLFastTests): - def get_dummy_inputs(self, device, seed=0): - inputs = super().get_dummy_inputs(device, seed=seed) - inputs["guidance_scale"] = 1.0 - return inputs - - def test_stable_diffusion_xl_multi_prompts(self): - pass