Skip to content
Snippets Groups Projects
Unverified Commit 2302b6bf authored by qasfb's avatar qasfb Committed by GitHub
Browse files

Update vision_transformer.py (#331)

Account for register tokens in get_intermediate_layers
parent da4b3825
Branches
No related tags found
No related merge requests found
...@@ -306,7 +306,7 @@ class DinoVisionTransformer(nn.Module): ...@@ -306,7 +306,7 @@ class DinoVisionTransformer(nn.Module):
if norm: if norm:
outputs = [self.norm(out) for out in outputs] outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs] class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1:] for out in outputs] outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
if reshape: if reshape:
B, _, w, h = x.shape B, _, w, h = x.shape
outputs = [ outputs = [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment