From e7df9fc95de8a43951060aa68db8ed6159cdaa8e Mon Sep 17 00:00:00 2001
From: Patrick Labatut <60359573+patricklabatut@users.noreply.github.com>
Date: Sat, 30 Sep 2023 20:12:02 +0200
Subject: [PATCH] Expose DPT depth models via torch.hub.load() (#238)

Add streamlined model versions w/o the mmcv dependency to directly load them via torch.hub.load().
---
 dinov2/hub/depth/__init__.py     |   2 +-
 dinov2/hub/depth/decode_heads.py | 466 ++++++++++++++++++++++++++++++-
 dinov2/hub/depthers.py           | 116 +++++++-
 hubconf.py                       |   1 +
 4 files changed, 575 insertions(+), 10 deletions(-)

diff --git a/dinov2/hub/depth/__init__.py b/dinov2/hub/depth/__init__.py
index 1ccf442..91716e5 100644
--- a/dinov2/hub/depth/__init__.py
+++ b/dinov2/hub/depth/__init__.py
@@ -3,5 +3,5 @@
 # This source code is licensed under the Apache License, Version 2.0
 # found in the LICENSE file in the root directory of this source tree.
 
-from .decode_heads import BNHead
+from .decode_heads import BNHead, DPTHead
 from .encoder_decoder import DepthEncoderDecoder
diff --git a/dinov2/hub/depth/decode_heads.py b/dinov2/hub/depth/decode_heads.py
index ca657f8..f455acc 100644
--- a/dinov2/hub/depth/decode_heads.py
+++ b/dinov2/hub/depth/decode_heads.py
@@ -4,6 +4,9 @@
 # found in the LICENSE file in the root directory of this source tree.
 
 import copy
+from functools import partial
+import math
+import warnings
 
 import torch
 import torch.nn as nn
@@ -29,6 +32,8 @@ class DepthBaseDecodeHead(nn.Module):
     Args:
         in_channels (List): Input channels.
         channels (int): Channels after modules, before conv_depth.
+        conv_layer (nn.Module): Conv layers. Default: None.
+        act_layer (nn.Module): Activation layers. Default: nn.ReLU.
         loss_decode (dict): Config of decode loss.
             Default: ().
         sampler (dict|None): The config of depth map sampler.
@@ -39,7 +44,7 @@ class DepthBaseDecodeHead(nn.Module):
             Default: 1e-3.
         max_depth (int): Max depth in dataset setting.
             Default: None.
-        norm_cfg (dict|None): Config of norm layers.
+        norm_layer (dict|None): Norm layers.
             Default: None.
         classify (bool): Whether predict depth in a cls.-reg. manner.
             Default: False.
@@ -56,13 +61,15 @@ class DepthBaseDecodeHead(nn.Module):
     def __init__(
         self,
         in_channels,
+        conv_layer=None,
+        act_layer=nn.ReLU,
         channels=96,
         loss_decode=(),
         sampler=None,
         align_corners=False,
         min_depth=1e-3,
         max_depth=None,
-        norm_cfg=None,
+        norm_layer=None,
         classify=False,
         n_bins=256,
         bins_strategy="UD",
@@ -73,11 +80,13 @@ class DepthBaseDecodeHead(nn.Module):
 
         self.in_channels = in_channels
         self.channels = channels
+        self.conf_layer = conv_layer
+        self.act_layer = act_layer
         self.loss_decode = loss_decode
         self.align_corners = align_corners
         self.min_depth = min_depth
         self.max_depth = max_depth
-        self.norm_cfg = norm_cfg
+        self.norm_layer = norm_layer
         self.classify = classify
         self.n_bins = n_bins
         self.scale_up = scale_up
@@ -285,3 +294,454 @@ class BNHead(DepthBaseDecodeHead):
         output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
         output = self.depth_pred(output)
         return output
+
+
+class ConvModule(nn.Module):
+    """A conv block that bundles conv/norm/activation layers.
+
+    This block simplifies the usage of convolution layers, which are commonly
+    used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+    It is based upon three build methods: `build_conv_layer()`,
+    `build_norm_layer()` and `build_activation_layer()`.
+
+    Besides, we add some additional features in this module.
+    1. Automatically set `bias` of the conv layer.
+    2. Spectral norm is supported.
+    3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
+    supports zero and circular padding, and we add "reflect" padding mode.
+
+    Args:
+        in_channels (int): Number of channels in the input feature map.
+            Same as that in ``nn._ConvNd``.
+        out_channels (int): Number of channels produced by the convolution.
+            Same as that in ``nn._ConvNd``.
+        kernel_size (int | tuple[int]): Size of the convolving kernel.
+            Same as that in ``nn._ConvNd``.
+        stride (int | tuple[int]): Stride of the convolution.
+            Same as that in ``nn._ConvNd``.
+        padding (int | tuple[int]): Zero-padding added to both sides of
+            the input. Same as that in ``nn._ConvNd``.
+        dilation (int | tuple[int]): Spacing between kernel elements.
+            Same as that in ``nn._ConvNd``.
+        groups (int): Number of blocked connections from input channels to
+            output channels. Same as that in ``nn._ConvNd``.
+        bias (bool | str): If specified as `auto`, it will be decided by the
+            norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
+            False. Default: "auto".
+        conv_layer (nn.Module): Convolution layer. Default: None,
+            which means using conv2d.
+        norm_layer (nn.Module): Normalization layer. Default: None.
+        act_layer (nn.Module): Activation layer. Default: nn.ReLU.
+        inplace (bool): Whether to use inplace mode for activation.
+            Default: True.
+        with_spectral_norm (bool): Whether use spectral norm in conv module.
+            Default: False.
+        padding_mode (str): If the `padding_mode` has not been supported by
+            current `Conv2d` in PyTorch, we will use our own padding layer
+            instead. Currently, we support ['zeros', 'circular'] with official
+            implementation and ['reflect'] with our own implementation.
+            Default: 'zeros'.
+        order (tuple[str]): The order of conv/norm/activation layers. It is a
+            sequence of "conv", "norm" and "act". Common examples are
+            ("conv", "norm", "act") and ("act", "conv", "norm").
+            Default: ('conv', 'norm', 'act').
+    """
+
+    _abbr_ = "conv_block"
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=1,
+        padding=0,
+        dilation=1,
+        groups=1,
+        bias="auto",
+        conv_layer=nn.Conv2d,
+        norm_layer=None,
+        act_layer=nn.ReLU,
+        inplace=True,
+        with_spectral_norm=False,
+        padding_mode="zeros",
+        order=("conv", "norm", "act"),
+    ):
+        super(ConvModule, self).__init__()
+        official_padding_mode = ["zeros", "circular"]
+        self.conv_layer = conv_layer
+        self.norm_layer = norm_layer
+        self.act_layer = act_layer
+        self.inplace = inplace
+        self.with_spectral_norm = with_spectral_norm
+        self.with_explicit_padding = padding_mode not in official_padding_mode
+        self.order = order
+        assert isinstance(self.order, tuple) and len(self.order) == 3
+        assert set(order) == set(["conv", "norm", "act"])
+
+        self.with_norm = norm_layer is not None
+        self.with_activation = act_layer is not None
+        # if the conv layer is before a norm layer, bias is unnecessary.
+        if bias == "auto":
+            bias = not self.with_norm
+        self.with_bias = bias
+
+        if self.with_explicit_padding:
+            if padding_mode == "zeros":
+                padding_layer = nn.ZeroPad2d
+            else:
+                raise AssertionError(f"Unsupported padding mode: {padding_mode}")
+            self.pad = padding_layer(padding)
+
+        # reset padding to 0 for conv module
+        conv_padding = 0 if self.with_explicit_padding else padding
+        # build convolution layer
+        self.conv = self.conv_layer(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=conv_padding,
+            dilation=dilation,
+            groups=groups,
+            bias=bias,
+        )
+        # export the attributes of self.conv to a higher level for convenience
+        self.in_channels = self.conv.in_channels
+        self.out_channels = self.conv.out_channels
+        self.kernel_size = self.conv.kernel_size
+        self.stride = self.conv.stride
+        self.padding = padding
+        self.dilation = self.conv.dilation
+        self.transposed = self.conv.transposed
+        self.output_padding = self.conv.output_padding
+        self.groups = self.conv.groups
+
+        if self.with_spectral_norm:
+            self.conv = nn.utils.spectral_norm(self.conv)
+
+        # build normalization layers
+        if self.with_norm:
+            # norm layer is after conv layer
+            if order.index("norm") > order.index("conv"):
+                norm_channels = out_channels
+            else:
+                norm_channels = in_channels
+            norm = partial(norm_layer, num_features=norm_channels)
+            self.add_module("norm", norm)
+            if self.with_bias:
+                from torch.nnModules.batchnorm import _BatchNorm
+                from torch.nnModules.instancenorm import _InstanceNorm
+
+                if isinstance(norm, (_BatchNorm, _InstanceNorm)):
+                    warnings.warn("Unnecessary conv bias before batch/instance norm")
+        else:
+            self.norm_name = None
+
+        # build activation layer
+        if self.with_activation:
+            # nn.Tanh has no 'inplace' argument
+            # (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU)
+            if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)):
+                act_layer = partial(act_layer, inplace=inplace)
+            self.activate = act_layer()
+
+        # Use msra init by default
+        self.init_weights()
+
+    @property
+    def norm(self):
+        if self.norm_name:
+            return getattr(self, self.norm_name)
+        else:
+            return None
+
+    def init_weights(self):
+        # 1. It is mainly for customized conv layers with their own
+        #    initialization manners by calling their own ``init_weights()``,
+        #    and we do not want ConvModule to override the initialization.
+        # 2. For customized conv layers without their own initialization
+        #    manners (that is, they don't have their own ``init_weights()``)
+        #    and PyTorch's conv layers, they will be initialized by
+        #    this method with default ``kaiming_init``.
+        # Note: For PyTorch's conv layers, they will be overwritten by our
+        #    initialization implementation using default ``kaiming_init``.
+        if not hasattr(self.conv, "init_weights"):
+            if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU):
+                nonlinearity = "leaky_relu"
+                a = 0.01  # XXX: default negative_slope
+            else:
+                nonlinearity = "relu"
+                a = 0
+            if hasattr(self.conv, "weight") and self.conv.weight is not None:
+                nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity)
+            if hasattr(self.conv, "bias") and self.conv.bias is not None:
+                nn.init.constant_(self.conv.bias, 0)
+        if self.with_norm:
+            if hasattr(self.norm, "weight") and self.norm.weight is not None:
+                nn.init.constant_(self.norm.weight, 1)
+            if hasattr(self.norm, "bias") and self.norm.bias is not None:
+                nn.init.constant_(self.norm.bias, 0)
+
+    def forward(self, x, activate=True, norm=True):
+        for layer in self.order:
+            if layer == "conv":
+                if self.with_explicit_padding:
+                    x = self.pad(x)
+                x = self.conv(x)
+            elif layer == "norm" and norm and self.with_norm:
+                x = self.norm(x)
+            elif layer == "act" and activate and self.with_activation:
+                x = self.activate(x)
+        return x
+
+
+class Interpolate(nn.Module):
+    def __init__(self, scale_factor, mode, align_corners=False):
+        super(Interpolate, self).__init__()
+        self.interp = nn.functional.interpolate
+        self.scale_factor = scale_factor
+        self.mode = mode
+        self.align_corners = align_corners
+
+    def forward(self, x):
+        x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
+        return x
+
+
+class HeadDepth(nn.Module):
+    def __init__(self, features):
+        super(HeadDepth, self).__init__()
+        self.head = nn.Sequential(
+            nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+            Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+            nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(),
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+        )
+
+    def forward(self, x):
+        x = self.head(x)
+        return x
+
+
+class ReassembleBlocks(nn.Module):
+    """ViTPostProcessBlock, process cls_token in ViT backbone output and
+    rearrange the feature vector to feature map.
+    Args:
+        in_channels (int): ViT feature channels. Default: 768.
+        out_channels (List): output channels of each stage.
+            Default: [96, 192, 384, 768].
+        readout_type (str): Type of readout operation. Default: 'ignore'.
+        patch_size (int): The patch size. Default: 16.
+    """
+
+    def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16):
+        super(ReassembleBlocks, self).__init__()
+
+        assert readout_type in ["ignore", "add", "project"]
+        self.readout_type = readout_type
+        self.patch_size = patch_size
+
+        self.projects = nn.ModuleList(
+            [
+                ConvModule(
+                    in_channels=in_channels,
+                    out_channels=out_channel,
+                    kernel_size=1,
+                    act_layer=None,
+                )
+                for out_channel in out_channels
+            ]
+        )
+
+        self.resize_layers = nn.ModuleList(
+            [
+                nn.ConvTranspose2d(
+                    in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
+                ),
+                nn.ConvTranspose2d(
+                    in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
+                ),
+                nn.Identity(),
+                nn.Conv2d(
+                    in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
+                ),
+            ]
+        )
+        if self.readout_type == "project":
+            self.readout_projects = nn.ModuleList()
+            for _ in range(len(self.projects)):
+                self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
+
+    def forward(self, inputs):
+        assert isinstance(inputs, list)
+        out = []
+        for i, x in enumerate(inputs):
+            assert len(x) == 2
+            x, cls_token = x[0], x[1]
+            feature_shape = x.shape
+            if self.readout_type == "project":
+                x = x.flatten(2).permute((0, 2, 1))
+                readout = cls_token.unsqueeze(1).expand_as(x)
+                x = self.readout_projects[i](torch.cat((x, readout), -1))
+                x = x.permute(0, 2, 1).reshape(feature_shape)
+            elif self.readout_type == "add":
+                x = x.flatten(2) + cls_token.unsqueeze(-1)
+                x = x.reshape(feature_shape)
+            else:
+                pass
+            x = self.projects[i](x)
+            x = self.resize_layers[i](x)
+            out.append(x)
+        return out
+
+
+class PreActResidualConvUnit(nn.Module):
+    """ResidualConvUnit, pre-activate residual unit.
+    Args:
+        in_channels (int): number of channels in the input feature map.
+        act_layer (nn.Module): activation layer.
+        norm_layer (nn.Module): norm layer.
+        stride (int): stride of the first block. Default: 1
+        dilation (int): dilation rate for convs layers. Default: 1.
+    """
+
+    def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
+        super(PreActResidualConvUnit, self).__init__()
+
+        self.conv1 = ConvModule(
+            in_channels,
+            in_channels,
+            3,
+            stride=stride,
+            padding=dilation,
+            dilation=dilation,
+            norm_layer=norm_layer,
+            act_layer=act_layer,
+            bias=False,
+            order=("act", "conv", "norm"),
+        )
+
+        self.conv2 = ConvModule(
+            in_channels,
+            in_channels,
+            3,
+            padding=1,
+            norm_layer=norm_layer,
+            act_layer=act_layer,
+            bias=False,
+            order=("act", "conv", "norm"),
+        )
+
+    def forward(self, inputs):
+        inputs_ = inputs.clone()
+        x = self.conv1(inputs)
+        x = self.conv2(x)
+        return x + inputs_
+
+
+class FeatureFusionBlock(nn.Module):
+    """FeatureFusionBlock, merge feature map from different stages.
+    Args:
+        in_channels (int): Input channels.
+        act_layer (nn.Module): activation layer for ResidualConvUnit.
+        norm_layer (nn.Module): normalization layer.
+        expand (bool): Whether expand the channels in post process block.
+            Default: False.
+        align_corners (bool): align_corner setting for bilinear upsample.
+            Default: True.
+    """
+
+    def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
+        super(FeatureFusionBlock, self).__init__()
+
+        self.in_channels = in_channels
+        self.expand = expand
+        self.align_corners = align_corners
+
+        self.out_channels = in_channels
+        if self.expand:
+            self.out_channels = in_channels // 2
+
+        self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)
+
+        self.res_conv_unit1 = PreActResidualConvUnit(
+            in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
+        )
+        self.res_conv_unit2 = PreActResidualConvUnit(
+            in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
+        )
+
+    def forward(self, *inputs):
+        x = inputs[0]
+        if len(inputs) == 2:
+            if x.shape != inputs[1].shape:
+                res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
+            else:
+                res = inputs[1]
+            x = x + self.res_conv_unit1(res)
+        x = self.res_conv_unit2(x)
+        x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
+        x = self.project(x)
+        return x
+
+
+class DPTHead(DepthBaseDecodeHead):
+    """Vision Transformers for Dense Prediction.
+    This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
+    Args:
+        embed_dims (int): The embed dimension of the ViT backbone.
+            Default: 768.
+        post_process_channels (List): Out channels of post process conv
+            layers. Default: [96, 192, 384, 768].
+        readout_type (str): Type of readout operation. Default: 'ignore'.
+        patch_size (int): The patch size. Default: 16.
+        expand_channels (bool): Whether expand the channels in post process
+            block. Default: False.
+    """
+
+    def __init__(
+        self,
+        embed_dims=768,
+        post_process_channels=[96, 192, 384, 768],
+        readout_type="ignore",
+        patch_size=16,
+        expand_channels=False,
+        **kwargs,
+    ):
+        super(DPTHead, self).__init__(**kwargs)
+
+        self.in_channels = self.in_channels
+        self.expand_channels = expand_channels
+        self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
+
+        self.post_process_channels = [
+            channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
+        ]
+        self.convs = nn.ModuleList()
+        for channel in self.post_process_channels:
+            self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
+        self.fusion_blocks = nn.ModuleList()
+        for _ in range(len(self.convs)):
+            self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
+        self.fusion_blocks[0].res_conv_unit1 = None
+        self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
+        self.num_fusion_blocks = len(self.fusion_blocks)
+        self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
+        self.num_post_process_channels = len(self.post_process_channels)
+        assert self.num_fusion_blocks == self.num_reassemble_blocks
+        assert self.num_reassemble_blocks == self.num_post_process_channels
+        self.conv_depth = HeadDepth(self.channels)
+
+    def forward(self, inputs, img_metas):
+        assert len(inputs) == self.num_reassemble_blocks
+        x = [inp for inp in inputs]
+        x = self.reassemble_blocks(x)
+        x = [self.convs[i](feature) for i, feature in enumerate(x)]
+        out = self.fusion_blocks[0](x[-1])
+        for i in range(1, len(self.fusion_blocks)):
+            out = self.fusion_blocks[i](out, x[-(i + 1)])
+        out = self.project(out)
+        out = self.depth_pred(out)
+        return out
diff --git a/dinov2/hub/depthers.py b/dinov2/hub/depthers.py
index 246f932..f88b7e9 100644
--- a/dinov2/hub/depthers.py
+++ b/dinov2/hub/depthers.py
@@ -5,12 +5,12 @@
 
 from enum import Enum
 from functools import partial
