Skip to content

Lora QueryKey float16/float32 mismatch when loading safetensor from CivitAI #3748

@codejudas

Description

@codejudas

Describe the bug

I'm trying to load LORAs following the docs. I've downloaded a couple LORAs so far, and ensured that they are marked as fp16 on civitai.

This throws the following error:

Query/Key/Value should all have the same dtype
  query.dtype: torch.float32
  key.dtype  : torch.float16
  value.dtype: torch.float16

Reproduction

Here is the code I'm using (with this Lora):

StableDiffusionPipeline..from_pretrained(
                    "runwayml/stable-diffusion-v1-5",
                    local_files_only=True,
                    cache_dir=".",
                    torch_dtype=torch.float16,
                    safety_checker=None,
                    requires_safety_checker=False,
)
pipeline.to("cuda)
pipeline.enable_xformers_memory_efficient_attention()
pipeline.load_lora_weights(
                        ".",
                        weight_name="3DMM_V7.safetensors",
                        torch_dtype=torch.float16,
                        cache_dir=".",
                        local_files_only=True,
)
with autocast("cuda"):
    images =  pipeline(  ## This line throws
            prompt="a blue cat, arcane style",
            negative_prompt="",
            height=512,
            width=512,
            num_inference_steps=50,
            guidance_scale=7.5,
            num_images_per_prompt=3,
            generator=None,
            cross_attention_kwargs={"scale": 0.5}
    )

Logs

Query/Key/Value should all have the same dtype
  query.dtype: torch.float32
  key.dtype  : torch.float16
  value.dtype: torch.float16
Traceback (most recent call last):
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 728, in __call__
    noise_pred = self.unet(
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 797, in forward
    sample, res_samples = downsample_block(
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 924, in forward
    hidden_states = attn(
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/diffusers/models/transformer_2d.py", line 296, in forward
    hidden_states = block(
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/diffusers/models/attention.py", line 160, in forward
    attn_output = self.attn2(
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/diffusers/models/attention_processor.py", line 320, in forward
    return self.processor(
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/diffusers/models/attention_processor.py", line 1221, in __call__
    hidden_states = xformers.ops.memory_efficient_attention(
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 192, in memory_efficient_attention
    return _memory_efficient_attention(
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 290, in _memory_efficient_attention
    return _memory_efficient_attention_forward(
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 303, in _memory_efficient_attention_forward
    inp.validate_inputs()
  File "/home/devonperoutky/.cache/pypoetry/virtualenvs/easel-w0jPk5Y7-py3.8/lib/python3.8/site-packages/xformers/ops/fmha/common.py", line 73, in validate_inputs
    raise ValueError(
ValueError: Query/Key/Value should all have the same dtype
  query.dtype: torch.float32
  key.dtype  : torch.float16
  value.dtype: torch.float16


### System Info

Python 3.8.10
Debian 10.2.1-6
diffusers 0.17.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions