From 9a4564ce5ebfe66a37fd16c6a233fb04ffb0a752 Mon Sep 17 00:00:00 2001
From: Patrick Labatut <60359573+patricklabatut@users.noreply.github.com>
Date: Wed, 27 Sep 2023 17:06:03 +0200
Subject: [PATCH] Rework Pytorch Hub support code (#202)

Rework support code for torch.hub.load() to allow reusing shared functions and eventually expose more models.
---
 dinov2/hub/__init__.py    |   4 +
 dinov2/hub/backbones.py   |  84 ++++++++++++++++++
 dinov2/hub/classifiers.py | 147 ++++++++++++++++++++++++++++++++
 dinov2/hub/utils.py       |  13 +++
 hubconf.py                | 175 +-------------------------------------
 5 files changed, 250 insertions(+), 173 deletions(-)
 create mode 100644 dinov2/hub/__init__.py
 create mode 100644 dinov2/hub/backbones.py
 create mode 100644 dinov2/hub/classifiers.py
 create mode 100644 dinov2/hub/utils.py

diff --git a/dinov2/hub/__init__.py b/dinov2/hub/__init__.py
new file mode 100644
index 0000000..b88da6b
--- /dev/null
+++ b/dinov2/hub/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/dinov2/hub/backbones.py b/dinov2/hub/backbones.py
new file mode 100644
index 0000000..17e0098
--- /dev/null
+++ b/dinov2/hub/backbones.py
@@ -0,0 +1,84 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+    LVD142M = "LVD142M"
+
+
+def _make_dinov2_model(
+    *,
+    arch_name: str = "vit_large",
+    img_size: int = 518,
+    patch_size: int = 14,
+    init_values: float = 1.0,
+    ffn_layer: str = "mlp",
+    block_chunks: int = 0,
+    pretrained: bool = True,
+    weights: Union[Weights, str] = Weights.LVD142M,
+    **kwargs,
+):
+    from ..models import vision_transformer as vits
+
+    if isinstance(weights, str):
+        try:
+            weights = Weights[weights]
+        except KeyError:
+            raise AssertionError(f"Unsupported weights: {weights}")
+
+    model_name = _make_dinov2_model_name(arch_name, patch_size)
+    vit_kwargs = dict(
+        img_size=img_size,
+        patch_size=patch_size,
+        init_values=init_values,
+        ffn_layer=ffn_layer,
+        block_chunks=block_chunks,
+    )
+    vit_kwargs.update(**kwargs)
+    model = vits.__dict__[arch_name](**vit_kwargs)
+
+    if pretrained:
+        url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth"
+        state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+        model.load_state_dict(state_dict, strict=False)
+
+    return model
+
+
+def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+    """
+    DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
+    """
+    return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+    """
+    DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
+    """
+    return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+    """
+    DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+    """
+    return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+    """
+    DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
+    """
+    return _make_dinov2_model(
+        arch_name="vit_giant2", ffn_layer="swiglufused", weights=weights, pretrained=pretrained, **kwargs
+    )
diff --git a/dinov2/hub/classifiers.py b/dinov2/hub/classifiers.py
new file mode 100644
index 0000000..636a732
--- /dev/null
+++ b/dinov2/hub/classifiers.py
@@ -0,0 +1,147 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+import torch.nn as nn
+
+from .backbones import _make_dinov2_model
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+    IMAGENET1K = "IMAGENET1K"
+
+
+def _make_dinov2_linear_classification_head(
+    *,
+    model_name: str = "dinov2_vitl14",
+    embed_dim: int = 1024,
+    layers: int = 4,
+    pretrained: bool = True,
+    weights: Union[Weights, str] = Weights.IMAGENET1K,
+    **kwargs,
+):
+    if layers not in (1, 4):
+        raise AssertionError(f"Unsupported number of layers: {layers}")
+    if isinstance(weights, str):
+        try:
+            weights = Weights[weights]
+        except KeyError:
+            raise AssertionError(f"Unsupported weights: {weights}")
+
+    linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
+
+    if pretrained:
+        layers_str = str(layers) if layers == 4 else ""
+        url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_linear{layers_str}_head.pth"
+        state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+        linear_head.load_state_dict(state_dict, strict=False)
+
+    return linear_head
+
+
+class _LinearClassifierWrapper(nn.Module):
+    def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4):
+        super().__init__()
+        self.backbone = backbone
+        self.linear_head = linear_head
+        self.layers = layers
+
+    def forward(self, x):
+        if self.layers == 1:
+            x = self.backbone.forward_features(x)
+            cls_token = x["x_norm_clstoken"]
+            patch_tokens = x["x_norm_patchtokens"]
+            # fmt: off
+            linear_input = torch.cat([
+                cls_token,
+                patch_tokens.mean(dim=1),
+            ], dim=1)
+            # fmt: on
+        elif self.layers == 4:
+            x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
+            # fmt: off
+            linear_input = torch.cat([
+                x[0][1],
+                x[1][1],
+                x[2][1],
+                x[3][1],
+                x[3][0].mean(dim=1),
+            ], dim=1)
+            # fmt: on
+        else:
+            assert False, f"Unsupported number of layers: {self.layers}"
+        return self.linear_head(linear_input)
+
+
+def _make_dinov2_linear_classifier(
+    *,
+    arch_name: str = "vit_large",
+    layers: int = 4,
+    pretrained: bool = True,
+    weights: Union[Weights, str] = Weights.IMAGENET1K,
+    **kwargs,
+):
+    backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
+
+    embed_dim = backbone.embed_dim
+    patch_size = backbone.patch_size
+    model_name = _make_dinov2_model_name(arch_name, patch_size)
+    linear_head = _make_dinov2_linear_classification_head(
+        model_name=model_name,
+        embed_dim=embed_dim,
+        layers=layers,
+        pretrained=pretrained,
+        weights=weights,
+    )
+
+    return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
+
+
+def dinov2_vits14_lc(
+    *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+    """
+    Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+    """
+    return _make_dinov2_linear_classifier(
+        arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
+    )
+
+
+def dinov2_vitb14_lc(
+    *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+    """
+    Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+    """
+    return _make_dinov2_linear_classifier(
+        arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
+    )
+
+
+def dinov2_vitl14_lc(
+    *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+    """
+    Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+    """
+    return _make_dinov2_linear_classifier(
+        arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
+    )
+
+
+def dinov2_vitg14_lc(
+    *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs
+):
+    """
+    Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
+    """
+    return _make_dinov2_linear_classifier(
+        arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
+    )
diff --git a/dinov2/hub/utils.py b/dinov2/hub/utils.py
new file mode 100644
index 0000000..f1829b4
--- /dev/null
+++ b/dinov2/hub/utils.py
@@ -0,0 +1,13 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
+    compact_arch_name = arch_name.replace("_", "")[:4]
+    return f"dinov2_{compact_arch_name}{patch_size}"
diff --git a/hubconf.py b/hubconf.py
index fc6a4ee..b3b4483 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -3,179 +3,8 @@
 # This source code is licensed under the Apache License, Version 2.0
 # found in the LICENSE file in the root directory of this source tree.
 
-import torch
-import torch.nn as nn
 
+from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14
+from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc
 
 dependencies = ["torch"]
-
-
-_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
-
-
-def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
-    compact_arch_name = arch_name.replace("_", "")[:4]
-    return f"dinov2_{compact_arch_name}{patch_size}"
-
-
-def _make_dinov2_model(
-    *,
-    arch_name: str = "vit_large",
-    img_size: int = 518,
-    patch_size: int = 14,
-    init_values: float = 1.0,
-    ffn_layer: str = "mlp",
-    block_chunks: int = 0,
-    pretrained: bool = True,
-    **kwargs,
-):
-    from dinov2.models import vision_transformer as vits
-
-    model_name = _make_dinov2_model_name(arch_name, patch_size)
-    vit_kwargs = dict(
-        img_size=img_size,
-        patch_size=patch_size,
-        init_values=init_values,
-        ffn_layer=ffn_layer,
-        block_chunks=block_chunks,
-    )
-    vit_kwargs.update(**kwargs)
-    model = vits.__dict__[arch_name](**vit_kwargs)
-
-    if pretrained:
-        url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth"
-        state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
-        model.load_state_dict(state_dict, strict=False)
-
-    return model
-
-
-def dinov2_vits14(*, pretrained: bool = True, **kwargs):
-    """
-    DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
-    """
-    return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, **kwargs)
-
-
-def dinov2_vitb14(*, pretrained: bool = True, **kwargs):
-    """
-    DINOv2 ViT-B/14 model pretrained on the LVD-142M dataset.
-    """
-    return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, **kwargs)
-
-
-def dinov2_vitl14(*, pretrained: bool = True, **kwargs):
-    """
-    DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
-    """
-    return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, **kwargs)
-
-
-def dinov2_vitg14(*, pretrained: bool = True, **kwargs):
-    """
-    DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
-    """
-    return _make_dinov2_model(arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, **kwargs)
-
-
-def _make_dinov2_linear_head(
-    *,
-    model_name: str = "dinov2_vitl14",
-    embed_dim: int = 1024,
-    layers: int = 4,
-    pretrained: bool = True,
-    **kwargs,
-):
-    assert layers in (1, 4), f"Unsupported number of layers: {layers}"
-    linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
-
-    if pretrained:
-        layers_str = str(layers) if layers == 4 else ""
-        url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_linear{layers_str}_head.pth"
-        state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
-        linear_head.load_state_dict(state_dict, strict=False)
-
-    return linear_head
-
-
-class _LinearClassifierWrapper(nn.Module):
-    def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4):
-        super().__init__()
-        self.backbone = backbone
-        self.linear_head = linear_head
-        self.layers = layers
-
-    def forward(self, x):
-        if self.layers == 1:
-            x = self.backbone.forward_features(x)
-            cls_token = x["x_norm_clstoken"]
-            patch_tokens = x["x_norm_patchtokens"]
-            # fmt: off
-            linear_input = torch.cat([
-                cls_token,
-                patch_tokens.mean(dim=1),
-            ], dim=1)
-            # fmt: on
-        elif self.layers == 4:
-            x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
-            # fmt: off
-            linear_input = torch.cat([
-                x[0][1],
-                x[1][1],
-                x[2][1],
-                x[3][1],
-                x[3][0].mean(dim=1),
-            ], dim=1)
-            # fmt: on
-        else:
-            assert False, f"Unsupported number of layers: {self.layers}"
-        return self.linear_head(linear_input)
-
-
-def _make_dinov2_linear_classifier(
-    *,
-    arch_name: str = "vit_large",
-    layers: int = 4,
-    pretrained: bool = True,
-    **kwargs,
-):
-    backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
-
-    embed_dim = backbone.embed_dim
-    patch_size = backbone.patch_size
-    model_name = _make_dinov2_model_name(arch_name, patch_size)
-    linear_head = _make_dinov2_linear_head(
-        model_name=model_name, embed_dim=embed_dim, layers=layers, pretrained=pretrained
-    )
-
-    return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
-
-
-def dinov2_vits14_lc(*, layers: int = 4, pretrained: bool = True, **kwargs):
-    """
-    Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
-    """
-    return _make_dinov2_linear_classifier(arch_name="vit_small", layers=layers, pretrained=pretrained, **kwargs)
-
-
-def dinov2_vitb14_lc(*, layers: int = 4, pretrained: bool = True, **kwargs):
-    """
-    Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
-    """
-    return _make_dinov2_linear_classifier(arch_name="vit_base", layers=layers, pretrained=pretrained, **kwargs)
-
-
-def dinov2_vitl14_lc(*, layers: int = 4, pretrained: bool = True, **kwargs):
-    """
-    Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
-    """
-    return _make_dinov2_linear_classifier(arch_name="vit_large", layers=layers, pretrained=pretrained, **kwargs)
-
-
-def dinov2_vitg14_lc(*, layers: int = 4, pretrained: bool = True, **kwargs):
-    """
-    Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
-    """
-    return _make_dinov2_linear_classifier(
-        arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, **kwargs
-    )
-- 
GitLab