Skip to content
Snippets Groups Projects
Unverified Commit ca58ffcd authored by Patrick Labatut's avatar Patrick Labatut Committed by GitHub
Browse files

Fix linear classifier wrapper (#61)

Fix linear classifier wrapper in PyTorch Hub configuration module to support multi-sample batches.
parent 3e7e278d
No related branches found
No related tags found
No related merge requests found
......@@ -109,21 +109,25 @@ class _LinearClassifierWrapper(nn.Module):
def forward(self, x):
if self.layers == 1:
x = self.backbone.forward_features(x)
cls_token = x["x_norm_clstoken"].squeeze(0)
patch_tokens = x["x_norm_patchtokens"].squeeze(0)
cls_token = x["x_norm_clstoken"]
patch_tokens = x["x_norm_patchtokens"]
# fmt: off
linear_input = torch.cat([
cls_token,
patch_tokens.mean(0)
])
patch_tokens.mean(dim=1),
], dim=1)
# fmt: on
elif self.layers == 4:
x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
# fmt: off
linear_input = torch.cat([
x[0][1].squeeze(0),
x[1][1].squeeze(0),
x[2][1].squeeze(0),
x[3][1].squeeze(0),
x[3][0].squeeze(0).mean(0)
])
x[0][1],
x[1][1],
x[2][1],
x[3][1],
x[3][0].mean(dim=1),
], dim=1)
# fmt: on
else:
assert False, f"Unsupported number of layers: {self.layers}"
return self.linear_head(linear_input)
......@@ -141,7 +145,9 @@ def _make_dinov2_linear_classifier(
embed_dim = backbone.embed_dim
patch_size = backbone.patch_size
model_name = _make_dinov2_model_name(arch_name, patch_size)
linear_head = _make_dinov2_linear_head(model_name=model_name, embed_dim=embed_dim, layers=layers, pretrained=pretrained)
linear_head = _make_dinov2_linear_head(
model_name=model_name, embed_dim=embed_dim, layers=layers, pretrained=pretrained
)
return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers)
......@@ -171,4 +177,6 @@ def dinov2_vitg14_lc(*, pretrained: bool = True, **kwargs):
"""
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
"""
return _make_dinov2_linear_classifier(arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, **kwargs)
return _make_dinov2_linear_classifier(
arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, **kwargs
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment