From ca58ffcd8763b798db7372fcdfd2d1b95362abca Mon Sep 17 00:00:00 2001
From: Patrick Labatut <60359573+patricklabatut@users.noreply.github.com>
Date: Wed, 26 Apr 2023 02:15:45 +0200
Subject: [PATCH] Fix linear classifier wrapper (#61)

Fix linear classifier wrapper in PyTorch Hub configuration module to support multi-sample batches.
---
 hubconf.py | 32 ++++++++++++++++++++------------
 1 file changed, 20 insertions(+), 12 deletions(-)

diff --git a/hubconf.py b/hubconf.py
index 4660b61..69d7600 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -109,21 +109,25 @@ class _LinearClassifierWrapper(nn.Module):
     def forward(self, x):
         if self.layers == 1:
             x = self.backbone.forward_features(x)
-            cls_token = x["x_norm_clstoken"].squeeze(0)
-            patch_tokens = x["x_norm_patchtokens"].squeeze(0)
+            cls_token = x["x_norm_clstoken"]
+            patch_tokens = x["x_norm_patchtokens"]
+            # fmt: off
             linear_input = torch.cat([
                 cls_token,
-                patch_tokens.mean(0)
-            ])
+                patch_tokens.mean(dim=1),
+            ], dim=1)
+            # fmt: on
         elif self.layers == 4:
             x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
+            # fmt: off
             linear_input = torch.cat([
-                x[0][1].squeeze(0),
-                x[1][1].squeeze(0),
-                x[2][1].squeeze(0),
-                x[3][1].squeeze(0),
-                x[3][0].squeeze(0).mean(0)
-            ])
+                x[0][1],
+                x[1][1],
+                x[2][1],
+                x[3][1],
+                x[3][0].mean(dim=1),
+            ], dim=1)
+            # fmt: on
         else:
             assert False, f"Unsupported number of layers: {self.layers}"
         return self.linear_head(linear_input)
@@ -141,7 +145,9 @@ def _make_dinov2_linear_classifier(
     embed_dim = backbone.embed_dim
     patch_size = backbone.patch_size
     model_name = _make_dinov2_model_name(arch_name, patch_size)
-    linear_head = _make_dinov2_linear_head(model_name=model_name, embed_dim=embed_dim, layers=layers, pretrained=pretrained)
+    linear_head = _make_dinov2_linear_head(
+        model_name=model_name, embed_dim=embed_dim, layers=layers, pretrained=pretrained
+    )
 
     return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
 
@@ -171,4 +177,6 @@ def dinov2_vitg14_lc(*, pretrained: bool = True, **kwargs):
     """
     Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
     """
-    return _make_dinov2_linear_classifier(arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, **kwargs)
+    return _make_dinov2_linear_classifier(
+        arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, **kwargs
+    )
-- 
GitLab