[SDXL DreamBooth LoRA] add support for text encoder fine-tuning#4097
Merged
[SDXL DreamBooth LoRA] add support for text encoder fine-tuning#4097
Conversation
…nto feat/sdxl-dreambooth-returns
|
The documentation is not available anymore as the PR was closed or merged. |
Member
Author
|
@patrickvonplaten @williamberman I think I have addressed all your comments:
I would suggest taking another deeper look. |
|
|
||
| return StableDiffusionXLPipelineOutput(images=image) | ||
|
|
||
| # Overrride to properly handle the loading and unloading of the additional text encoder. |
| def compute_embeddings(prompt, text_encoders, tokenizers): | ||
| def compute_time_ids(): | ||
| # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids | ||
| original_size = (args.resolution, args.resolution) |
Contributor
There was a problem hiding this comment.
Suggested change
| original_size = (args.resolution, args.resolution) | |
| original_size = (args.resolution, args.resolution) |
This should ideally be the original size of the passed image (before resizing), but ok to leave as is for now
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Contributor
|
@williamberman ok for you? |
Comment on lines
+840
to
+863
| def save_lora_weights( | ||
| self, | ||
| save_directory: Union[str, os.PathLike], | ||
| unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | ||
| text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | ||
| text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | ||
| is_main_process: bool = True, | ||
| weight_name: str = None, | ||
| save_function: Callable = None, | ||
| safe_serialization: bool = False, | ||
| ): | ||
| state_dict = {} | ||
|
|
||
| def pack_weights(layers, prefix): | ||
| layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers | ||
| layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} | ||
| return layers_state_dict | ||
|
|
||
| state_dict.update(pack_weights(unet_lora_layers, "unet")) | ||
|
|
||
| if text_encoder_lora_layers and text_encoder_2_lora_layers: | ||
| state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) | ||
| state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) | ||
|
|
Comment on lines
430
to
436
| class DreamBoothDataset(Dataset): | ||
| """ | ||
| A dataset to prepare the instance and class images with the prompts for fine-tuning the model. | ||
| It pre-processes the images and the tokenizes prompts. | ||
| It pre-processes the images. | ||
| """ | ||
|
|
||
| def __init__( |
Member
Author
|
Thanks all for your suggestions. |
orpatashnik
pushed a commit
to orpatashnik/diffusers
that referenced
this pull request
Aug 1, 2023
…ingface#4097) * Allow low precision sd xl * finish * finish * feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth * fix: variable assignments. * add: autocast block. * add debugging * vae dtype hell * fix: vae dtype hell. * fix: vae dtype hell 3. * clean up * lora text encoder loader. * fix: unwrapping models. * add: tests. * docs. * handle unexpected keys. * fix vae dtype in the final inference. * fix scope problem. * fix: save_model_card args. * initialize: prefix to None. * fix: dtype issues. * apply gixes. * debgging. * debugging * debugging * debugging * debugging * debugging * add: fast tests. * pre-tokenize. * address: will's comments. * fix: loader and tests. * fix: dataloader. * simplify dataloader. * length. * simplification. * make style && make quality * simplify state_dict munging * fix: tests. * fix: state_dict packing. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
orpatashnik
pushed a commit
to orpatashnik/diffusers
that referenced
this pull request
Aug 1, 2023
…ingface#4097) * Allow low precision sd xl * finish * finish * feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth * fix: variable assignments. * add: autocast block. * add debugging * vae dtype hell * fix: vae dtype hell. * fix: vae dtype hell 3. * clean up * lora text encoder loader. * fix: unwrapping models. * add: tests. * docs. * handle unexpected keys. * fix vae dtype in the final inference. * fix scope problem. * fix: save_model_card args. * initialize: prefix to None. * fix: dtype issues. * apply gixes. * debgging. * debugging * debugging * debugging * debugging * debugging * add: fast tests. * pre-tokenize. * address: will's comments. * fix: loader and tests. * fix: dataloader. * simplify dataloader. * length. * simplification. * make style && make quality * simplify state_dict munging * fix: tests. * fix: state_dict packing. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
orpatashnik
pushed a commit
to orpatashnik/diffusers
that referenced
this pull request
Aug 1, 2023
…ingface#4097) * Allow low precision sd xl * finish * finish * feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth * fix: variable assignments. * add: autocast block. * add debugging * vae dtype hell * fix: vae dtype hell. * fix: vae dtype hell 3. * clean up * lora text encoder loader. * fix: unwrapping models. * add: tests. * docs. * handle unexpected keys. * fix vae dtype in the final inference. * fix scope problem. * fix: save_model_card args. * initialize: prefix to None. * fix: dtype issues. * apply gixes. * debgging. * debugging * debugging * debugging * debugging * debugging * add: fast tests. * pre-tokenize. * address: will's comments. * fix: loader and tests. * fix: dataloader. * simplify dataloader. * length. * simplification. * make style && make quality * simplify state_dict munging * fix: tests. * fix: state_dict packing. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
yoonseokjin
pushed a commit
to yoonseokjin/diffusers
that referenced
this pull request
Dec 25, 2023
…ingface#4097) * Allow low precision sd xl * finish * finish * feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth * fix: variable assignments. * add: autocast block. * add debugging * vae dtype hell * fix: vae dtype hell. * fix: vae dtype hell 3. * clean up * lora text encoder loader. * fix: unwrapping models. * add: tests. * docs. * handle unexpected keys. * fix vae dtype in the final inference. * fix scope problem. * fix: save_model_card args. * initialize: prefix to None. * fix: dtype issues. * apply gixes. * debgging. * debugging * debugging * debugging * debugging * debugging * add: fast tests. * pre-tokenize. * address: will's comments. * fix: loader and tests. * fix: dataloader. * simplify dataloader. * length. * simplification. * make style && make quality * simplify state_dict munging * fix: tests. * fix: state_dict packing. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
AmericanPresidentJimmyCarter
pushed a commit
to AmericanPresidentJimmyCarter/diffusers
that referenced
this pull request
Apr 26, 2024
…ingface#4097) * Allow low precision sd xl * finish * finish * feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth * fix: variable assignments. * add: autocast block. * add debugging * vae dtype hell * fix: vae dtype hell. * fix: vae dtype hell 3. * clean up * lora text encoder loader. * fix: unwrapping models. * add: tests. * docs. * handle unexpected keys. * fix vae dtype in the final inference. * fix scope problem. * fix: save_model_card args. * initialize: prefix to None. * fix: dtype issues. * apply gixes. * debgging. * debugging * debugging * debugging * debugging * debugging * add: fast tests. * pre-tokenize. * address: will's comments. * fix: loader and tests. * fix: dataloader. * simplify dataloader. * length. * simplification. * make style && make quality * simplify state_dict munging * fix: tests. * fix: state_dict packing. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds support for text encoder fine-tuning in the DreamBooth LoRA script for SDXL.
Summary of the changes:
To help us maintain sanity, I tested the current training script under three settings:
Artifacts:
Artifacts:
Artifacts: