Skip to content
Snippets Groups Projects
Unverified Commit c0ffb6ed authored by Patrick Labatut's avatar Patrick Labatut Committed by GitHub
Browse files

Do not force using FlashAttention (#58)

Remove hardcoded selection of operator implementation and use xFormers fMHA dispatcher instead.
parent ca58ffcd
No related branches found
No related tags found
No related merge requests found
...@@ -73,11 +73,7 @@ class MemEffAttention(Attention): ...@@ -73,11 +73,7 @@ class MemEffAttention(Attention):
q, k, v = unbind(qkv, 2) q, k, v = unbind(qkv, 2)
if attn_bias is not None: x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
self_att_op = fmha.MemoryEfficientAttentionFlashAttentionOp
else:
self_att_op = None
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=self_att_op)
x = x.reshape([B, N, C]) x = x.reshape([B, N, C])
x = self.proj(x) x = self.proj(x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment