diff --git a/hubconf.py b/hubconf.py index 4660b61c991084990aabb34d4d6640944a4222c4..69d7600a195afb2bf23d9e8a5d6ddb5a3fb4de99 100644 --- a/hubconf.py +++ b/hubconf.py @@ -109,21 +109,25 @@ class _LinearClassifierWrapper(nn.Module): def forward(self, x): if self.layers == 1: x = self.backbone.forward_features(x) - cls_token = x["x_norm_clstoken"].squeeze(0) - patch_tokens = x["x_norm_patchtokens"].squeeze(0) + cls_token = x["x_norm_clstoken"] + patch_tokens = x["x_norm_patchtokens"] + # fmt: off linear_input = torch.cat([ cls_token, - patch_tokens.mean(0) - ]) + patch_tokens.mean(dim=1), + ], dim=1) + # fmt: on elif self.layers == 4: x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True) + # fmt: off linear_input = torch.cat([ - x[0][1].squeeze(0), - x[1][1].squeeze(0), - x[2][1].squeeze(0), - x[3][1].squeeze(0), - x[3][0].squeeze(0).mean(0) - ]) + x[0][1], + x[1][1], + x[2][1], + x[3][1], + x[3][0].mean(dim=1), + ], dim=1) + # fmt: on else: assert False, f"Unsupported number of layers: {self.layers}" return self.linear_head(linear_input) @@ -141,7 +145,9 @@ def _make_dinov2_linear_classifier( embed_dim = backbone.embed_dim patch_size = backbone.patch_size model_name = _make_dinov2_model_name(arch_name, patch_size) - linear_head = _make_dinov2_linear_head(model_name=model_name, embed_dim=embed_dim, layers=layers, pretrained=pretrained) + linear_head = _make_dinov2_linear_head( + model_name=model_name, embed_dim=embed_dim, layers=layers, pretrained=pretrained + ) return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) @@ -171,4 +177,6 @@ def dinov2_vitg14_lc(*, pretrained: bool = True, **kwargs): """ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. """ - return _make_dinov2_linear_classifier(arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, **kwargs) + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, **kwargs + )