-from typing import Union
+from typing import Optional, Tuple, Union
 
 import torch
 
 from .backbones import _make_dinov2_model
-from .depth import BNHead, DepthEncoderDecoder
+from .depth import BNHead, DepthEncoderDecoder, DPTHead
 from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
 
 
@@ -19,10 +19,26 @@ class Weights(Enum):
     KITTI = "KITTI"
 
 
+def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]:
+    if not pretrained:  # Default
+        return (0.001, 10.0)
+
+    # Pretrained, set according to the training dataset for the provided weights
+    if weights == Weights.KITTI:
+        return (0.001, 80.0)
+
+    if weights == Weights.NYU:
+        return (0.001, 10.0)
+
+    return (0.001, 10.0)
+
+
 def _make_dinov2_linear_depth_head(
     *,
-    embed_dim: int = 1024,
-    layers: int = 4,
+    embed_dim: int,
+    layers: int,
+    min_depth: float,
+    max_depth: float,
     **kwargs,
 ):
     if layers not in (1, 4):
@@ -46,7 +62,7 @@ def _make_dinov2_linear_depth_head(
         channels=embed_dim * len(in_index) * 2,
         align_corners=False,
         min_depth=0.001,
-        max_depth=10,
+        max_depth=80,
         loss_decode=(),
     )
 
