From 942cc2dee424075a8685f87c9ac831d3762d867b Mon Sep 17 00:00:00 2001
From: qasfb <36480216+qasfb@users.noreply.github.com>
Date: Fri, 1 Sep 2023 15:07:26 +0200
Subject: [PATCH] Update implementation of scaled_index_add and
 index_select_cat with fallback

Add a check for the proper size and type for the optimized scaled_index_add and index_select_cat, and fallback on pytorch implementation otherwise
---
 dinov2/layers/block.py | 22 ++++++++++++++++++++--
 1 file changed, 20 insertions(+), 2 deletions(-)

diff --git a/dinov2/layers/block.py b/dinov2/layers/block.py
index 930787b..68c6f45 100644
--- a/dinov2/layers/block.py
+++ b/dinov2/layers/block.py
@@ -27,8 +27,26 @@ logger = logging.getLogger("dinov2")
 XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
 try:
     if XFORMERS_ENABLED:
-        from xformers.ops import fmha, scaled_index_add, index_select_cat
-
+        from xformers.ops import scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat
+
+        def scaled_index_add(input, index, source, scaling, alpha):
+            is_proper_embed_dim = input.shape[-1] % 256 == 0
+            is_float16 = input.dtype == torch.half
+            if is_proper_embed_dim and is_float16:
+                return _scaled_index_add(input, index, source, scaling, alpha)
+            else:
+                return torch.index_add(input, dim=0, source=scaling * source, index=index, alpha=alpha)
+        
+        
+        def index_select_cat(sources, indices):
+            is_proper_embed_dim = all(s.shape[-1] % 256 == 0 for s in sources)
+            is_float16 = all(s.dtype == torch.half for s in sources)
+            if is_proper_embed_dim and is_float16:
+                return _index_select_cat(sources, indices)
+            else:
+                return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0)
+
+        
         XFORMERS_AVAILABLE = True
         warnings.warn("xFormers is available (Block)")
     else:
-- 
GitLab