Skip to content
Snippets Groups Projects
Unverified Commit 44abdbe2 authored by Patrick Labatut's avatar Patrick Labatut Committed by GitHub
Browse files

Fix interpolate parameters to allow tracing (#247)

Pass scale factor as a tuple of floats to F.interpolate() to allow tracing.
parent e7df9fc9
No related branches found
No related tags found
No related merge requests found
...@@ -177,13 +177,16 @@ class DinoVisionTransformer(nn.Module): ...@@ -177,13 +177,16 @@ class DinoVisionTransformer(nn.Module):
# see discussion at https://github.com/facebookresearch/dino/issues/8 # see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1 w0, h0 = w0 + 0.1, h0 + 0.1
sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate( patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), scale_factor=(sx, sy),
mode="bicubic", mode="bicubic",
) )
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] assert int(w0) == patch_pos_embed.shape[-2]
assert int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment