T5Attention support for cross-attention#2654
T5Attention support for cross-attention#2654patrickvonplaten merged 18 commits intohuggingface:mainfrom
Conversation
Fix use of AttnProcessor2_0 for cross attention with mask
|
The documentation is not available anymore as the PR was closed or merged. |
|
Can you give a bit more background on what issue is fixed here? I'm not so sure about this tbh. |
|
Of course!! my bad! The issue is that the shape of the mask returned by With this change at least the function doesn't complain... however the outputs vs. |
|
thanks @Birch-san i am happy to close this in view of your PR. I also need to add two extra flags for |
|
Cool, this works thanks a lot for making the changes @kashif ! |
|
@Birch-san - think we could adapt your PR after this quite easily no? |
| # but only if it has the `scale` argument | ||
| if processor is None: | ||
| processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() | ||
| if torch.torch_version.TorchVersion(torch.__version__) >= (2, 1, 0): |
There was a problem hiding this comment.
Let's revert this, we don't need 2.1, 2.0 is enough and I think the logic before was good
There was a problem hiding this comment.
right but then the scaled_dot_product_attention in 2.0 has no scale which is what i would need... but yes i can deal with that in the pipeline?
There was a problem hiding this comment.
Ah I see, ok I think it's fine if Torch 2.0 doesn't work yet for the spectrogram model. Let's maybe just advertise it with the previous PyTorch version and see if the community tries it out on Pytorch 2.0
There was a problem hiding this comment.
ok cool! reverting... i can deal with it or i can also check if attn.scale == 1 and not do this... which is only for spectrogram for now?
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | ||
| hidden_states = F.scaled_dot_product_attention( | ||
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | ||
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale |
There was a problem hiding this comment.
Is this backwards compatible?
There was a problem hiding this comment.
yes since if scale=None, the default scale is used ie. the 1/sqrt(D) but only works in 2.1 nightly
|
@kashif can you also run all the slow tests for:
So that we can be sure that nothing is broken |
|
ok sure reverting and running slow tests... give me a few! |
|
ran slow tests... all failures are of this example: |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Very cool! Thanks for the PR @kashif :-)
|
ok, thanks! will add fast tests to spectrogram diffusion! |
* fix AttnProcessor2_0 Fix use of AttnProcessor2_0 for cross attention with mask * added scale_qk and out_bias flags * fixed for xformers * check if it has scale argument * Update cross_attention.py * check torch version * fix sliced attn * style * set scale * fix test * fixed addedKV processor * revert back AttnProcessor2_0 * if missing if * fix inner_dim --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* fix AttnProcessor2_0 Fix use of AttnProcessor2_0 for cross attention with mask * added scale_qk and out_bias flags * fixed for xformers * check if it has scale argument * Update cross_attention.py * check torch version * fix sliced attn * style * set scale * fix test * fixed addedKV processor * revert back AttnProcessor2_0 * if missing if * fix inner_dim --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* fix AttnProcessor2_0 Fix use of AttnProcessor2_0 for cross attention with mask * added scale_qk and out_bias flags * fixed for xformers * check if it has scale argument * Update cross_attention.py * check torch version * fix sliced attn * style * set scale * fix test * fixed addedKV processor * revert back AttnProcessor2_0 * if missing if * fix inner_dim --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Added support for implementing T5Attention to processors. Needed for #1044
Tested on pytorch 2.0 RC