multiple prediction options in ddpm, ddim#818
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
Yes we indeed need this now I think :-) (also for dance diffusion) |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Generally, this looks good to me :-) We'll definitely need tests here though
|
I want to make a colab comparing the prediction for training on one scheduler to start (make sure it works). |
|
I'm new to contributing and so I'm a little confused about what I should be doing. Should I clone the changes and make a colab to compare with original predictions? |
|
Hey @pie31415, Since you mentioned you were interested in this PR, I think it'd be super useful to do a PR review here :-) |
|
@pie31415 Another really useful thing would be to just verify the implementation from the original papers and links above. This is a pretty tricky port so I will do this too, but it would be hugely useful. For example, I actually think the DDIM implementation is much closer than DDPM. |
| model_output: torch.FloatTensor, | ||
| timestep: int, | ||
| sample: torch.FloatTensor, | ||
| prediction_type: str = "epsilon", |
There was a problem hiding this comment.
Think this should go in the __init__ function and we've somewhat settled on predict_epsilon: bool I think in terms of naming :-)
There was a problem hiding this comment.
Ah ok now we actually have three types, so we might have to reconsider this choice 😅
But I think it should definitely go in the config of the scheduler and not be an arg of __call__
There was a problem hiding this comment.
I actually messaged Nathan about my progress on DDIM v prediction in a separate branch as you commented - crazy timing!
I'll make sure I make these changes in my branch before opening a PR
|
That's very interesting here actually - @patil-suraj @anton-l could you also take a look? :-) |
|
DDIM will hopefully be ready for review soon too. Results training on it are still a little pixelated, but you can clearly see the shape of a butterfly. I'm guessing I have something not quite right with the variance calculation. Will hopefully have updates here soon! |
* v diffusion support for ddpm * quality and style * variable name consistency * missing base case * pass prediction type along in the pipeline * put prediction type in scheduler config * style * try to train on ddim * changes to ddim * ddim v prediction works to train butterflies example * fix bad merge, style and quality * try to fix broken doc strings * second pass * one more * white space * Update src/diffusers/schedulers/scheduling_ddim.py * remove extra lines * Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Ben Glickenhaus <ben@mail.cs.umass.edu> Co-authored-by: Nathan Lambert <nathan@huggingface.co>
|
Update for the diffusers team (@patrickvonplaten , @anton-l , @patil-suraj ). We updated DDIM now (promising results), and I'll add tests / fix merge issues this afternoon. |
|
@patrickvonplaten this should be go to go. Now, this leaves only Lots more good work from @bglick13 |
| set_alpha_to_one: bool = True, | ||
| variance_type: str = "fixed", | ||
| steps_offset: int = 0, | ||
| prediction_type: Literal["epsilon", "sample", "velocity"] = "epsilon", |
There was a problem hiding this comment.
Note we currently have a config parameter called predict_epsilon that is already used in multiple schedulers:
So we cannot really add this prediciton_type here without deprecating the other one and also deprecating arguments like this one:
pcuenca
left a comment
There was a problem hiding this comment.
LGTM. Same comments about deprecating predict_epsilon everywhere.
| def expand_to_shape(input, timesteps, shape, device): | ||
| """ | ||
| Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast | ||
| nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once. | ||
| """ | ||
| out = torch.gather(input.to(device), 0, timesteps.to(device)) | ||
| reshape = [shape[0]] + [1] * (len(shape) - 1) | ||
| out = out.reshape(*reshape) | ||
| return out | ||
|
|
||
|
|
There was a problem hiding this comment.
How do we feel about moving this to scheduling_utils.py? Maybe get_alpha_sigma as well.
There was a problem hiding this comment.
I'm good with this. I had this as a TODO in my mind. Could also be made more elegant, but wasn't 100% sure how yet.
There was a problem hiding this comment.
My only concern with these two is
expand_to_shapewould be the only function like this. It's okay to start the trend.get_alpha_sigmawon't work with many of the schedulers, so I'm okay with leaving it in the ones that use v-prediction for now.
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
|
Added more deprecating across the board. I tried to address @patrickvonplaten's comment above, but would like a double check on that! |
| ) | ||
|
|
||
| # not check on predict_epsilon for depreciation flag above | ||
| elif self.prediction_type == "sample" or not self.config.predict_epsilon: |
There was a problem hiding this comment.
These if statement's I had to mess with a little bit to get the tests to pass. All will be much cleaner when its deprecated.
|
The code isn't as clear, but you can see some details on model parametrization in the SD 2.0 code here. The option
@patil-suraj @patrickvonplaten @bglick13 |
starting to work on discussion in #778.
Please contribute and leave feedback. This is mostly a placeholder for my work right now as I figure out how to do it.
Some relevant repositories: