diff --git a/dinov2/layers/block.py b/dinov2/layers/block.py index 930787b262faac4f2264797496faff75ac56b7cc..68c6f45a958fde3ac4604f2b59e52de7648780f1 100644 --- a/dinov2/layers/block.py +++ b/dinov2/layers/block.py @@ -27,8 +27,26 @@ logger = logging.getLogger("dinov2") XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: - from xformers.ops import fmha, scaled_index_add, index_select_cat - + from xformers.ops import scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat + + def scaled_index_add(input, index, source, scaling, alpha): + is_proper_embed_dim = input.shape[-1] % 256 == 0 + is_float16 = input.dtype == torch.half + if is_proper_embed_dim and is_float16: + return _scaled_index_add(input, index, source, scaling, alpha) + else: + return torch.index_add(input, dim=0, source=scaling * source, index=index, alpha=alpha) + + + def index_select_cat(sources, indices): + is_proper_embed_dim = all(s.shape[-1] % 256 == 0 for s in sources) + is_float16 = all(s.dtype == torch.half for s in sources) + if is_proper_embed_dim and is_float16: + return _index_select_cat(sources, indices) + else: + return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0) + + XFORMERS_AVAILABLE = True warnings.warn("xFormers is available (Block)") else: