From 9c7e3245797cf7be5b5729445a4af1272bd610df Mon Sep 17 00:00:00 2001 From: Patrick Labatut <60359573+patricklabatut@users.noreply.github.com> Date: Fri, 27 Oct 2023 15:15:10 +0200 Subject: [PATCH] Add new backbones trained with registers (#282) Add new backbones (and matching linear classification heads) trained with 4 registers following [Vision Transformers Need Registers](https://arxiv.org/abs/2309.16588). --- MODEL_CARD.md | 87 +++++++++- README.md | 125 +++++++++++++++ dinov2/configs/eval/vitb14_reg4_pretrain.yaml | 9 ++ dinov2/configs/eval/vitg14_reg4_pretrain.yaml | 10 ++ dinov2/configs/eval/vitl14_reg4_pretrain.yaml | 9 ++ dinov2/configs/eval/vits14_reg4_pretrain.yaml | 9 ++ dinov2/configs/ssl_default_config.yaml | 3 + dinov2/hub/backbones.py | 80 +++++++++- dinov2/hub/classifiers.py | 149 ++++++++++++++++-- dinov2/hub/utils.py | 5 +- dinov2/models/__init__.py | 3 + dinov2/models/vision_transformer.py | 46 +++++- hubconf.py | 2 + 13 files changed, 502 insertions(+), 35 deletions(-) create mode 100644 dinov2/configs/eval/vitb14_reg4_pretrain.yaml create mode 100644 dinov2/configs/eval/vitg14_reg4_pretrain.yaml create mode 100644 dinov2/configs/eval/vitl14_reg4_pretrain.yaml create mode 100644 dinov2/configs/eval/vits14_reg4_pretrain.yaml diff --git a/MODEL_CARD.md b/MODEL_CARD.md index d158d58..21b9bf2 100644 --- a/MODEL_CARD.md +++ b/MODEL_CARD.md @@ -1,12 +1,16 @@ # Model Card for DINOv2-S/B/L/g -These are Vision Transformer models trained following the method described in the paper: +These are Vision Transformer models trained following the method described in the papers: "DINOv2: Learning Robust Visual Features without Supervision" +and +"Vision Transformers Need Registers". -We provide 4 models: 1 ViT-g trained from scratch, and 3 ViT-S/B/L models distilled from the ViT-g. +We provide 8 models: +- 1 ViT-g trained from scratch with 3 ViT-S/B/L models distilled from the ViT-g, without registers. +- 1 ViT-g trained from scratch with 3 ViT-S/B/L models distilled from the ViT-g, with registers. ## Model Details -The model takes an image as input and returns a class token and patch tokens. +The model takes an image as input and returns a class token and patch tokens, and optionally 4 register tokens. The embedding dimension is: - 384 for ViT-S. @@ -14,9 +18,9 @@ The embedding dimension is: - 1024 for ViT-L. - 1536 for ViT-g. -The models follow a Transformer architecture, with a patch size of 14. +The models follow a Transformer architecture, with a patch size of 14. In the case of registers, we add 4 register tokens, learned during training, to the input sequence after the patch embedding. -For a 224x224 image, this results in 1 class token + 256 patch tokens. +For a 224x224 image, this results in 1 class token + 256 patch tokens, and optionally 4 register tokens. The models can accept larger images provided the image shapes are multiples of the patch size (14). If this condition is not verified, the model will crop to the closest smaller multiple of the patch size. @@ -63,10 +67,18 @@ Use the code below to get started with the model. ```python import torch + +# DINOv2 dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') + +# DINOv2 with registers +dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg') +dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg') +dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg') +dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg') ``` ## Training Details @@ -92,11 +104,11 @@ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') ## Evaluation -We refer users to the associated paper for the evaluation protocols. +We refer users to the associated papers for the evaluation protocols. <table> <tr> - <th>model</th> + <th colspan="2"></th> <th colspan="3">ImageNet-1k</th> <th>NYU-Depth v2</th> <th>SUN-RGBD</th> @@ -105,7 +117,8 @@ We refer users to the associated paper for the evaluation protocols. <th>Oxford-H</th> </tr> <tr> - <th rowspan="2">task</th> + <th rowspan="2">model</th> + <th rowspan="2">with <br /> registers</th> <th>classif. (acc)</th> <th>classif. (acc)</th> <th>classif. V2 (acc)</th> @@ -128,6 +141,7 @@ We refer users to the associated paper for the evaluation protocols. </tr> <tr> <td>ViT-S/14</td> + <td align="center">:x:</td> <td align="right">79.0%</td> <td align="right">81.1%</td> <td align="right">70.8%</td> @@ -137,8 +151,21 @@ We refer users to the associated paper for the evaluation protocols. <td align="right">69.5%</td> <td align="right">43.2</td> </tr> + <tr> + <td>ViT-S/14</td> + <td align="center">:white_check_mark:</td> + <td align="right">79.1%</td> + <td align="right">80.9%</td> + <td align="right">71.0%</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">67.6%</td> + <td align="right">39.5</td> + </tr> <tr> <td>ViT-B/14</td> + <td align="center">:x:</td> <td align="right">82.1%</td> <td align="right">84.5%</td> <td align="right">74.9%</td> @@ -147,9 +174,21 @@ We refer users to the associated paper for the evaluation protocols. <td align="right">51.3</td> <td align="right">76.3%</td> <td align="right">49.5</td> + </tr> + <td>ViT-B/14</td> + <td align="center">:white_check_mark:</td> + <td align="right">82.0%</td> + <td align="right">84.6%</td> + <td align="right">75.6%</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">73.8%</td> + <td align="right">51.0</td> </tr> <tr> <td>ViT-L/14</td> + <td align="center">:x:</td> <td align="right">83.5%</td> <td align="right">86.3%</td> <td align="right">77.6%</td> @@ -159,8 +198,21 @@ We refer users to the associated paper for the evaluation protocols. <td align="right">79.8%</td> <td align="right">54.0</td> </tr> + <tr> + <td>ViT-L/14</td> + <td align="center">:white_check_mark:</td> + <td align="right">83.8%</td> + <td align="right">86.7%</td> + <td align="right">78.5%</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">80.9%</td> + <td align="right">55.7</td> + </tr> <tr> <td>ViT-g/14</td> + <td align="center">:x:</td> <td align="right">83.5%</td> <td align="right">86.5%</td> <td align="right">78.4%</td> @@ -170,6 +222,19 @@ We refer users to the associated paper for the evaluation protocols. <td align="right">81.6%</td> <td align="right">52.3</td> </tr> + <tr> + <tr> + <td>ViT-g/14</td> + <td align="center">:white_check_mark:</td> + <td align="right">83.7%</td> + <td align="right">87.1%</td> + <td align="right">78.8%</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">N/A</td> + <td align="right">81.5%</td> + <td align="right">58.2</td> + </tr> </table> ## Environmental Impact @@ -198,4 +263,10 @@ xFormers 0.0.18 journal={arXiv:2304.07193}, year={2023} } +@misc{darcet2023vitneedreg, + title={Vision Transformers Need Registers}, + author={Darcet, Timothée and Oquab, Maxime and Mairal, Julien and Bojanowski, Piotr}, + journal={arXiv:2309.16588}, + year={2023} +} ``` diff --git a/README.md b/README.md index 3e1a1d5..8ea1060 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +:new: [2023-10-26] *Added DINOv2 backbones with registers.* + # DINOv2: Learning Robust Visual Features without Supervision **[Meta AI Research, FAIR](https://ai.facebook.com/research/)** @@ -31,6 +33,7 @@ https://github.com/facebookresearch/dinov2/assets/60359573/f168823e-7922-415a-b4 <tr> <th>model</th> <th># of<br />params</th> + <th>with<br />registers</th> <th>ImageNet<br />k-NN</th> <th>ImageNet<br />linear</th> <th>download</th> @@ -40,31 +43,67 @@ https://github.com/facebookresearch/dinov2/assets/60359573/f168823e-7922-415a-b4 <tr> <td>ViT-S/14 distilled</td> <td align="right">21 M</td> + <td align="center">:x:</td> <td align="right">79.0%</td> <td align="right">81.1%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth">backbone only</a></td> </tr> + <tr> + <td>ViT-S/14 distilled</td> + <td align="right">21 M</td> + <td align="center">:white_check_mark:</td> + <td align="right">79.1%</td> + <td align="right">80.9%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth">backbone only</a></td> + </tr> <tr> <td>ViT-B/14 distilled</td> <td align="right">86 M</td> + <td align="center">:x:</td> <td align="right">82.1%</td> <td align="right">84.5%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth">backbone only</a></td> </tr> + <tr> + <td>ViT-B/14 distilled</td> + <td align="right">86 M</td> + <td align="center">:white_check_mark:</td> + <td align="right">82.0%</td> + <td align="right">84.6%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth">backbone only</a></td> + </tr> <tr> <td>ViT-L/14 distilled</td> <td align="right">300 M</td> + <td align="center">:x:</td> <td align="right">83.5%</td> <td align="right">86.3%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth">backbone only</a></td> </tr> + <tr> + <td>ViT-L/14 distilled</td> + <td align="right">300 M</td> + <td align="center">:white_check_mark:</td> + <td align="right">83.8%</td> + <td align="right">86.7%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth">backbone only</a></td> + </tr> <tr> <td>ViT-g/14</td> <td align="right">1,100 M</td> + <td align="center">:x:</td> <td align="right">83.5%</td> <td align="right">86.5%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth">backbone only</a></td> </tr> + <tr> + <td>ViT-g/14</td> + <td align="right">1,100 M</td> + <td align="center">:white_check_mark:</td> + <td align="right">83.7%</td> + <td align="right">87.1%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth">backbone only</a></td> + </tr> </tbody> </table> @@ -77,10 +116,17 @@ A corresponding [model card](MODEL_CARD.md) is included in the repository. ```python import torch +# DINOv2 dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') + +# DINOv2 with registers +dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg') +dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg') +dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg') +dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg') ``` ### Pretrained heads - Image classification @@ -89,6 +135,7 @@ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') <thead> <tr> <th rowspan="2">backbone</th> + <th rowspan="2">with<br />registers</th> <th>download</th> </tr> <tr> @@ -98,29 +145,62 @@ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') <tbody> <tr> <td>ViT-S/14 distilled</td> + <td align="center">:x:</td> <td> linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">1 layer</a>, <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear4_head.pth">4 layers</a>) </td> </tr> + <tr> + <td>ViT-S/14 distilled</td> + <td align="center">:white_check_mark:</td> + <td> + linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth">1 layer</a>, + <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear4_head.pth">4 layers</a>) + </td> + </tr> <tr> <td>ViT-B/14 distilled</td> + <td align="center">:x:</td> <td> linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">1 layer</a>, <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear4_head.pth">4 layers</a>) </tr> + <tr> + <td>ViT-B/14 distilled</td> + <td align="center">:white_check_mark:</td> + <td> + linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth">1 layer</a>, + <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear4_head.pth">4 layers</a>) + </tr> <tr> <td>ViT-L/14 distilled</td> + <td align="center">:x:</td> <td> linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">1 layer</a>, <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear4_head.pth">4 layers</a>) </tr> + <tr> + <td>ViT-L/14 distilled</td> + <td align="center">:white_check_mark:</td> + <td> + linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth">1 layer</a>, + <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear4_head.pth">4 layers</a>) + </tr> <tr> <td>ViT-g/14</td> + <td align="center">:x:</td> <td> linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">1 layer</a>, <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear4_head.pth">4 layers</a>) </tr> + <tr> + <td>ViT-g/14</td> + <td align="center">:white_check_mark:</td> + <td> + linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_lreg4_inear_head.pth">1 layer</a>, + <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear4_head.pth">4 layers</a>) + </tr> </tbody> </table> @@ -129,10 +209,17 @@ The (full) classifier models can be loaded via PyTorch Hub: ```python import torch +# DINOv2 dinov2_vits14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc') dinov2_vitb14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc') dinov2_vitl14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_lc') dinov2_vitg14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_lc') + +# DINOv2 with registers +dinov2_vits14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg_lc') +dinov2_vitb14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg_lc') +dinov2_vitl14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg_lc') +dinov2_vitg14_reg_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg_lc') ``` ### Pretrained heads - Depth estimation @@ -429,29 +516,58 @@ We release the weights from evaluating the different models: <table style="margin: auto"> <tr> <th>model</th> + <th>with<br />registers</th> <th>ImageNet<br />top-1</th> <th>linear evaluation</th> </tr> <tr> <td>ViT-S/14 distilled</td> + <td align="center">:x:</td> <td align="right">81.1%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">linear head weights</a></td> </tr> + <tr> + <td>ViT-S/14 distilled</td> + <td align="center">:white_check_mark:</td> + <td align="right">80.8%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth">linear head weights</a></td> + </tr> <tr> <td>ViT-B/14 distilled</td> + <td align="center">:x:</td> <td align="right">84.5%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">linear head weights</a></td> </tr> + <tr> + <td>ViT-B/14 distilled</td> + <td align="center">:white_check_mark:</td> + <td align="right">84.4%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth">linear head weights</a></td> + </tr> <tr> <td>ViT-L/14 distilled</td> + <td align="center">:x:</td> <td align="right">86.3%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">linear head weights</a></td> </tr> + <tr> + <td>ViT-L/14 distilled</td> + <td align="center">:white_check_mark:</td> + <td align="right">86.5%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth">linear head weights</a></td> + </tr> <tr> <td>ViT-g/14</td> + <td align="center">:x:</td> <td align="right">86.5%</td> <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">linear head weights</a></td> </tr> + <tr> + <td>ViT-g/14</td> + <td align="center">:white_check_mark:</td> + <td align="right">87.0%</td> + <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear_head.pth">linear head weights</a></td> + </tr> </table> The performance of the provided pretrained model weights can be evaluated as follows on ImageNet-1k: @@ -493,3 +609,12 @@ If you find this repository useful, please consider giving a star :star: and cit year={2023} } ``` + +``` +@misc{darcet2023vitneedreg, + title={Vision Transformers Need Registers}, + author={Darcet, Timothée and Oquab, Maxime and Mairal, Julien and Bojanowski, Piotr}, + journal={arXiv:2309.16588}, + year={2023} +} +``` diff --git a/dinov2/configs/eval/vitb14_reg4_pretrain.yaml b/dinov2/configs/eval/vitb14_reg4_pretrain.yaml new file mode 100644 index 0000000..d53edc0 --- /dev/null +++ b/dinov2/configs/eval/vitb14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_base + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vitg14_reg4_pretrain.yaml b/dinov2/configs/eval/vitg14_reg4_pretrain.yaml new file mode 100644 index 0000000..15948f8 --- /dev/null +++ b/dinov2/configs/eval/vitg14_reg4_pretrain.yaml @@ -0,0 +1,10 @@ +student: + arch: vit_giant2 + patch_size: 14 + ffn_layer: swiglufused + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vitl14_reg4_pretrain.yaml b/dinov2/configs/eval/vitl14_reg4_pretrain.yaml new file mode 100644 index 0000000..0e2bc4e --- /dev/null +++ b/dinov2/configs/eval/vitl14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_large + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/eval/vits14_reg4_pretrain.yaml b/dinov2/configs/eval/vits14_reg4_pretrain.yaml new file mode 100644 index 0000000..d25fd63 --- /dev/null +++ b/dinov2/configs/eval/vits14_reg4_pretrain.yaml @@ -0,0 +1,9 @@ +student: + arch: vit_small + patch_size: 14 + num_register_tokens: 4 + interpolate_antialias: true + interpolate_offset: 0.0 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/dinov2/configs/ssl_default_config.yaml b/dinov2/configs/ssl_default_config.yaml index a4ef045..ccaae1c 100644 --- a/dinov2/configs/ssl_default_config.yaml +++ b/dinov2/configs/ssl_default_config.yaml @@ -80,6 +80,9 @@ student: qkv_bias: true proj_bias: true ffn_bias: true + num_register_tokens: 0 + interpolate_antialias: false + interpolate_offset: 0.1 teacher: momentum_teacher: 0.992 final_momentum_teacher: 1 diff --git a/dinov2/hub/backbones.py b/dinov2/hub/backbones.py index 17e0098..53fe837 100644 --- a/dinov2/hub/backbones.py +++ b/dinov2/hub/backbones.py @@ -23,6 +23,9 @@ def _make_dinov2_model( init_values: float = 1.0, ffn_layer: str = "mlp", block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs, @@ -35,21 +38,25 @@ def _make_dinov2_model( except KeyError: raise AssertionError(f"Unsupported weights: {weights}") - model_name = _make_dinov2_model_name(arch_name, patch_size) + model_base_name = _make_dinov2_model_name(arch_name, patch_size) vit_kwargs = dict( img_size=img_size, patch_size=patch_size, init_values=init_values, ffn_layer=ffn_layer, block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, ) vit_kwargs.update(**kwargs) model = vits.__dict__[arch_name](**vit_kwargs) if pretrained: - url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth" + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") - model.load_state_dict(state_dict, strict=False) + model.load_state_dict(state_dict, strict=True) return model @@ -80,5 +87,70 @@ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Wei DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. """ return _make_dinov2_model( - arch_name="vit_giant2", ffn_layer="swiglufused", weights=weights, pretrained=pretrained, **kwargs + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, ) diff --git a/dinov2/hub/classifiers.py b/dinov2/hub/classifiers.py index 636a732..3f0841e 100644 --- a/dinov2/hub/classifiers.py +++ b/dinov2/hub/classifiers.py @@ -19,11 +19,13 @@ class Weights(Enum): def _make_dinov2_linear_classification_head( *, - model_name: str = "dinov2_vitl14", + arch_name: str = "vit_large", + patch_size: int = 14, embed_dim: int = 1024, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, **kwargs, ): if layers not in (1, 4): @@ -37,10 +39,12 @@ def _make_dinov2_linear_classification_head( linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) if pretrained: + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) layers_str = str(layers) if layers == 4 else "" - url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_linear{layers_str}_head.pth" + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_linear{layers_str}_head.pth" state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") - linear_head.load_state_dict(state_dict, strict=False) + linear_head.load_state_dict(state_dict, strict=True) return linear_head @@ -85,63 +89,180 @@ def _make_dinov2_linear_classifier( layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, **kwargs, ): - backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) + backbone = _make_dinov2_model( + arch_name=arch_name, + pretrained=pretrained, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + **kwargs, + ) embed_dim = backbone.embed_dim patch_size = backbone.patch_size - model_name = _make_dinov2_model_name(arch_name, patch_size) linear_head = _make_dinov2_linear_classification_head( - model_name=model_name, + arch_name=arch_name, + patch_size=patch_size, embed_dim=embed_dim, layers=layers, pretrained=pretrained, weights=weights, + num_register_tokens=num_register_tokens, ) return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) def dinov2_vits14_lc( - *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, ): """ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. """ return _make_dinov2_linear_classifier( - arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, ) def dinov2_vitb14_lc( - *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, ): """ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. """ return _make_dinov2_linear_classifier( - arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, ) def dinov2_vitl14_lc( - *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, ): """ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. """ return _make_dinov2_linear_classifier( - arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + **kwargs, ) def dinov2_vitg14_lc( - *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs + *, + layers: int = 4, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.IMAGENET1K, + **kwargs, ): """ Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. """ return _make_dinov2_linear_classifier( - arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + **kwargs, + ) + + +def dinov2_vits14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_small", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_base", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_large", + layers=layers, + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg_lc( + *, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs +): + """ + Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone with registers (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. + """ + return _make_dinov2_linear_classifier( + arch_name="vit_giant2", + layers=layers, + ffn_layer="swiglufused", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, ) diff --git a/dinov2/hub/utils.py b/dinov2/hub/utils.py index e03032e..9c66414 100644 --- a/dinov2/hub/utils.py +++ b/dinov2/hub/utils.py @@ -14,9 +14,10 @@ import torch.nn.functional as F _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" -def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str: +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: compact_arch_name = arch_name.replace("_", "")[:4] - return f"dinov2_{compact_arch_name}{patch_size}" + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" class CenterPadding(nn.Module): diff --git a/dinov2/models/__init__.py b/dinov2/models/__init__.py index e7c92d9..3fdff20 100644 --- a/dinov2/models/__init__.py +++ b/dinov2/models/__init__.py @@ -23,6 +23,9 @@ def build_model(args, only_teacher=False, img_size=224): qkv_bias=args.qkv_bias, proj_bias=args.proj_bias, ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, ) teacher = vits.__dict__[args.arch](**vit_kwargs) if only_teacher: diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index de21210..c8c3ec2 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -62,6 +62,9 @@ class DinoVisionTransformer(nn.Module): block_fn=Block, ffn_layer="mlp", block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, ): """ Args: @@ -84,6 +87,9 @@ class DinoVisionTransformer(nn.Module): block_fn (nn.Module): transformer block class ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings """ super().__init__() norm_layer = partial(nn.LayerNorm, eps=1e-6) @@ -93,12 +99,19 @@ class DinoVisionTransformer(nn.Module): self.n_blocks = depth self.num_heads = num_heads self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) if drop_path_uniform is True: dpr = [drop_path_rate] * depth @@ -159,6 +172,8 @@ class DinoVisionTransformer(nn.Module): def init_weights(self): trunc_normal_(self.pos_embed, std=0.02) nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) named_apply(init_weights_vit_timm, self) def interpolate_pos_encoding(self, x, w, h): @@ -175,7 +190,7 @@ class DinoVisionTransformer(nn.Module): h0 = h // self.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 - w0, h0 = w0 + 0.1, h0 + 0.1 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset sqrt_N = math.sqrt(N) sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N @@ -183,6 +198,7 @@ class DinoVisionTransformer(nn.Module): patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), scale_factor=(sx, sy), mode="bicubic", + antialias=self.interpolate_antialias, ) assert int(w0) == patch_pos_embed.shape[-2] @@ -199,6 +215,16 @@ class DinoVisionTransformer(nn.Module): x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.interpolate_pos_encoding(x, w, h) + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + return x def forward_features_list(self, x_list, masks_list): @@ -213,7 +239,8 @@ class DinoVisionTransformer(nn.Module): output.append( { "x_norm_clstoken": x_norm[:, 0], - "x_norm_patchtokens": x_norm[:, 1:], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } @@ -232,7 +259,8 @@ class DinoVisionTransformer(nn.Module): x_norm = self.norm(x) return { "x_norm_clstoken": x_norm[:, 0], - "x_norm_patchtokens": x_norm[:, 1:], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } @@ -305,7 +333,7 @@ def init_weights_vit_timm(module: nn.Module, name: str = ""): nn.init.zeros_(module.bias) -def vit_small(patch_size=16, **kwargs): +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=384, @@ -313,12 +341,13 @@ def vit_small(patch_size=16, **kwargs): num_heads=6, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, **kwargs, ) return model -def vit_base(patch_size=16, **kwargs): +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=768, @@ -326,12 +355,13 @@ def vit_base(patch_size=16, **kwargs): num_heads=12, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, **kwargs, ) return model -def vit_large(patch_size=16, **kwargs): +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1024, @@ -339,12 +369,13 @@ def vit_large(patch_size=16, **kwargs): num_heads=16, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, **kwargs, ) return model -def vit_giant2(patch_size=16, **kwargs): +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): """ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 """ @@ -355,6 +386,7 @@ def vit_giant2(patch_size=16, **kwargs): num_heads=24, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, **kwargs, ) return model diff --git a/hubconf.py b/hubconf.py index a9fbdc8..d3664e2 100644 --- a/hubconf.py +++ b/hubconf.py @@ -5,7 +5,9 @@ from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14 +from dinov2.hub.backbones import dinov2_vitb14_reg, dinov2_vitg14_reg, dinov2_vitl14_reg, dinov2_vits14_reg from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc +from dinov2.hub.classifiers import dinov2_vitb14_reg_lc, dinov2_vitg14_reg_lc, dinov2_vitl14_reg_lc, dinov2_vits14_reg_lc from dinov2.hub.depthers import dinov2_vitb14_ld, dinov2_vitg14_ld, dinov2_vitl14_ld, dinov2_vits14_ld from dinov2.hub.depthers import dinov2_vitb14_dd, dinov2_vitg14_dd, dinov2_vitl14_dd, dinov2_vits14_dd -- GitLab