diff --git a/MODEL_CARD.md b/MODEL_CARD.md index d158d5828aec4225c881f3621aec1e06d06c8a6c..21b9bf295c8cab14e782e1a7a1d051be9e501088 100644 --- a/MODEL_CARD.md +++ b/MODEL_CARD.md @@ -1,12 +1,16 @@ # Model Card for DINOv2-S/B/L/g -These are Vision Transformer models trained following the method described in the paper: +These are Vision Transformer models trained following the method described in the papers: "DINOv2: Learning Robust Visual Features without Supervision" +and +"Vision Transformers Need Registers". -We provide 4 models: 1 ViT-g trained from scratch, and 3 ViT-S/B/L models distilled from the ViT-g. +We provide 8 models: +- 1 ViT-g trained from scratch with 3 ViT-S/B/L models distilled from the ViT-g, without registers. +- 1 ViT-g trained from scratch with 3 ViT-S/B/L models distilled from the ViT-g, with registers. ## Model Details -The model takes an image as input and returns a class token and patch tokens. +The model takes an image as input and returns a class token and patch tokens, and optionally 4 register tokens. The embedding dimension is: - 384 for ViT-S. @@ -14,9 +18,9 @@ The embedding dimension is: - 1024 for ViT-L. - 1536 for ViT-g. -The models follow a Transformer architecture, with a patch size of 14. +The models follow a Transformer architecture, with a patch size of 14. In the case of registers, we add 4 register tokens, learned during training, to the input sequence after the patch embedding. -For a 224x224 image, this results in 1 class token + 256 patch tokens. +For a 224x224 image, this results in 1 class token + 256 patch tokens, and optionally 4 register tokens. The models can accept larger images provided the image shapes are multiples of the patch size (14). If this condition is not verified, the model will crop to the closest smaller multiple of the patch size. @@ -63,10 +67,18 @@ Use the code below to get started with the model. ```python import torch + +# DINOv2 dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') + +# DINOv2 with registers +dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg') +dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg') +dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg') +dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg') ``` ## Training Details @@ -92,11 +104,11 @@ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') ## Evaluation -We refer users to the associated paper for the evaluation protocols. +We refer users to the associated papers for the evaluation protocols. <table> <tr> - <th>model</th> + <th colspan="2"></th> <th colspan="3">ImageNet-1k</th> <th>NYU-Depth v2</th> <th>SUN-RGBD</th> @@ -105,7 +117,8 @@ We refer users to the associated paper for the evaluation protocols. <th>Oxford-H</th> </tr> <tr> - <th rowspan="2">task</th> + <th rowspan="2">model</th> + <th rowspan="2">with <br /> registers</th> <th>classif. (acc)</th> <th>classif. (acc)</th> <th>classif. V2 (acc)</th> @@ -128,6 +141,7 @@ We refer users to the associated paper for the evaluation protocols. </tr> <tr> <td>ViT-S/14</td> + <td align="center">:x:</td> <td align="right">79.0%</td> <td align="right">81.1%</td> <td align="right">70.8%</td> @@ -137,8 +151,21 @@ We refer users to the associated paper for the evaluation protocols. <td align="right">69.5%</td> <td align="right">43.2</td> </tr> + <tr> + <td>ViT-S/14</td> + <td align="center">:white_check_mark:</td> + <td align="right">79.1%</td> + <td align="right">80.9%</td> + <td align="right">71.0%</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">67.6%</td> + <td align="right">39.5</td> + </tr> <tr> <td>ViT-B/14</td> + <td align="center">:x:</td> <td align="right">82.1%</td> <td align="right">84.5%</td> <td align="right">74.9%</td> @@ -147,9 +174,21 @@ We refer users to the associated paper for the evaluation protocols. <td align="right">51.3</td> <td align="right">76.3%</td> <td align="right">49.5</td> + </tr> + <td>ViT-B/14</td> + <td align="center">:white_check_mark:</td> + <td align="right">82.0%</td> + <td align="right">84.6%</td> + <td align="right">75.6%</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">73.8%</td> + <td align="right">51.0</td> </tr> <tr> <td>ViT-L/14</td> + <td align="center">:x:</td> <td align="right">83.5%</td> <td align="right">86.3%</td> <td align="right">77.6%</td> @@ -159,8 +198,21 @@ We refer users to the associated paper for the evaluation protocols. <td align="right">79.8%</td> <td align="right">54.0</td> </tr> + <tr> + <td>ViT-L/14</td> + <td align="center">:white_check_mark:</td> + <td align="right">83.8%</td> + <td align="right">86.7%</td> + <td align="right">78.5%</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">80.9%</td> + <td align="right">55.7</td> + </tr> <tr> <td>ViT-g/14</td> + <td align="center">:x:</td> <td align="right">83.5%</td> <td align="right">86.5%</td> <td align="right">78.4%</td> @@ -170,6 +222,19 @@ We refer users to the associated paper for the evaluation protocols. <td align="right">81.6%</td> <td align="right">52.3</td> </tr> + <tr> + <tr> + <td>ViT-g/14</td> + <td align="center">:white_check_mark:</td> + <td align="right">83.7%</td> + <td align="right">87.1%</td> + <td align="right">78.8%</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">81.5%</td> + <td align="right">58.2</td> + </tr> </table> ## Environmental Impact @@ -198,4 +263,10 @@ xFormers 0.0.18 journal={arXiv:2304.07193}, year={2023} } +@misc{darcet2023vitneedreg, + title={Vision Transformers Need Registers}, + author={Darcet, Timothée and Oquab, Maxime and Mairal, Julien and Bojanowski, Piotr}, + journal={arXiv:2309.16588}, + year={2023} +} ``` diff --git a/README.md b/README.md index 3e1a1d5cc552d88f87c3ada28345168abf35333c..8ea1060a3b0e09266efb0d8e395d07a7125906c3 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +:new: [2023-10-26] *Added DINOv2 backbones with registers.* + # DINOv2: Learning Robust Visual Features without Supervision **[Meta AI Research, FAIR](https://ai.facebook.com/research/)** @@ -31,6 +33,7 @@ https://github.com/facebookresearch/dinov2/assets/60359573/f168823e-7922-415a-b4 <tr> <th>model</th> <th># of<br />params</th> + <th>with<br />registers</th> <th>ImageNet<br />k-NN</th> <th>ImageNet<br />linear</th> <th>download</th> @@ -40,31 +43,67 @@ https://github.com/facebookresearch/dinov2/assets/60359573/f168823e-7922-415a-b4 <tr> <td>ViT-S/14 distilled</td> <td align="right">21 M</td> + <td align="center">:x:</td> <td align="right">79.0%</td> <td align="right">81.1%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth">backbone only</a></td> </tr> + <tr> + <td>ViT-S/14 distilled</td> + <td align="right">21 M</td> + <td align="center">:white_check_mark:</td> + <td align="right">79.1%</td> + <td align="right">80.9%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth">backbone only</a></td> + </tr> <tr> <td>ViT-B/14 distilled</td> <td align="right">86 M</td> + <td align="center">:x:</td> <td align="right">82.1%</td> <td align="right">84.5%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth">backbone only</a></td> </tr> + <tr> + <td>ViT-B/14 distilled</td> + <td align="right">86 M</td> + <td align="center">:white_check_mark:</td> + <td align="right">82.0%</td> + <td align="right">84.6%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth">backbone only</a></td> + </tr> <tr> <td>ViT-L/14 distilled</td> <td align="right">300 M</td> + <td align="center">:x:</td> <td align="right">83.5%</td> <td align="right">86.3%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth">backbone only</a></td> </tr> + <tr> + <td>ViT-L/14 distilled</td> + <td align="right">300 M</td> + <td align="center">:white_check_mark:</td> + <td align="right">83.8%</td> + <td align="right">86.7%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth">backbone only</a></td> + </tr> <tr> <td>ViT-g/14</td> <td align="right">1,100 M</td> + <td align="center">:x:</td> <td align="right">83.5%</td> <td align="right">86.5%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth">backbone only</a></td> </tr> + <tr> + <td>ViT-g/14</td> + <td align="right">1,100 M</td> + <td align="center">:white_check_mark:</td> + <td align="right">83.7%</td> + <td align="right">87.1%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth">backbone only</a></td> + </tr> </tbody> </table> @@ -77,10 +116,17 @@ A corresponding [model card](MODEL_CARD.md) is included in the repository. ```python import torch +# DINOv2 dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') + +# DINOv2 with registers +dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg') +dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg') +dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg') +dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg') ``` ### Pretrained heads - Image classification @@ -89,6 +135,7 @@ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') <thead> <tr> <th rowspan="2">backbone</th> + <th rowspan="2">with<br />registers</th> <th>download</th> </tr> <tr> @@ -98,29 +145,62 @@ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') <tbody> <tr> <td>ViT-S/14 distilled</td> + <td align="center">:x:</td> <td> linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">1 layer</a>, <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear4_head.pth">4 layers</a>) </td> </tr> + <tr> + <td>ViT-S/14 distilled</td> + <td align="center">:white_check_mark:</td> + <td> + linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth">1 layer</a>, + <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear4_head.pth">4 layers</a>) + </td> + </tr> <tr> <td>ViT-B/14 distilled</td> + <td align="center">:x:</td> <td> linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">1 layer</a>, <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear4_head.pth">4 layers</a>) </tr> + <tr> + <td>ViT-B/14 distilled</td> + <td align="center">:white_check_mark:</td> + <td> + linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth">1 layer</a>, + <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear4_head.pth">4 layers</a>) + </tr> <tr> <td>ViT-L/14 distilled</td> + <td align="center">:x:</td> <td> linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">1 layer</a>, <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear4_head.pth">4 layers</a>) </tr> + <tr> + <td>ViT-L/14 distilled</td> + <td align="center">:white_check_mark:</td> + <td> + linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth">1 layer</a>, + <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear4_head.pth">4 layers</a>) + </tr> <tr> <td>ViT-g/14</td> + <td align="center">:x:</td> <td> linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">1 layer</a>, <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear4_head.pth">4 layers</a>) </tr> + <tr> + <td>ViT-g/14</td> + <td align="center">:white_check_mark:</td> + <td> + linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_lreg4_inear_head.pth">1 layer</a>, + <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear4_head.pth">4 layers</a>) + </tr> </tbody> </table> @@ -129,10 +209,17 @@ The (full) classifier models can be loaded via PyTorch Hub: ```python import torch +# DINOv2 dinov2_vits14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc') dinov2_vitb14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc') dinov2_vitl14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_lc') dinov2_vitg14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_lc') + +# DINOv2 with registers +dinov2_vits14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg_lc') +dinov2_vitb14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg_lc') +dinov2_vitl14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg_lc') +dinov2_vitg14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg_lc') ``` ### Pretrained heads - Depth estimation @@ -429,29 +516,58 @@ We release the weights from evaluating the different models: <table style="margin: auto"> <tr> <th>model</th> + <th>with<br />registers</th> <th>ImageNet<br />top-1</th> <th>linear evaluation</th> </tr> <tr> <td>ViT-S/14 distilled</td> + <td align="center">:x:</td> <td align="right">81.1%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">linear head weights</a></td> </tr> + <tr> + <td>ViT-S/14 distilled</td> + <td align="center">:white_check_mark:</td> + <td align="right">80.8%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth">linear head weights</a></td> + </tr> <tr> <td>ViT-B/14 distilled</td> + <td align="center">:x:</td> <td align="right">84.5%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">linear head weights</a></td> </tr> + <tr> + <td>ViT-B/14 distilled</td> + <td align="center">:white_check_mark:</td> + <td align="right">84.4%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth">linear head weights</a></td> + </tr> <tr> <td>ViT-L/14 distilled</td> + <td align="center">:x:</td> <td align="right">86.3%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">linear head weights</a></td> </tr> + <tr> + <td>ViT-L/14 distilled</td> + <td align="center">:white_check_mark:</td> + <td align="right">86.5%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth">linear head weights</a></td> + </tr> <tr> <td>ViT-g/14</td> + <td align="center">:x:</td> <td align="right">86.5%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">linear head weights</a></td> </tr> + <tr> + <td>ViT-g/14</td> + <td align="center">:white_check_mark:</td> + <td align="right">87.0%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear_head.pth">linear head weights</a></td> + </tr> </table> The performance of the provided pretrained model weights can be evaluated as follows on ImageNet-1k: @@ -493,3 +609,12 @@ If you find this repository useful, please consider giving a star :star: and cit year={2023} } ``` + +``` +@misc{darcet2023vitneedreg, + title={Vision Transformers Need Registers}, + author={Darcet, Timothée and Oquab, Maxime and Mairal, Julien and Bojanowski, Piotr}, + journal={arXiv:2309.16588}, + year={2023} +} +``` diff --git a/dinov2/configs/eval/vitb14_reg4_pretrain.yaml b/dinov2/configs/eval/vitb14_reg4_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d53edc04a0761b4b35c147d63e04d55c90092c8f --- /dev/null +++ b/dinov2/configs/eval/vitb14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_base + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vitg14_reg4_pretrain.yaml b/dinov2/configs/eval/vitg14_reg4_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15948f8589ea0a6e04717453eb88c18388e7f1b2 --- /dev/null +++ b/dinov2/configs/eval/vitg14_reg4_pretrain.yaml @@ -0,0 +1,10 @@ +student: + arch: vit_giant2 + patch_size: 14 + ffn_layer: swiglufused + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vitl14_reg4_pretrain.yaml b/dinov2/configs/eval/vitl14_reg4_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e2bc4e7b24b1a64d0369a24927996d0f184e283 --- /dev/null +++ b/dinov2/configs/eval/vitl14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_large + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vits14_reg4_pretrain.yaml b/dinov2/configs/eval/vits14_reg4_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d25fd638389bfba9220792302dc9dbf5d9a2406a --- /dev/null +++ b/dinov2/configs/eval/vits14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_small + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/ssl_default_config.yaml b/dinov2/configs/ssl_default_config.yaml index a4ef04545ce9d6cc52b5179236008adc8a9bbda2..ccaae1c3174b21bcaf6e803dc861492261e5abe1 100644 --- a/dinov2/configs/ssl_default_config.yaml +++ b/dinov2/configs/ssl_default_config.yaml @@ -80,6 +80,9 @@ student: qkv_bias: true proj_bias: true ffn_bias: true + num_register_tokens: 0 + interpolate_antialias: false + interpolate_offset: 0.1 teacher: momentum_teacher: 0.992 final_momentum_teacher: 1 diff --git a/dinov2/hub/backbones.py b/dinov2/hub/backbones.py index 17e00981f732850f3a4086e01952828fd731ff87..53fe83719d5107eb77a8f25ef1814c3d73446002 100644 --- a/dinov2/hub/backbones.py +++ b/dinov2/hub/backbones.py @@ -23,6 +23,9 @@ def _make_dinov2_model( init_values: float = 1.0, ffn_layer: str = "mlp", block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs, @@ -35,21 +38,25 @@ def _make_dinov2_model( except KeyError: raise AssertionError(f"Unsupported weights: {weights}") - model_name = _make_dinov2_model_name(arch_name, patch_size) + model_base_name = _make_dinov2_model_name(arch_name, patch_size) vit_kwargs = dict( img_size=img_size, patch_size=patch_size, init_values=init_values, ffn_layer=ffn_layer, block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, ) vit_kwargs.update(**kwargs) model = vits.__dict__[arch_name](**vit_kwargs) if pretrained: - url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth" + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") - model.load_state_dict(state_dict, strict=False) + model.load_state_dict(state_dict, strict=True) return model @@ -80,5 +87,70 @@ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Wei DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. """ return _make_dinov2_model( - arch_name="vit_giant2", ffn_layer="swiglufused", weights=weights, pretrained=pretrained, **kwargs + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, ) diff --git a/dinov2/hub/classifiers.py b/dinov2/hub/classifiers.py index 636a732c1531271507e6c8fd7569157ca7130ee2..3f0841efa80ab3d564cd320d61da254af182606b 100644 --- a/dinov2/hub/classifiers.py +++ b/dinov2/hub/classifiers.py @@ -19,11 +19,13 @@ class Weights(Enum): def _make_dinov2_linear_classification_head( *, - model_name: str = "dinov2_vitl14", + arch_name: str = "vit_large", + patch_size: int = 14, embed_dim: int = 1024, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, **kwargs, ): if layers not in (1, 4): @@ -37,10 +39,12 @@ def _make_dinov2_linear_classification_head( linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) if pretrained: + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) layers_str = str(layers) if layers == 4 else "" - url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_linear{layers_str}_head.pth" + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth" state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") - linear_head.load_state_dict(state_dict, strict=False) + linear_head.load_state_dict(state_dict, strict=True) return linear_head @@ -85,63 +89,180 @@ def _make_dinov2_linear_classifier( layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, **kwargs, ): - backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + backbone = _make_dinov2_model( + arch_name=arch_name, + pretrained=pretrained, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + **kwargs, + ) embed_dim = backbone.embed_dim patch_size = backbone.patch_size - model_name = _make_dinov2_model_name(arch_name, patch_size) linear_head = _make_dinov2_linear_classification_head( - model_name=model_name, + arch_name=arch_name, + patch_size=patch_size, embed_dim=embed_dim, layers=layers, pretrained=pretrained, weights=weights, + num_register_tokens=num_register_tokens, ) return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) def dinov2_vits14_lc( - *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, ): """ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. """ return _make_dinov2_linear_classifier( - arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, ) def dinov2_vitb14_lc( - *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, ): """ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. """ return _make_dinov2_linear_classifier( - arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, ) def dinov2_vitl14_lc( - *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, ): """ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. """ return _make_dinov2_linear_classifier( - arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, ) def dinov2_vitg14_lc( - *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, ): """ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. """ return _make_dinov2_linear_classifier( - arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vits14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, ) diff --git a/dinov2/hub/utils.py b/dinov2/hub/utils.py index e03032ed43c23588ed0fb156c50bd38378333920..9c6641404093652d5a2f19b4cf283d976ec39e64 100644 --- a/dinov2/hub/utils.py +++ b/dinov2/hub/utils.py @@ -14,9 +14,10 @@ import torch.nn.functional as F _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" -def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str: +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: compact_arch_name = arch_name.replace("_", "")[:4] - return f"dinov2_{compact_arch_name}{patch_size}" + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" class CenterPadding(nn.Module): diff --git a/dinov2/models/__init__.py b/dinov2/models/__init__.py index e7c92d9edfd96f69b80a1f1bbc791c8a18508ecf..3fdff20badbd5244bf79f16bf18dd2cb73982265 100644 --- a/dinov2/models/__init__.py +++ b/dinov2/models/__init__.py @@ -23,6 +23,9 @@ def build_model(args, only_teacher=False, img_size=224): qkv_bias=args.qkv_bias, proj_bias=args.proj_bias, ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, ) teacher = vits.__dict__[args.arch](**vit_kwargs) if only_teacher: diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index de212108cff11362f525f36e6678ab388ed58392..c8c3ec277db73bf667660372f07fae4cce6d9b60 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -62,6 +62,9 @@ class DinoVisionTransformer(nn.Module): block_fn=Block, ffn_layer="mlp", block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, ): """ Args: @@ -84,6 +87,9 @@ class DinoVisionTransformer(nn.Module): block_fn (nn.Module): transformer block class ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings """ super().__init__() norm_layer = partial(nn.LayerNorm, eps=1e-6) @@ -93,12 +99,19 @@ class DinoVisionTransformer(nn.Module): self.n_blocks = depth self.num_heads = num_heads self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) if drop_path_uniform is True: dpr = [drop_path_rate] * depth @@ -159,6 +172,8 @@ class DinoVisionTransformer(nn.Module): def init_weights(self): trunc_normal_(self.pos_embed, std=0.02) nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) named_apply(init_weights_vit_timm, self) def interpolate_pos_encoding(self, x, w, h): @@ -175,7 +190,7 @@ class DinoVisionTransformer(nn.Module): h0 = h // self.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 - w0, h0 = w0 + 0.1, h0 + 0.1 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset sqrt_N = math.sqrt(N) sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N @@ -183,6 +198,7 @@ class DinoVisionTransformer(nn.Module): patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), scale_factor=(sx, sy), mode="bicubic", + antialias=self.interpolate_antialias, ) assert int(w0) == patch_pos_embed.shape[-2] @@ -199,6 +215,16 @@ class DinoVisionTransformer(nn.Module): x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.interpolate_pos_encoding(x, w, h) + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + return x def forward_features_list(self, x_list, masks_list): @@ -213,7 +239,8 @@ class DinoVisionTransformer(nn.Module): output.append( { "x_norm_clstoken": x_norm[:, 0], - "x_norm_patchtokens": x_norm[:, 1:], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } @@ -232,7 +259,8 @@ class DinoVisionTransformer(nn.Module): x_norm = self.norm(x) return { "x_norm_clstoken": x_norm[:, 0], - "x_norm_patchtokens": x_norm[:, 1:], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } @@ -305,7 +333,7 @@ def init_weights_vit_timm(module: nn.Module, name: str = ""): nn.init.zeros_(module.bias) -def vit_small(patch_size=16, **kwargs): +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=384, @@ -313,12 +341,13 @@ def vit_small(patch_size=16, **kwargs): num_heads=6, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, **kwargs, ) return model -def vit_base(patch_size=16, **kwargs): +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=768, @@ -326,12 +355,13 @@ def vit_base(patch_size=16, **kwargs): num_heads=12, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, **kwargs, ) return model -def vit_large(patch_size=16, **kwargs): +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1024, @@ -339,12 +369,13 @@ def vit_large(patch_size=16, **kwargs): num_heads=16, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, **kwargs, ) return model -def vit_giant2(patch_size=16, **kwargs): +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): """ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 """ @@ -355,6 +386,7 @@ def vit_giant2(patch_size=16, **kwargs): num_heads=24, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, **kwargs, ) return model diff --git a/hubconf.py b/hubconf.py index a9fbdc86d5382b25f7637d8d8d82f3d170db3c9a..d3664e2cc4846b065a99eb5080fb598b7b6c9319 100644 --- a/hubconf.py +++ b/hubconf.py @@ -5,7 +5,9 @@ from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14 +from dinov2.hub.backbones import dinov2_vitb14_reg, dinov2_vitg14_reg, dinov2_vitl14_reg, dinov2_vits14_reg from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc +from dinov2.hub.classifiers import dinov2_vitb14_reg_lc, dinov2_vitg14_reg_lc, dinov2_vitl14_reg_lc, dinov2_vits14_reg_lc from dinov2.hub.depthers import dinov2_vitb14_ld, dinov2_vitg14_ld, dinov2_vitl14_ld, dinov2_vits14_ld from dinov2.hub.depthers import dinov2_vitb14_dd, dinov2_vitg14_dd, dinov2_vitl14_dd, dinov2_vits14_dd