Skip to content

the main branch is broken for controlnet training with sdxl #4206

@yutongli

Description

@yutongli

Describe the bug

the main branch (all PRs up to #4205) failed the controlnet training with sdxl. Some kind of disalignment of arguments.

Reproduction

export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9"
export VAE_DIR="madebyollin/sdxl-vae-fp16-fix"
export OUTPUT_DIR="product_train_output_extract_1stbatch_100k_sdxl0.9_1024_lr1"
export CACHE_DIR="/home/ubuntu/lamda-filesystem/custom_cache"

accelerate launch --mixed_precision="fp16" train_controlnet_sdxl.py
--pretrained_model_name_or_path=$MODEL_DIR
--output_dir=$OUTPUT_DIR
--pretrained_vae_model_name_or_path=$VAE_DIR
--cache_dir=$CACHE_DIR
--dataset_name=all_training_full_extract
--image_column="target"
--conditioning_image_column="source"
--caption_column="prompt"
--resolution=1024
--learning_rate=1e-5
--validation_image "./val1_extract_source.jpg" "./val2_extract_source.jpg" "./val3_extract_source.jpg" "./popchange.png"
--validation_prompt "a white trash can sitting on a table next to a plant" "a bottle of liquid with flower in it" "a rack with a bunch of shoes on it" "a doll in galaxy"
--train_batch_size=1
--gradient_accumulation_steps=12
--tracker_project_name="product_train_output_extract_1stbatch_100k_sdxl0.9_1024_lr1"
--num_train_epochs=20
--report_to=wandb
--validation_steps=100
--checkpointing_steps=1000
--checkpoints_total_limit=10
--seed=42
--enable_xformers_memory_efficient_attention

Logs

Traceback (most recent call last):
  File "train_controlnet_sdxl.py", line 1248, in <module>
    main(args)
  File "train_controlnet_sdxl.py", line 1212, in main
    image_logs = log_validation(
  File "train_controlnet_sdxl.py", line 126, in log_validation
    image = pipeline(
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/lamda-filesystem/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py", line 782, in __call__
    self.check_inputs(
  File "/home/ubuntu/lamda-filesystem/diffusers/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py", line 450, in check_inputs
    raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
ValueError: `prompt_2` has to be of type `str` or `list` but is <class 'PIL.Image.Image'>

System Info

  • diffusers version: 0.19.0.dev0
  • Platform: Linux-5.15.0-52-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Huggingface_hub version: 0.16.4
  • Transformers version: 4.31.0
  • Accelerate version: 0.21.0
  • xFormers version: 0.0.20
  • Using GPU in script?: Single GPU
  • Using distributed or parallel set-up in script?: Single GPU

Who can help?

@patrickvonplaten @sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions