diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py
index 4926108b343f19c58cd11ae093208e6da6488e2b..13b44ae3c4e32f79ed793f3dafc58b0016dbaa99 100644
--- a/dinov2/models/vision_transformer.py
+++ b/dinov2/models/vision_transformer.py
@@ -188,21 +188,25 @@ class DinoVisionTransformer(nn.Module):
         dim = x.shape[-1]
         w0 = w // self.patch_size
         h0 = h // self.patch_size
-        # we add a small number to avoid floating point error in the interpolation
-        # see discussion at https://github.com/facebookresearch/dino/issues/8
-        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
-
-        sqrt_N = math.sqrt(N)
-        sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+        M = int(math.sqrt(N))  # Recover the number of patches in each dimension
+        assert N == M * M
+        kwargs = {}
+        if self.interpolate_offset:
+            # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+            # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+            sx = float(w0 + self.interpolate_offset) / M
+            sy = float(h0 + self.interpolate_offset) / M
+            kwargs["scale_factor"] = (sx, sy)
+        else:
+            # Simply specify an output size instead of a scale factor
+            kwargs["size"] = (w0, h0)
         patch_pos_embed = nn.functional.interpolate(
-            patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
-            scale_factor=(sx, sy),
+            patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
             mode="bicubic",
             antialias=self.interpolate_antialias,
+            **kwargs,
         )
-
-        assert int(w0) == patch_pos_embed.shape[-2]
-        assert int(h0) == patch_pos_embed.shape[-1]
+        assert (w0, h0) == patch_pos_embed.shape[-2:]
         patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
 
@@ -306,7 +310,7 @@ class DinoVisionTransformer(nn.Module):
         if norm:
             outputs = [self.norm(out) for out in outputs]
         class_tokens = [out[:, 0] for out in outputs]
-        outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
+        outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
         if reshape:
             B, _, w, h = x.shape
             outputs = [