Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
| cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu() | ||
|
|
||
| # cast to float32 to as numpy does not support bfloat16 | ||
| if image_embeds.dtype == torch.bfloat16: |
There was a problem hiding this comment.
I don't think we need an if statement here as .float() should always work and be correct (e.g. we can't do fp16 on CPU either and doing numpy() will move it to fp32 anyways
There was a problem hiding this comment.
@patrickvonplaten why would numpy() move the arrays to fp32? I thought that numpy arrays support fp16?
I believe we should use .half() instead which should work on CPU also?
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Outdated
Show resolved
Hide resolved
NouamaneTazi
left a comment
There was a problem hiding this comment.
Unfortunately some of these precision modifications will add more unrolled_elementwise_kernel< direct_copy_kernel_cuda> kernels, but at least doesnt require CPU-GPU sync, and doesn't happen inside a loop. So LGTM :-)
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Outdated
Show resolved
Hide resolved
| cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu() | ||
|
|
||
| # cast to float32 to as numpy does not support bfloat16 | ||
| if image_embeds.dtype == torch.bfloat16: |
There was a problem hiding this comment.
@patrickvonplaten why would numpy() move the arrays to fp32? I thought that numpy arrays support fp16?
I believe we should use .half() instead which should work on CPU also?
* support bf16 for stable diffusion * fix typo * address review comments
* support bf16 for stable diffusion * fix typo * address review comments
Currently stable diffusion (or diffusers in general) doesn't work with
bf16as nearest upsampling in torch is not supported inbf16.Minimal code to reproduce:
Addinatly, in
pipelineswe need to cast the images tofp32asbf16is not yet supported innumpy.This is a draft PR to enable
bf16training/inference for stable diffusion by castinginputtofp32wherebf16is not supported.Not sure if this is the right way, curious to you hear your feedback @patrickvonplaten @NouamaneTazi
fixes #771