From ebc1cba1096de0a5388527817cf1e5d4d3896699 Mon Sep 17 00:00:00 2001 From: Patrick Labatut <60359573+patricklabatut@users.noreply.github.com> Date: Wed, 30 Aug 2023 17:20:47 +0200 Subject: [PATCH] Allow disabling xFormers via environment variable (#180) Allow disabling the use of xFormers (for inference) by simply setting the XFORMERS_DISABLED environment variable --- dinov2/layers/attention.py | 19 ++++++++++++++----- dinov2/layers/block.py | 19 ++++++++++++++----- dinov2/layers/swiglu_ffn.py | 14 ++++++++++++-- dinov2/train/ssl_meta_arch.py | 6 ++---- 4 files changed, 42 insertions(+), 16 deletions(-) diff --git a/dinov2/layers/attention.py b/dinov2/layers/attention.py index 1f9b0c9..35e38e4 100644 --- a/dinov2/layers/attention.py +++ b/dinov2/layers/attention.py @@ -9,6 +9,8 @@ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py import logging +import os +import warnings from torch import Tensor from torch import nn @@ -17,13 +19,19 @@ from torch import nn logger = logging.getLogger("dinov2") +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: - from xformers.ops import memory_efficient_attention, unbind, fmha - - XFORMERS_AVAILABLE = True + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError except ImportError: - logger.warning("xFormers not available") XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") class Attention(nn.Module): @@ -65,7 +73,8 @@ class Attention(nn.Module): class MemEffAttention(Attention): def forward(self, x: Tensor, attn_bias=None) -> Tensor: if not XFORMERS_AVAILABLE: - assert attn_bias is None, "xFormers is required for nested tensors usage" + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") return super().forward(x) B, N, C = x.shape diff --git a/dinov2/layers/block.py b/dinov2/layers/block.py index 25488f5..1ac9b40 100644 --- a/dinov2/layers/block.py +++ b/dinov2/layers/block.py @@ -9,7 +9,9 @@ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py import logging +import os from typing import Callable, List, Any, Tuple, Dict +import warnings import torch from torch import nn, Tensor @@ -23,15 +25,21 @@ from .mlp import Mlp logger = logging.getLogger("dinov2") +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: - from xformers.ops import fmha - from xformers.ops import scaled_index_add, index_select_cat + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat - XFORMERS_AVAILABLE = True + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError except ImportError: - logger.warning("xFormers not available") XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Block)") + class Block(nn.Module): def __init__( @@ -246,7 +254,8 @@ class NestedTensorBlock(Block): if isinstance(x_or_x_list, Tensor): return super().forward(x_or_x_list) elif isinstance(x_or_x_list, list): - assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") return self.forward_nested(x_or_x_list) else: raise AssertionError diff --git a/dinov2/layers/swiglu_ffn.py b/dinov2/layers/swiglu_ffn.py index b3324b2..f840ded 100644 --- a/dinov2/layers/swiglu_ffn.py +++ b/dinov2/layers/swiglu_ffn.py @@ -4,7 +4,9 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import os from typing import Callable, Optional +import warnings from torch import Tensor, nn import torch.nn.functional as F @@ -33,14 +35,22 @@ class SwiGLUFFN(nn.Module): return self.w3(hidden) +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: - from xformers.ops import SwiGLU + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU - XFORMERS_AVAILABLE = True + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (SwiGLU)") + else: + warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError except ImportError: SwiGLU = SwiGLUFFN XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (SwiGLU)") + class SwiGLUFFNFused(SwiGLU): def __init__( diff --git a/dinov2/train/ssl_meta_arch.py b/dinov2/train/ssl_meta_arch.py index 86d0c24..9eeb3cf 100644 --- a/dinov2/train/ssl_meta_arch.py +++ b/dinov2/train/ssl_meta_arch.py @@ -19,13 +19,11 @@ from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, r from dinov2.models.vision_transformer import BlockChunk + try: from xformers.ops import fmha - - XFORMERS_AVAILABLE = True except ImportError: - XFORMERS_AVAILABLE = False -assert XFORMERS_AVAILABLE, "xFormers is required for DINOv2 training" + raise AssertionError("xFormers is required for training") logger = logging.getLogger("dinov2") -- GitLab