From 44abdbe27c0a5d4a826eff72b7c8ab0203d02328 Mon Sep 17 00:00:00 2001
From: Patrick Labatut <60359573+patricklabatut@users.noreply.github.com>
Date: Sat, 30 Sep 2023 22:29:41 +0200
Subject: [PATCH] Fix interpolate parameters to allow tracing (#247)

Pass scale factor as a tuple of floats to F.interpolate() to allow tracing.
---
 dinov2/models/vision_transformer.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py
index d572572..de21210 100644
--- a/dinov2/models/vision_transformer.py
+++ b/dinov2/models/vision_transformer.py
@@ -177,13 +177,16 @@ class DinoVisionTransformer(nn.Module):
         # see discussion at https://github.com/facebookresearch/dino/issues/8
         w0, h0 = w0 + 0.1, h0 + 0.1
 
+        sqrt_N = math.sqrt(N)
+        sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
         patch_pos_embed = nn.functional.interpolate(
-            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
-            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+            patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+            scale_factor=(sx, sy),
             mode="bicubic",
         )
 
-        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+        assert int(w0) == patch_pos_embed.shape[-2]
+        assert int(h0) == patch_pos_embed.shape[-1]
         patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
 
-- 
GitLab