Skip to content
Snippets Groups Projects
Unverified Commit e7df9fc9 authored by Patrick Labatut's avatar Patrick Labatut Committed by GitHub
Browse files

Expose DPT depth models via torch.hub.load() (#238)

Add streamlined model versions w/o the mmcv dependency to directly load them via torch.hub.load().
parent 82185b17
No related branches found
No related tags found
No related merge requests found
......@@ -3,5 +3,5 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .decode_heads import BNHead
from .decode_heads import BNHead, DPTHead
from .encoder_decoder import DepthEncoderDecoder
This diff is collapsed.
......@@ -5,12 +5,12 @@
from enum import Enum
from functools import partial
from typing import Union
from typing import Optional, Tuple, Union
import torch
from .backbones import _make_dinov2_model
from .depth import BNHead, DepthEncoderDecoder
from .depth import BNHead, DepthEncoderDecoder, DPTHead
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
......@@ -19,10 +19,26 @@ class Weights(Enum):
KITTI = "KITTI"
def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]:
if not pretrained: # Default
return (0.001, 10.0)
# Pretrained, set according to the training dataset for the provided weights
if weights == Weights.KITTI:
return (0.001, 80.0)
if weights == Weights.NYU:
return (0.001, 10.0)
return (0.001, 10.0)
def _make_dinov2_linear_depth_head(
*,
embed_dim: int = 1024,
layers: int = 4,
embed_dim: int,
layers: int,
min_depth: float,
max_depth: float,
**kwargs,
):
if layers not in (1, 4):
......@@ -46,7 +62,7 @@ def _make_dinov2_linear_depth_head(
channels=embed_dim * len(in_index) * 2,
align_corners=False,
min_depth=0.001,
max_depth=10,
max_depth=80,
loss_decode=(),
)
......@@ -57,6 +73,7 @@ def _make_dinov2_linear_depther(
layers: int = 4,
pretrained: bool = True,
weights: Union[Weights, str] = Weights.NYU,
depth_range: Optional[Tuple[float, float]] = None,
**kwargs,
):
if layers not in (1, 4):
......@@ -67,15 +84,20 @@ def _make_dinov2_linear_depther(
except KeyError:
raise AssertionError(f"Unsupported weights: {weights}")
if depth_range is None:
depth_range = _get_depth_range(pretrained, weights)
min_depth, max_depth = depth_range
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
embed_dim = backbone.embed_dim
patch_size = backbone.patch_size
model_name = _make_dinov2_model_name(arch_name, patch_size)
linear_depth_head = _make_dinov2_linear_depth_head(
arch_name=arch_name,
embed_dim=embed_dim,
layers=layers,
min_depth=min_depth,
max_depth=max_depth,
)
layer_count = {
......@@ -140,3 +162,85 @@ def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union
return _make_dinov2_linear_depther(
arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
)
def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float):
return DPTHead(
in_channels=[embed_dim] * 4,
channels=256,
embed_dims=embed_dim,
post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)],
readout_type="project",
min_depth=min_depth,
max_depth=max_depth,
loss_decode=(),
)
def _make_dinov2_dpt_depther(
*,
arch_name: str = "vit_large",
pretrained: bool = True,
weights: Union[Weights, str] = Weights.NYU,
depth_range: Optional[Tuple[float, float]] = None,
**kwargs,
):
if isinstance(weights, str):
try:
weights = Weights[weights]
except KeyError:
raise AssertionError(f"Unsupported weights: {weights}")
if depth_range is None:
depth_range = _get_depth_range(pretrained, weights)
min_depth, max_depth = depth_range
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
model_name = _make_dinov2_model_name(arch_name, backbone.patch_size)
dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth)
out_index = {
"vit_small": [2, 5, 8, 11],
"vit_base": [2, 5, 8, 11],
"vit_large": [4, 11, 17, 23],
"vit_giant2": [9, 19, 29, 39],
}[arch_name]
model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head)
model.backbone.forward = partial(
backbone.get_intermediate_layers,
n=out_index,
reshape=True,
return_class_token=True,
norm=False,
)
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0]))
if pretrained:
weights_str = weights.value.lower()
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth"
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
model.load_state_dict(state_dict, strict=False)
return model
def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
return _make_dinov2_dpt_depther(
arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
)
......@@ -7,6 +7,7 @@
from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14
from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc
from dinov2.hub.depthers import dinov2_vitb14_ld, dinov2_vitg14_ld, dinov2_vitl14_ld, dinov2_vits14_ld
from dinov2.hub.depthers import dinov2_vitb14_dd, dinov2_vitg14_dd, dinov2_vitl14_dd, dinov2_vits14_dd
dependencies = ["torch"]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment