diff --git a/dinov2/layers/attention.py b/dinov2/layers/attention.py index 1f9b0c94b40967dfdff4f261c127cbd21328c905..35e38e4657b4453470e88f6933d39fcab7119287 100644 --- a/dinov2/layers/attention.py +++ b/dinov2/layers/attention.py @@ -9,6 +9,8 @@ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py import logging +import os +import warnings from torch import Tensor from torch import nn @@ -17,13 +19,19 @@ from torch import nn logger = logging.getLogger("dinov2") +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: - from xformers.ops import memory_efficient_attention, unbind, fmha - - XFORMERS_AVAILABLE = True + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError except ImportError: - logger.warning("xFormers not available") XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") class Attention(nn.Module): @@ -65,7 +73,8 @@ class Attention(nn.Module): class MemEffAttention(Attention): def forward(self, x: Tensor, attn_bias=None) -> Tensor: if not XFORMERS_AVAILABLE: - assert attn_bias is None, "xFormers is required for nested tensors usage" + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") return super().forward(x) B, N, C = x.shape diff --git a/dinov2/layers/block.py b/dinov2/layers/block.py index 25488f57cc0ad3c692f86b62555f6668e2a66db1..1ac9b40971e6d13454e4cc6ad24e5d0a1d9cf268 100644 --- a/dinov2/layers/block.py +++ b/dinov2/layers/block.py @@ -9,7 +9,9 @@ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py import logging +import os from typing import Callable, List, Any, Tuple, Dict +import warnings import torch from torch import nn, Tensor @@ -23,15 +25,21 @@ from .mlp import Mlp logger = logging.getLogger("dinov2") +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: - from xformers.ops import fmha - from xformers.ops import scaled_index_add, index_select_cat + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat - XFORMERS_AVAILABLE = True + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError except ImportError: - logger.warning("xFormers not available") XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Block)") + class Block(nn.Module): def __init__( @@ -246,7 +254,8 @@ class NestedTensorBlock(Block): if isinstance(x_or_x_list, Tensor): return super().forward(x_or_x_list) elif isinstance(x_or_x_list, list): - assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") return self.forward_nested(x_or_x_list) else: raise AssertionError diff --git a/dinov2/layers/swiglu_ffn.py b/dinov2/layers/swiglu_ffn.py index b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e..f840ded440ccbf3d1f1c019ab3440f64194ba46a 100644 --- a/dinov2/layers/swiglu_ffn.py +++ b/dinov2/layers/swiglu_ffn.py @@ -4,7 +4,9 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import os from typing import Callable, Optional +import warnings from torch import Tensor, nn import torch.nn.functional as F @@ -33,14 +35,22 @@ class SwiGLUFFN(nn.Module): return self.w3(hidden) +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: - from xformers.ops import SwiGLU + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU - XFORMERS_AVAILABLE = True + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (SwiGLU)") + else: + warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError except ImportError: SwiGLU = SwiGLUFFN XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (SwiGLU)") + class SwiGLUFFNFused(SwiGLU): def __init__( diff --git a/dinov2/train/ssl_meta_arch.py b/dinov2/train/ssl_meta_arch.py index 86d0c2413f9abc61953d0e12b43a5a843d97d244..9eeb3cf00b7ea42a301424809801d009cdd79f02 100644 --- a/dinov2/train/ssl_meta_arch.py +++ b/dinov2/train/ssl_meta_arch.py @@ -19,13 +19,11 @@ from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, r from dinov2.models.vision_transformer import BlockChunk + try: from xformers.ops import fmha - - XFORMERS_AVAILABLE = True except ImportError: - XFORMERS_AVAILABLE = False -assert XFORMERS_AVAILABLE, "xFormers is required for DINOv2 training" + raise AssertionError("xFormers is required for training") logger = logging.getLogger("dinov2")