""" ImageGenerator module"""
import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline,EulerAncestralDiscreteScheduler
from diffusers.schedulers import DPMSolverMultistepScheduler
class ImageGenerator:
"""ImageGenerator class"""
def __init__(self):
self.prior_path = './models/image_generation/prior-bf16'
self.decoder_path = './models/image_generation/decoder-bf16'
self.img_width = 1024
self.img_height = 1024
self.prior_cfg = 4
self.decoder_cfg = 1.1
self.prior_num_steps = 20
self.decoder_num_steps = 10
self.num_images_per_prompt = 1
self.prior = None
self.decoder = None
def load_model(self):
self.prior = StableCascadePriorPipeline.from_pretrained(
self.prior_path, torch_dtype=torch.bfloat16,
).to('cuda')
self.decoder = StableCascadeDecoderPipeline.from_pretrained(
self.decoder_path, torch_dtype=torch.bfloat16
).to('cuda')
scheduler = EulerAncestralDiscreteScheduler()
self.prior.scheduler = scheduler
self.decoder.scheduler = scheduler
# warmup
prior_output = self.prior(
prompt="a cat",
height=1024,
width=1024,
negative_prompt="",
guidance_scale=4,
num_images_per_prompt=1,
num_inference_steps=10,
)
_ = self.decoder(
image_embeddings=prior_output.image_embeddings.to(torch.bfloat16),
prompt="a cat",
negative_prompt="",
guidance_scale=1.1,
output_type="pil",
num_inference_steps=3,
)
def inference(self, prompt, negative_prompt):
try :
prior_output = self.prior(
prompt=prompt,
height=self.img_height,
width=self.img_width,
negative_prompt=negative_prompt,
guidance_scale=self.prior_cfg,
num_images_per_prompt=1,
num_inference_steps=self.prior_num_steps,
)
result = self.decoder(
image_embeddings=prior_output.image_embeddings.to(torch.bfloat16),
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=self.decoder_cfg,
output_type="pil",
num_inference_steps=self.decoder_num_steps,
)
except Exception as e:
print('error : ',e)
return result
if __name__ == '__main__':
generator = ImageGenerator()
generator.load_model()
prompt = "a man"
negative_prompt = ""
result = generator.inference(prompt,negative_prompt)
result.images[0].save('result.png')
when I don't change sampler everything is ok.
Hi
According to this link . I want to test other samplers in stable cascade. but I get the below error.
my inference code
when I don't change sampler everything is ok.
my diffusers package version is 0.27.2
Thanks