diff --git a/dinov2/layers/attention.py b/dinov2/layers/attention.py index c789ebd880f3eecf5c336230b4e3babe66ad2cfc..1f9b0c94b40967dfdff4f261c127cbd21328c905 100644 --- a/dinov2/layers/attention.py +++ b/dinov2/layers/attention.py @@ -73,11 +73,7 @@ class MemEffAttention(Attention): q, k, v = unbind(qkv, 2) - if attn_bias is not None: - self_att_op = fmha.MemoryEfficientAttentionFlashAttentionOp - else: - self_att_op = None - x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=self_att_op) + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape([B, N, C]) x = self.proj(x)