Skip to content
Snippets Groups Projects
Unverified Commit ebc1cba1 authored by Patrick Labatut's avatar Patrick Labatut Committed by GitHub
Browse files

Allow disabling xFormers via environment variable (#180)

Allow disabling the use of xFormers (for inference) by simply setting the XFORMERS_DISABLED environment variable
parent be7e5725
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,8 @@
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
import os
import warnings
from torch import Tensor
from torch import nn
......@@ -17,13 +19,19 @@ from torch import nn
logger = logging.getLogger("dinov2")
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
from xformers.ops import memory_efficient_attention, unbind, fmha
XFORMERS_AVAILABLE = True
if XFORMERS_ENABLED:
from xformers.ops import memory_efficient_attention, unbind
XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (Attention)")
else:
warnings.warn("xFormers is disabled (Attention)")
raise ImportError
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
warnings.warn("xFormers is not available (Attention)")
class Attention(nn.Module):
......@@ -65,7 +73,8 @@ class Attention(nn.Module):
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, "xFormers is required for nested tensors usage"
if attn_bias is not None:
raise AssertionError("xFormers is required for using nested tensors")
return super().forward(x)
B, N, C = x.shape
......
......@@ -9,7 +9,9 @@
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
import os
from typing import Callable, List, Any, Tuple, Dict
import warnings
import torch
from torch import nn, Tensor
......@@ -23,15 +25,21 @@ from .mlp import Mlp
logger = logging.getLogger("dinov2")
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
from xformers.ops import fmha
from xformers.ops import scaled_index_add, index_select_cat
if XFORMERS_ENABLED:
from xformers.ops import fmha, scaled_index_add, index_select_cat
XFORMERS_AVAILABLE = True
XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (Block)")
else:
warnings.warn("xFormers is disabled (Block)")
raise ImportError
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
warnings.warn("xFormers is not available (Block)")
class Block(nn.Module):
def __init__(
......@@ -246,7 +254,8 @@ class NestedTensorBlock(Block):
if isinstance(x_or_x_list, Tensor):
return super().forward(x_or_x_list)
elif isinstance(x_or_x_list, list):
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
if not XFORMERS_AVAILABLE:
raise AssertionError("xFormers is required for using nested tensors")
return self.forward_nested(x_or_x_list)
else:
raise AssertionError
......@@ -4,7 +4,9 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Callable, Optional
import warnings
from torch import Tensor, nn
import torch.nn.functional as F
......@@ -33,14 +35,22 @@ class SwiGLUFFN(nn.Module):
return self.w3(hidden)
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
from xformers.ops import SwiGLU
if XFORMERS_ENABLED:
from xformers.ops import SwiGLU
XFORMERS_AVAILABLE = True
XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (SwiGLU)")
else:
warnings.warn("xFormers is disabled (SwiGLU)")
raise ImportError
except ImportError:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
warnings.warn("xFormers is not available (SwiGLU)")
class SwiGLUFFNFused(SwiGLU):
def __init__(
......
......@@ -19,13 +19,11 @@ from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, r
from dinov2.models.vision_transformer import BlockChunk
try:
from xformers.ops import fmha
XFORMERS_AVAILABLE = True
except ImportError:
XFORMERS_AVAILABLE = False
assert XFORMERS_AVAILABLE, "xFormers is required for DINOv2 training"
raise AssertionError("xFormers is required for training")
logger = logging.getLogger("dinov2")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment