Skip to content

Allow dtype to be specified in Flax pipeline#600

Merged
pcuenca merged 3 commits intomainfrom
flax_dtype
Sep 21, 2022
Merged

Allow dtype to be specified in Flax pipeline#600
pcuenca merged 3 commits intomainfrom
flax_dtype

Conversation

@pcuenca
Copy link
Copy Markdown
Member

@pcuenca pcuenca commented Sep 21, 2022

This replaces #581, which was reviewed by @patil-suraj.

This may be a temporary solution until #567 is addressed.
The denoising loop always computes the next step in float32, so this
would fail when using `bfloat16`.
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

HuggingFaceDocBuilderDev commented Sep 21, 2022

The documentation is not available anymore as the PR was closed or merged.

)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative to this, I think, would be to prepare the scheduler parameters using the same dtype as the model. We can do that in a follow-up PR.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, and agree we should check model dtype here.

)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, and agree we should check model dtype here.

@pcuenca pcuenca merged commit fb2fbab into main Sep 21, 2022
@pcuenca pcuenca deleted the flax_dtype branch September 21, 2022 08:57
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Fix typo in docstring.

* Allow dtype to be overridden on model load.

This may be a temporary solution until huggingface#567 is addressed.

* Create latents in float32

The denoising loop always computes the next step in float32, so this
would fail when using `bfloat16`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants