Implement SD3 loss weighting#8528
Conversation
|
thanks @Slickytail checking |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@Slickytail can you kindly |
…tribution thanks to @Slickytail
|
I get the following error: |
|
@kashif I will review the changes a bit later but could you test the scripts and see if they are running without errors? |
|
This avoids errors |
|
a question about the loss: why here we firstly convert the model prediction to x_0 to compute loss? |
|
Hi @kashif, it looks like you commited the necessary formatting/style changes, so I'm assuming |
|
@Slickytail yes let's keep all the |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks a lot for these!
@asomoza has confirmed that these are working nicely.
In a follow up PR, I will add your PR in the comments to honor your contributions :)
Will also make it a little utility function and move it to training_utils.py.
Appreciate your help here!
|
Okay I can confirm that the failing example tests are only with the latest version of |
|
Thanks for the ping, @sayakpaul. In the latest
|
|
I merged https://huggingface.co/datasets/hf-internal-testing/fill10/discussions/1 to fix the CI |
* Add lognorm and cosmap weighting * Implement mode sampling * Update examples/dreambooth/train_dreambooth_lora_sd3.py * Update examples/dreambooth/train_dreambooth_lora_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_lora_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_lora_sd3.py * keep timestamp sampling fully on cpu --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>


The loss-weighting schemes for SD3 training were not implemented correctly, causing all of them to be non-functional. I went ahead and implemented the lognorm and cosmap schemes, just by using the density at those timesteps. Potentially, a better approach would be to sample the timestep according to that density in the first place.
The Mode scheme is much harder to implement -- there's a reason that they didn't include an explicit form for the density in the paper (I couldn't find one...), so I put in an error message if you try to use it for now.
@sayakpaul @kashif