diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index c8c3ec277db73bf667660372f07fae4cce6d9b60..4926108b343f19c58cd11ae093208e6da6488e2b 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -306,7 +306,7 @@ class DinoVisionTransformer(nn.Module): if norm: outputs = [self.norm(out) for out in outputs] class_tokens = [out[:, 0] for out in outputs] - outputs = [out[:, 1:] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] if reshape: B, _, w, h = x.shape outputs = [