-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[Debugging] Stable Diffusion #592
Copy link
Copy link
Closed
Labels
staleIssues that haven't received updatesIssues that haven't received updates
Description
To quickly debug new functionality for stable diffusion you can use this dummy model:
For PyTorch:
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe")For Flax:
from diffusers import FlaxStableDiffusionPipeline
from jax import pmap
import numpy as np
import jax
from flax.jax_utils import replicate
from flax.training.common_utils import shard
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe")
# or
# pipeline = FlaxStableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe", from_pt=True)
# ideally test both
#pipeline.save_pretrained("./tiny-stable-diffusion-pipe", params=params)
# should also work
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 4
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
# first try debug mode
# sample = pipeline(prompt_ids, params, prng_seed, num_inference_steps, debug=True)
# next we go to TPUv2-8
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
import ipdb; ipdb.set_trace()
print("works!")Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
staleIssues that haven't received updatesIssues that haven't received updates
Type
Fields
Give feedbackNo fields configured for issues without a type.