Skip to content

UniPC Multistep fix tensor dtype/device on order=3#7532

Merged
yiyixuxu merged 3 commits intohuggingface:mainfrom
Beinsezii:unipc_fp16
Apr 3, 2024
Merged

UniPC Multistep fix tensor dtype/device on order=3#7532
yiyixuxu merged 3 commits intohuggingface:mainfrom
Beinsezii:unipc_fp16

Conversation

@Beinsezii
Copy link
Copy Markdown
Contributor

@Beinsezii Beinsezii commented Mar 31, 2024

As I found in #7517, UniPC with order=3 was erring on step 3 due to one of the tensors not being the correct dtype and device. This PR corrects that and updates the fp16 test to check every combination of order, solver, prediction to ensure they all work against half precision.

I verified working on both FP16 and BF16 on CUDA, and the updated test will correctly fail if you checkout to the first commit 036e33e before I fixed the rhos_p typings.

The test might be a bit over-comprehensive as it stands. Right now I don't think its possible for all the parameters to create incompatible tensors but in my mind by just checking everything it should catch future refactors.

Should merge cleanly on top of #7531

@yiyixuxu

It wasn't catching errs on order==3. Might be excessive?
For completions sake. Probably overkill?
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!

@yiyixuxu yiyixuxu merged commit 19ab04f into huggingface:main Apr 3, 2024
@Beinsezii Beinsezii deleted the unipc_fp16 branch April 3, 2024 21:01
noskill pushed a commit to noskill/diffusers that referenced this pull request Apr 5, 2024
* UniPC UTs iterate solvers on FP16

It wasn't catching errs on order==3. Might be excessive?

* UniPC Multistep fix tensor dtype/device on order=3

* UniPC UTs Add v_pred to fp16 test iter

For completions sake. Probably overkill?
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* UniPC UTs iterate solvers on FP16

It wasn't catching errs on order==3. Might be excessive?

* UniPC Multistep fix tensor dtype/device on order=3

* UniPC UTs Add v_pred to fp16 test iter

For completions sake. Probably overkill?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants