diff --git a/dinov2/hub/depth/__init__.py b/dinov2/hub/depth/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ccf4423e17458a7f7b486eb5477a398ab47a28d
--- /dev/null
+++ b/dinov2/hub/depth/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .decode_heads import BNHead
+from .encoder_decoder import DepthEncoderDecoder
diff --git a/dinov2/hub/depth/decode_heads.py b/dinov2/hub/depth/decode_heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca657f807fc6ae8a0a4171376d973b0b4bf06079
--- /dev/null
+++ b/dinov2/hub/depth/decode_heads.py
@@ -0,0 +1,287 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import copy
+
+import torch
+import torch.nn as nn
+
+from .ops import resize
+
+
+# XXX: (Untested) replacement for mmcv.imdenormalize()
+def _imdenormalize(img, mean, std, to_bgr=True):
+    import numpy as np
+
+    mean = mean.reshape(1, -1).astype(np.float64)
+    std = std.reshape(1, -1).astype(np.float64)
+    img = (img * std) + mean
+    if to_bgr:
+        img = img[::-1]
+    return img
+
+
+class DepthBaseDecodeHead(nn.Module):
+    """Base class for BaseDecodeHead.
+
+    Args:
+        in_channels (List): Input channels.
+        channels (int): Channels after modules, before conv_depth.
+        loss_decode (dict): Config of decode loss.
+            Default: ().
+        sampler (dict|None): The config of depth map sampler.
+            Default: None.
+        align_corners (bool): align_corners argument of F.interpolate.
+            Default: False.
+        min_depth (int): Min depth in dataset setting.
+            Default: 1e-3.
+        max_depth (int): Max depth in dataset setting.
+            Default: None.
+        norm_cfg (dict|None): Config of norm layers.
+            Default: None.
+        classify (bool): Whether predict depth in a cls.-reg. manner.
+            Default: False.
+        n_bins (int): The number of bins used in cls. step.
+            Default: 256.
+        bins_strategy (str): The discrete strategy used in cls. step.
+            Default: 'UD'.
+        norm_strategy (str): The norm strategy on cls. probability
+            distribution. Default: 'linear'
+        scale_up (str): Whether predict depth in a scale-up manner.
+            Default: False.
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        channels=96,
+        loss_decode=(),
+        sampler=None,
+        align_corners=False,
+        min_depth=1e-3,
+        max_depth=None,
+        norm_cfg=None,
+        classify=False,
+        n_bins=256,
+        bins_strategy="UD",
+        norm_strategy="linear",
+        scale_up=False,
+    ):
+        super(DepthBaseDecodeHead, self).__init__()
+
+        self.in_channels = in_channels
+        self.channels = channels
+        self.loss_decode = loss_decode
+        self.align_corners = align_corners
+        self.min_depth = min_depth
+        self.max_depth = max_depth
+        self.norm_cfg = norm_cfg
+        self.classify = classify
+        self.n_bins = n_bins
+        self.scale_up = scale_up
+
+        if self.classify:
+            assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
+            assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
+
+            self.bins_strategy = bins_strategy
+            self.norm_strategy = norm_strategy
+            self.softmax = nn.Softmax(dim=1)
+            self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
+        else:
+            self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
+
+        self.relu = nn.ReLU()
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, inputs, img_metas):
+        """Placeholder of forward function."""
+        pass
+
+    def forward_train(self, img, inputs, img_metas, depth_gt):
+        """Forward function for training.
+        Args:
+            inputs (list[Tensor]): List of multi-level img features.
+            img_metas (list[dict]): List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `depth/datasets/pipelines/formatting.py:Collect`.
+            depth_gt (Tensor): GT depth
+
+        Returns:
+            dict[str, Tensor]: a dictionary of loss components
+        """
+        depth_pred = self.forward(inputs, img_metas)
+        losses = self.losses(depth_pred, depth_gt)
+
+        log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
+        losses.update(**log_imgs)
+
+        return losses
+
+    def forward_test(self, inputs, img_metas):
+        """Forward function for testing.
+        Args:
+            inputs (list[Tensor]): List of multi-level img features.
+            img_metas (list[dict]): List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `depth/datasets/pipelines/formatting.py:Collect`.
+
+        Returns:
+            Tensor: Output depth map.
+        """
+        return self.forward(inputs, img_metas)
+
+    def depth_pred(self, feat):
+        """Prediction each pixel."""
+        if self.classify:
+            logit = self.conv_depth(feat)
+
+            if self.bins_strategy == "UD":
+                bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
+            elif self.bins_strategy == "SID":
+                bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
+
+            # following Adabins, default linear
+            if self.norm_strategy == "linear":
+                logit = torch.relu(logit)
+                eps = 0.1
+                logit = logit + eps
+                logit = logit / logit.sum(dim=1, keepdim=True)
+            elif self.norm_strategy == "softmax":
+                logit = torch.softmax(logit, dim=1)
+            elif self.norm_strategy == "sigmoid":
+                logit = torch.sigmoid(logit)
+                logit = logit / logit.sum(dim=1, keepdim=True)
+
+            output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
+
+        else:
+            if self.scale_up:
+                output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
+            else:
+                output = self.relu(self.conv_depth(feat)) + self.min_depth
+        return output
+
+    def losses(self, depth_pred, depth_gt):
+        """Compute depth loss."""
+        loss = dict()
+        depth_pred = resize(
+            input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
+        )
+        if not isinstance(self.loss_decode, nn.ModuleList):
+            losses_decode = [self.loss_decode]
+        else:
+            losses_decode = self.loss_decode
+        for loss_decode in losses_decode:
+            if loss_decode.loss_name not in loss:
+                loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
+            else:
+                loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
+        return loss
+
+    def log_images(self, img_path, depth_pred, depth_gt, img_meta):
+        import numpy as np
+
+        show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
+        show_img = show_img.numpy().astype(np.float32)
+        show_img = _imdenormalize(
+            show_img,
+            img_meta["img_norm_cfg"]["mean"],
+            img_meta["img_norm_cfg"]["std"],
+            img_meta["img_norm_cfg"]["to_rgb"],
+        )
+        show_img = np.clip(show_img, 0, 255)
+        show_img = show_img.astype(np.uint8)
+        show_img = show_img[:, :, ::-1]
+        show_img = show_img.transpose(0, 2, 1)
+        show_img = show_img.transpose(1, 0, 2)
+
+        depth_pred = depth_pred / torch.max(depth_pred)
+        depth_gt = depth_gt / torch.max(depth_gt)
+
+        depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
+        depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
+
+        return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
+
+
+class BNHead(DepthBaseDecodeHead):
+    """Just a batchnorm."""
+
+    def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
+        super().__init__(**kwargs)
+        self.input_transform = input_transform
+        self.in_index = in_index
+        self.upsample = upsample
+        # self.bn = nn.SyncBatchNorm(self.in_channels)
+        if self.classify:
+            self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
+        else:
+            self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
+
+    def _transform_inputs(self, inputs):
+        """Transform inputs for decoder.
+        Args:
+            inputs (list[Tensor]): List of multi-level img features.
+        Returns:
+            Tensor: The transformed inputs
+        """
+
+        if "concat" in self.input_transform:
+            inputs = [inputs[i] for i in self.in_index]
+            if "resize" in self.input_transform:
+                inputs = [
+                    resize(
+                        input=x,
+                        size=[s * self.upsample for s in inputs[0].shape[2:]],
+                        mode="bilinear",
+                        align_corners=self.align_corners,
+                    )
+                    for x in inputs
+                ]
+            inputs = torch.cat(inputs, dim=1)
+        elif self.input_transform == "multiple_select":
+            inputs = [inputs[i] for i in self.in_index]
+        else:
+            inputs = inputs[self.in_index]
+
+        return inputs
+
+    def _forward_feature(self, inputs, img_metas=None, **kwargs):
+        """Forward function for feature maps before classifying each pixel with
+        ``self.cls_seg`` fc.
+        Args:
+            inputs (list[Tensor]): List of multi-level img features.
+        Returns:
+            feats (Tensor): A tensor of shape (batch_size, self.channels,
+                H, W) which is feature map for last layer of decoder head.
+        """
+        # accept lists (for cls token)
+        inputs = list(inputs)
+        for i, x in enumerate(inputs):
+            if len(x) == 2:
+                x, cls_token = x[0], x[1]
+                if len(x.shape) == 2:
+                    x = x[:, :, None, None]
+                cls_token = cls_token[:, :, None, None].expand_as(x)
+                inputs[i] = torch.cat((x, cls_token), 1)
+            else:
+                x = x[0]
+                if len(x.shape) == 2:
+                    x = x[:, :, None, None]
+                inputs[i] = x
+        x = self._transform_inputs(inputs)
+        # feats = self.bn(x)
+        return x
+
+    def forward(self, inputs, img_metas=None, **kwargs):
+        """Forward function."""
+        output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
+        output = self.depth_pred(output)
+        return output
diff --git a/dinov2/hub/depth/encoder_decoder.py b/dinov2/hub/depth/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb29ced67957a336e763b0e7c90c0eeaea36fea8
--- /dev/null
+++ b/dinov2/hub/depth/encoder_decoder.py
@@ -0,0 +1,351 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .ops import resize
+
+
+def add_prefix(inputs, prefix):
+    """Add prefix for dict.
+
+    Args:
+        inputs (dict): The input dict with str keys.
+        prefix (str): The prefix to add.
+
+    Returns:
+
+        dict: The dict with keys updated with ``prefix``.
+    """
+
+    outputs = dict()
+    for name, value in inputs.items():
+        outputs[f"{prefix}.{name}"] = value
+
+    return outputs
+
+
+class DepthEncoderDecoder(nn.Module):
+    """Encoder Decoder depther.
+
+    EncoderDecoder typically consists of backbone and decode_head.
+    """
+
+    def __init__(self, backbone, decode_head):
+        super(DepthEncoderDecoder, self).__init__()
+
+        self.backbone = backbone
+        self.decode_head = decode_head
+        self.align_corners = self.decode_head.align_corners
+
+    def extract_feat(self, img):
+        """Extract features from images."""
+        return self.backbone(img)
+
+    def encode_decode(self, img, img_metas, rescale=True, size=None):
+        """Encode images with backbone and decode into a depth estimation
+        map of the same size as input."""
+        x = self.extract_feat(img)
+        out = self._decode_head_forward_test(x, img_metas)
+        # crop the pred depth to the certain range.
+        out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
+        if rescale:
+            if size is None:
+                if img_metas is not None:
+                    size = img_metas[0]["ori_shape"][:2]
+                else:
+                    size = img.shape[2:]
+            out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
+        return out
+
+    def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
+        """Run forward function and calculate loss for decode head in
+        training."""
+        losses = dict()
+        loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs)
+        losses.update(add_prefix(loss_decode, "decode"))
+        return losses
+
+    def _decode_head_forward_test(self, x, img_metas):
+        """Run forward function and calculate loss for decode head in
+        inference."""
+        depth_pred = self.decode_head.forward_test(x, img_metas)
+        return depth_pred
+
+    def forward_dummy(self, img):
+        """Dummy forward function."""
+        depth = self.encode_decode(img, None)
+
+        return depth
+
+    def forward_train(self, img, img_metas, depth_gt, **kwargs):
+        """Forward function for training.
+
+        Args:
+            img (Tensor): Input images.
+            img_metas (list[dict]): List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `depth/datasets/pipelines/formatting.py:Collect`.
+            depth_gt (Tensor): Depth gt
+                used if the architecture supports depth estimation task.
+
+        Returns:
+            dict[str, Tensor]: a dictionary of loss components
+        """
+
+        x = self.extract_feat(img)
+
+        losses = dict()
+
+        # the last of x saves the info from neck
+        loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
+
+        losses.update(loss_decode)
+
+        return losses
+
+    def whole_inference(self, img, img_meta, rescale, size=None):
+        """Inference with full image."""
+        return self.encode_decode(img, img_meta, rescale, size=size)
+
+    def slide_inference(self, img, img_meta, rescale, stride, crop_size):
+        """Inference by sliding-window with overlap.
+
+        If h_crop > h_img or w_crop > w_img, the small patch will be used to
+        decode without padding.
+        """
+
+        h_stride, w_stride = stride
+        h_crop, w_crop = crop_size
+        batch_size, _, h_img, w_img = img.size()
+        h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+        w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+        preds = img.new_zeros((batch_size, 1, h_img, w_img))
+        count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
+        for h_idx in range(h_grids):
+            for w_idx in range(w_grids):
+                y1 = h_idx * h_stride
+                x1 = w_idx * w_stride
+                y2 = min(y1 + h_crop, h_img)
+                x2 = min(x1 + w_crop, w_img)
+                y1 = max(y2 - h_crop, 0)
+                x1 = max(x2 - w_crop, 0)
+                crop_img = img[:, :, y1:y2, x1:x2]
+                depth_pred = self.encode_decode(crop_img, img_meta, rescale)
+                preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
+
+                count_mat[:, :, y1:y2, x1:x2] += 1
+        assert (count_mat == 0).sum() == 0
+        if torch.onnx.is_in_onnx_export():
+            # cast count_mat to constant while exporting to ONNX
+            count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
+        preds = preds / count_mat
+        return preds
+
+    def inference(self, img, img_meta, rescale, size=None, mode="whole"):
+        """Inference with slide/whole style.
+
+        Args:
+            img (Tensor): The input image of shape (N, 3, H, W).
+            img_meta (dict): Image info dict where each dict has: 'img_shape',
+                'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `depth/datasets/pipelines/formatting.py:Collect`.
+            rescale (bool): Whether rescale back to original shape.
+
+        Returns:
+            Tensor: The output depth map.
+        """
+
+        assert mode in ["slide", "whole"]
+        ori_shape = img_meta[0]["ori_shape"]
+        assert all(_["ori_shape"] == ori_shape for _ in img_meta)
+        if mode == "slide":
+            depth_pred = self.slide_inference(img, img_meta, rescale)
+        else:
+            depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
+        output = depth_pred
+        flip = img_meta[0]["flip"]
+        if flip:
+            flip_direction = img_meta[0]["flip_direction"]
+            assert flip_direction in ["horizontal", "vertical"]
+            if flip_direction == "horizontal":
+                output = output.flip(dims=(3,))
+            elif flip_direction == "vertical":
+                output = output.flip(dims=(2,))
+
+        return output
+
+    def simple_test(self, img, img_meta, rescale=True):
+        """Simple test with single image."""
+        depth_pred = self.inference(img, img_meta, rescale)
+        if torch.onnx.is_in_onnx_export():
+            # our inference backend only support 4D output
+            depth_pred = depth_pred.unsqueeze(0)
+            return depth_pred
+        depth_pred = depth_pred.cpu().numpy()
+        # unravel batch dim
+        depth_pred = list(depth_pred)
+        return depth_pred
+
+    def aug_test(self, imgs, img_metas, rescale=True):
+        """Test with augmentations.
+
+        Only rescale=True is supported.
+        """
+        # aug_test rescale all imgs back to ori_shape for now
+        assert rescale
+        # to save memory, we get augmented depth logit inplace
+        depth_pred = self.inference(imgs[0], img_metas[0], rescale)
+        for i in range(1, len(imgs)):
+            cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
+            depth_pred += cur_depth_pred
+        depth_pred /= len(imgs)
+        depth_pred = depth_pred.cpu().numpy()
+        # unravel batch dim
+        depth_pred = list(depth_pred)
+        return depth_pred
+
+    def forward_test(self, imgs, img_metas, **kwargs):
+        """
+        Args:
+            imgs (List[Tensor]): the outer list indicates test-time
+                augmentations and inner Tensor should have a shape NxCxHxW,
+                which contains all images in the batch.
+            img_metas (List[List[dict]]): the outer list indicates test-time
+                augs (multiscale, flip, etc.) and the inner list indicates
+                images in a batch.
+        """
+        for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
+            if not isinstance(var, list):
+                raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
+        num_augs = len(imgs)
+        if num_augs != len(img_metas):
+            raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
+        # all images in the same aug batch all of the same ori_shape and pad
+        # shape
+        for img_meta in img_metas:
+            ori_shapes = [_["ori_shape"] for _ in img_meta]
+            assert all(shape == ori_shapes[0] for shape in ori_shapes)
+            img_shapes = [_["img_shape"] for _ in img_meta]
+            assert all(shape == img_shapes[0] for shape in img_shapes)
+            pad_shapes = [_["pad_shape"] for _ in img_meta]
+            assert all(shape == pad_shapes[0] for shape in pad_shapes)
+
+        if num_augs == 1:
+            return self.simple_test(imgs[0], img_metas[0], **kwargs)
+        else:
+            return self.aug_test(imgs, img_metas, **kwargs)
+
+    def forward(self, img, img_metas, return_loss=True, **kwargs):
+        """Calls either :func:`forward_train` or :func:`forward_test` depending
+        on whether ``return_loss`` is ``True``.
+
+        Note this setting will change the expected inputs. When
+        ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
+        and List[dict]), and when ``resturn_loss=False``, img and img_meta
+        should be double nested (i.e.  List[Tensor], List[List[dict]]), with
+        the outer list indicating test time augmentations.
+        """
+        if return_loss:
+            return self.forward_train(img, img_metas, **kwargs)
+        else:
+            return self.forward_test(img, img_metas, **kwargs)
+
+    def train_step(self, data_batch, optimizer, **kwargs):
+        """The iteration step during training.
+
+        This method defines an iteration step during training, except for the
+        back propagation and optimizer updating, which are done in an optimizer
+        hook. Note that in some complicated cases or models, the whole process
+        including back propagation and optimizer updating is also defined in
+        this method, such as GAN.
+
+        Args:
+            data (dict): The output of dataloader.
+            optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+                runner is passed to ``train_step()``. This argument is unused
+                and reserved.
+
+        Returns:
+            dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
+                ``num_samples``.
+                ``loss`` is a tensor for back propagation, which can be a
+                weighted sum of multiple losses.
+                ``log_vars`` contains all the variables to be sent to the
+                logger.
+                ``num_samples`` indicates the batch size (when the model is
+                DDP, it means the batch size on each GPU), which is used for
+                averaging the logs.
+        """
+        losses = self(**data_batch)
+
+        # split losses and images
+        real_losses = {}
+        log_imgs = {}
+        for k, v in losses.items():
+            if "img" in k:
+                log_imgs[k] = v
+            else:
+                real_losses[k] = v
+
+        loss, log_vars = self._parse_losses(real_losses)
+
+        outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
+
+        return outputs
+
+    def val_step(self, data_batch, **kwargs):
+        """The iteration step during validation.
+
+        This method shares the same signature as :func:`train_step`, but used
+        during val epochs. Note that the evaluation after training epochs is
+        not implemented with this method, but an evaluation hook.
+        """
+        output = self(**data_batch, **kwargs)
+        return output
+
+    @staticmethod
+    def _parse_losses(losses):
+        import torch.distributed as dist
+
+        """Parse the raw outputs (losses) of the network.
+
+        Args:
+            losses (dict): Raw output of the network, which usually contain
+                losses and other necessary information.
+
+        Returns:
+            tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
+                which may be a weighted sum of all losses, log_vars contains
+                all the variables to be sent to the logger.
+        """
+        log_vars = OrderedDict()
+        for loss_name, loss_value in losses.items():
+            if isinstance(loss_value, torch.Tensor):
+                log_vars[loss_name] = loss_value.mean()
+            elif isinstance(loss_value, list):
+                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+            else:
+                raise TypeError(f"{loss_name} is not a tensor or list of tensors")
+
+        loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
+
+        log_vars["loss"] = loss
+        for loss_name, loss_value in log_vars.items():
+            # reduce loss when distributed training
+            if dist.is_available() and dist.is_initialized():
+                loss_value = loss_value.data.clone()
+                dist.all_reduce(loss_value.div_(dist.get_world_size()))
+            log_vars[loss_name] = loss_value.item()
+
+        return loss, log_vars
diff --git a/dinov2/hub/depth/ops.py b/dinov2/hub/depth/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..15880ee0cb7652d4b41c489b927bf6a156b40e5e
--- /dev/null
+++ b/dinov2/hub/depth/ops.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import warnings
+
+import torch.nn.functional as F
+
+
+def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
+    if warning:
+        if size is not None and align_corners:
+            input_h, input_w = tuple(int(x) for x in input.shape[2:])
+            output_h, output_w = tuple(int(x) for x in size)
+            if output_h > input_h or output_w > output_h:
+                if (
+                    (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
+                    and (output_h - 1) % (input_h - 1)
+                    and (output_w - 1) % (input_w - 1)
+                ):
+                    warnings.warn(
+                        f"When align_corners={align_corners}, "
+                        "the output would more aligned if "
+                        f"input size {(input_h, input_w)} is `x+1` and "
+                        f"out size {(output_h, output_w)} is `nx+1`"
+                    )
+    return F.interpolate(input, size, scale_factor, mode, align_corners)
diff --git a/dinov2/hub/depthers.py b/dinov2/hub/depthers.py
new file mode 100644
index 0000000000000000000000000000000000000000..246f9328a033cd4ef83df0181e5881740415952e
--- /dev/null
+++ b/dinov2/hub/depthers.py
@@ -0,0 +1,142 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from functools import partial
+from typing import Union
+
+import torch
+
+from .backbones import _make_dinov2_model
+from .depth import BNHead, DepthEncoderDecoder
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
+
+
+class Weights(Enum):
+    NYU = "NYU"
+    KITTI = "KITTI"
+
+
+def _make_dinov2_linear_depth_head(
+    *,
+    embed_dim: int = 1024,
+    layers: int = 4,
+    **kwargs,
+):
+    if layers not in (1, 4):
+        raise AssertionError(f"Unsupported number of layers: {layers}")
+
+    if layers == 1:
+        in_index = [0]
+    else:
+        assert layers == 4
+        in_index = [0, 1, 2, 3]
+
+    return BNHead(
+        classify=True,
+        n_bins=256,
+        bins_strategy="UD",
+        norm_strategy="linear",
+        upsample=4,
+        in_channels=[embed_dim] * len(in_index),
+        in_index=in_index,
+        input_transform="resize_concat",
+        channels=embed_dim * len(in_index) * 2,
+        align_corners=False,
+        min_depth=0.001,
+        max_depth=10,
+        loss_decode=(),
+    )
+
+
+def _make_dinov2_linear_depther(
+    *,
+    arch_name: str = "vit_large",
+    layers: int = 4,
+    pretrained: bool = True,
+    weights: Union[Weights, str] = Weights.NYU,
+    **kwargs,
+):
+    if layers not in (1, 4):
+        raise AssertionError(f"Unsupported number of layers: {layers}")
+    if isinstance(weights, str):
+        try:
+            weights = Weights[weights]
+        except KeyError:
+            raise AssertionError(f"Unsupported weights: {weights}")
+
+    backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
+
+    embed_dim = backbone.embed_dim
+    patch_size = backbone.patch_size
+    model_name = _make_dinov2_model_name(arch_name, patch_size)
+    linear_depth_head = _make_dinov2_linear_depth_head(
+        arch_name=arch_name,
+        embed_dim=embed_dim,
+        layers=layers,
+    )
+
+    layer_count = {
+        "vit_small": 12,
+        "vit_base": 12,
+        "vit_large": 24,
+        "vit_giant2": 40,
+    }[arch_name]
+
+    if layers == 4:
+        out_index = {
+            "vit_small": [2, 5, 8, 11],
+            "vit_base": [2, 5, 8, 11],
+            "vit_large": [4, 11, 17, 23],
+            "vit_giant2": [9, 19, 29, 39],
+        }[arch_name]
+    else:
+        assert layers == 1
+        out_index = [layer_count - 1]
+
+    model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head)
+    model.backbone.forward = partial(
+        backbone.get_intermediate_layers,
+        n=out_index,
+        reshape=True,
+        return_class_token=True,
+        norm=False,
+    )
+    model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0]))
+
+    if pretrained:
+        layers_str = str(layers) if layers == 4 else ""
+        weights_str = weights.value.lower()
+        url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth"
+        checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+        if "state_dict" in checkpoint:
+            state_dict = checkpoint["state_dict"]
+        model.load_state_dict(state_dict, strict=False)
+
+    return model
+
+
+def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+    return _make_dinov2_linear_depther(
+        arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
+    )
+
+
+def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+    return _make_dinov2_linear_depther(
+        arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
+    )
+
+
+def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+    return _make_dinov2_linear_depther(
+        arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
+    )
+
+
+def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
+    return _make_dinov2_linear_depther(
+        arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
+    )
diff --git a/dinov2/hub/utils.py b/dinov2/hub/utils.py
index 46805993337e48b35f3933804db26278088c7965..e03032ed43c23588ed0fb156c50bd38378333920 100644
--- a/dinov2/hub/utils.py
+++ b/dinov2/hub/utils.py
@@ -3,9 +3,36 @@
 # This source code is licensed under the Apache License, Version 2.0
 # found in the LICENSE file in the root directory of this source tree.
 
+import itertools
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
 _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
 
 
 def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
     compact_arch_name = arch_name.replace("_", "")[:4]
     return f"dinov2_{compact_arch_name}{patch_size}"
+
+
+class CenterPadding(nn.Module):
+    def __init__(self, multiple):
+        super().__init__()
+        self.multiple = multiple
+
+    def _get_pad(self, size):
+        new_size = math.ceil(size / self.multiple) * self.multiple
+        pad_size = new_size - size
+        pad_size_left = pad_size // 2
+        pad_size_right = pad_size - pad_size_left
+        return pad_size_left, pad_size_right
+
+    @torch.inference_mode()
+    def forward(self, x):
+        pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
+        output = F.pad(x, pads)
+        return output
diff --git a/hubconf.py b/hubconf.py
index b3b44837300181bd1c901e8ce2cf0ab2fdc27775..d1221627b2e822e12dbbba5d4980f16acfeb45be 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -6,5 +6,7 @@
 
 from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14
 from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc
+from dinov2.hub.depthers import dinov2_vitb14_ld, dinov2_vitg14_ld, dinov2_vitl14_ld, dinov2_vits14_ld
+
 
 dependencies = ["torch"]