diff --git a/dinov2/layers/attention.py b/dinov2/layers/attention.py
index c789ebd880f3eecf5c336230b4e3babe66ad2cfc..1f9b0c94b40967dfdff4f261c127cbd21328c905 100644
--- a/dinov2/layers/attention.py
+++ b/dinov2/layers/attention.py
@@ -73,11 +73,7 @@ class MemEffAttention(Attention):
 
         q, k, v = unbind(qkv, 2)
 
-        if attn_bias is not None:
-            self_att_op = fmha.MemoryEfficientAttentionFlashAttentionOp
-        else:
-            self_att_op = None
-        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=self_att_op)
+        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
         x = x.reshape([B, N, C])
 
         x = self.proj(x)