diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py
index d572572be5f0154ae391bff28f82d095178b4a34..de212108cff11362f525f36e6678ab388ed58392 100644
--- a/dinov2/models/vision_transformer.py
+++ b/dinov2/models/vision_transformer.py
@@ -177,13 +177,16 @@ class DinoVisionTransformer(nn.Module):
         # see discussion at https://github.com/facebookresearch/dino/issues/8
         w0, h0 = w0 + 0.1, h0 + 0.1
 
+        sqrt_N = math.sqrt(N)
+        sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
         patch_pos_embed = nn.functional.interpolate(
-            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
-            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+            patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+            scale_factor=(sx, sy),
             mode="bicubic",
         )
 
-        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+        assert int(w0) == patch_pos_embed.shape[-2]
+        assert int(h0) == patch_pos_embed.shape[-1]
         patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)