@@ -57,6 +73,7 @@ def _make_dinov2_linear_depther(
     layers: int = 4,
     pretrained: bool = True,
     weights: Union[Weights, str] = Weights.NYU,
+    depth_range: Optional[Tuple[float, float]] = None,
     **kwargs,
 ):
     if layers not in (1, 4):
@@ -67,15 +84,20 @@ def _make_dinov2_linear_depther(
         except KeyError:
             raise AssertionError(f"Unsupported weights: {weights}")
 
+    if depth_range is None:
+        depth_range = _get_depth_range(pretrained, weights)
+    min_depth, max_depth = depth_range
+
     backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
 
     embed_dim = backbone.embed_dim
     patch_size = backbone.patch_size
     model_name = _make_dinov2_model_name(arch_name, patch_size)
     linear_depth_head = _make_dinov2_linear_depth_head(
-        arch_name=arch_name,
         embed_dim=embed_dim,
         layers=layers,
+        min_depth=min_depth,
+        max_depth=max_depth,
     )
 
     layer_count = {
@@ -140,3 +162,85 @@ def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union
     return _make_dinov2_linear_depther(
         arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
     )
+
+
+def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float):
+    return DPTHead(
+        in_channels=[embed_dim] * 4,
+        channels=256,
+        embed_dims=embed_dim,
+        post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)],
+        readout_type="project",
+        min_depth=min_depth,
+        max_depth=max_depth,
+        loss_decode=(),
+    )
+
+
+def _make_dinov2_dpt_depther(
+    *,
+    arch_name: str = "vit_large",
+    pretrained: bool = True,
+    weights: Union[Weights, str] = Weights.NYU,
+    depth_range: Optional[Tuple[float, float]] = None,
+    **kwargs,
+):
+    if isinstance(weights, str):
+        try:
+            weights = Weights[weights]
+        except KeyError:
+            raise AssertionError(f"Unsupported weights: {weights}")
+
+    if depth_range is None:
+        depth_range = _get_depth_range(pretrained, weights)
+    min_depth, max_depth = depth_range
+
+    backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
+
+    model_name = _make_dinov2_model_name(arch_name, backbone.patch_size)
+    dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth)
+
+    out_index = {
+        "vit_small": [2, 5, 8, 11],
+        "vit_base": [2, 5, 8, 11],
+        "vit_large": [4, 11, 17, 23],
+        "vit_giant2": [9, 19, 29, 39],
+    }[arch_name]
+
+    model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head)
+    model.backbone.forward = partial(
+        backbone.get_intermediate_layers,
+        n=out_index,
+        reshape=True,
+        return_class_token=True,
+        norm=False,
+    )
+    model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0]))
+
+    if pretrained:
+        weights_str = weights.value.lower()
+        url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth"
+        checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+        if "state_dict" in checkpoint:
+            state_dict = checkpoint["state_dict"]
+        model.load_state_dict(state_dict, strict=False)
+
+    return model
+
+
+def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+    return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+    return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+    return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+    return _make_dinov2_dpt_depther(
+        arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
+    )
diff --git a/hubconf.py b/hubconf.py
index d122162..a9fbdc8 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -7,6 +7,7 @@
 from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14
 from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc
 from dinov2.hub.depthers import dinov2_vitb14_ld, dinov2_vitg14_ld, dinov2_vitl14_ld, dinov2_vits14_ld
+from dinov2.hub.depthers import dinov2_vitb14_dd, dinov2_vitg14_dd, dinov2_vitl14_dd, dinov2_vits14_dd
 
 
 dependencies = ["torch"]
-- 
GitLab