From e1277af2ba9496fbadf7aec6eba56e8d882d1e35 Mon Sep 17 00:00:00 2001
From: Patrick Labatut <60359573+patricklabatut@users.noreply.github.com>
Date: Thu, 22 Feb 2024 19:10:54 +0100
Subject: [PATCH] [bug] Fix interpolation of positional embeddings (#378)

Use size instead of scale factor to specify the output size of nn.interpolate(): this avoids any rounding issue leading to mismatching output size and consistently generate the same output size as with the previous kludge (from facebookresearch/dino#8).
---
 dinov2/models/vision_transformer.py | 28 ++++++++++++++++------------
 1 file changed, 16 insertions(+), 12 deletions(-)

diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py
index 4926108..13b44ae 100644
--- a/dinov2/models/vision_transformer.py
+++ b/dinov2/models/vision_transformer.py
@@ -188,21 +188,25 @@ class DinoVisionTransformer(nn.Module):
         dim = x.shape[-1]
         w0 = w // self.patch_size
         h0 = h // self.patch_size
-        # we add a small number to avoid floating point error in the interpolation
-        # see discussion at https://github.com/facebookresearch/dino/issues/8
-        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
-
-        sqrt_N = math.sqrt(N)
-        sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+        M = int(math.sqrt(N))  # Recover the number of patches in each dimension
+        assert N == M * M
+        kwargs = {}
+        if self.interpolate_offset:
+            # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+            # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+            sx = float(w0 + self.interpolate_offset) / M
+            sy = float(h0 + self.interpolate_offset) / M
+            kwargs["scale_factor"] = (sx, sy)
+        else:
+            # Simply specify an output size instead of a scale factor
+            kwargs["size"] = (w0, h0)
         patch_pos_embed = nn.functional.interpolate(
-            patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
-            scale_factor=(sx, sy),
+            patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
             mode="bicubic",
             antialias=self.interpolate_antialias,
+            **kwargs,
         )
-
-        assert int(w0) == patch_pos_embed.shape[-2]
-        assert int(h0) == patch_pos_embed.shape[-1]
+        assert (w0, h0) == patch_pos_embed.shape[-2:]
         patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
 
@@ -306,7 +310,7 @@ class DinoVisionTransformer(nn.Module):
         if norm:
             outputs = [self.norm(out) for out in outputs]
         class_tokens = [out[:, 0] for out in outputs]
-        outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
+        outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
         if reshape:
             B, _, w, h = x.shape
             outputs = [
-- 
GitLab