Skip to content
Snippets Groups Projects
Unverified Commit 82185b17 authored by Patrick Labatut's avatar Patrick Labatut Committed by GitHub
Browse files

Expose linear depth models via PyTorch Hub (#237)

Add streamlined model versions w/o the mmcv dependency to directly load them via torch.hub.load().
parent b507fbcf
No related branches found
No related tags found
No related merge requests found
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .decode_heads import BNHead
from .encoder_decoder import DepthEncoderDecoder
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import copy
import torch
import torch.nn as nn
from .ops import resize
# XXX: (Untested) replacement for mmcv.imdenormalize()
def _imdenormalize(img, mean, std, to_bgr=True):
import numpy as np
mean = mean.reshape(1, -1).astype(np.float64)
std = std.reshape(1, -1).astype(np.float64)
img = (img * std) + mean
if to_bgr:
img = img[::-1]
return img
class DepthBaseDecodeHead(nn.Module):
"""Base class for BaseDecodeHead.
Args:
in_channels (List): Input channels.
channels (int): Channels after modules, before conv_depth.
loss_decode (dict): Config of decode loss.
Default: ().
sampler (dict|None): The config of depth map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
min_depth (int): Min depth in dataset setting.
Default: 1e-3.
max_depth (int): Max depth in dataset setting.
Default: None.
norm_cfg (dict|None): Config of norm layers.
Default: None.
classify (bool): Whether predict depth in a cls.-reg. manner.
Default: False.
n_bins (int): The number of bins used in cls. step.
Default: 256.
bins_strategy (str): The discrete strategy used in cls. step.
Default: 'UD'.
norm_strategy (str): The norm strategy on cls. probability
distribution. Default: 'linear'
scale_up (str): Whether predict depth in a scale-up manner.
Default: False.
"""
def __init__(
self,
in_channels,
channels=96,
loss_decode=(),
sampler=None,
align_corners=False,
min_depth=1e-3,
max_depth=None,
norm_cfg=None,
classify=False,
n_bins=256,
bins_strategy="UD",
norm_strategy="linear",
scale_up=False,
):
super(DepthBaseDecodeHead, self).__init__()
self.in_channels = in_channels
self.channels = channels
self.loss_decode = loss_decode
self.align_corners = align_corners
self.min_depth = min_depth
self.max_depth = max_depth
self.norm_cfg = norm_cfg
self.classify = classify
self.n_bins = n_bins
self.scale_up = scale_up
if self.classify:
assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
self.bins_strategy = bins_strategy
self.norm_strategy = norm_strategy
self.softmax = nn.Softmax(dim=1)
self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
else:
self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, inputs, img_metas):
"""Placeholder of forward function."""
pass
def forward_train(self, img, inputs, img_metas, depth_gt):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
depth_gt (Tensor): GT depth
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
depth_pred = self.forward(inputs, img_metas)
losses = self.losses(depth_pred, depth_gt)
log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
losses.update(**log_imgs)
return losses
def forward_test(self, inputs, img_metas):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
Returns:
Tensor: Output depth map.
"""
return self.forward(inputs, img_metas)
def depth_pred(self, feat):
"""Prediction each pixel."""
if self.classify:
logit = self.conv_depth(feat)
if self.bins_strategy == "UD":
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
elif self.bins_strategy == "SID":
bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
# following Adabins, default linear
if self.norm_strategy == "linear":
logit = torch.relu(logit)
eps = 0.1
logit = logit + eps
logit = logit / logit.sum(dim=1, keepdim=True)
elif self.norm_strategy == "softmax":
logit = torch.softmax(logit, dim=1)
elif self.norm_strategy == "sigmoid":
logit = torch.sigmoid(logit)
logit = logit / logit.sum(dim=1, keepdim=True)
output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
else:
if self.scale_up:
output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
else:
output = self.relu(self.conv_depth(feat)) + self.min_depth
return output
def losses(self, depth_pred, depth_gt):
"""Compute depth loss."""
loss = dict()
depth_pred = resize(
input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
)
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
else:
loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
return loss
def log_images(self, img_path, depth_pred, depth_gt, img_meta):
import numpy as np
show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
show_img = show_img.numpy().astype(np.float32)
show_img = _imdenormalize(
show_img,
img_meta["img_norm_cfg"]["mean"],
img_meta["img_norm_cfg"]["std"],
img_meta["img_norm_cfg"]["to_rgb"],
)
show_img = np.clip(show_img, 0, 255)
show_img = show_img.astype(np.uint8)
show_img = show_img[:, :, ::-1]
show_img = show_img.transpose(0, 2, 1)
show_img = show_img.transpose(1, 0, 2)
depth_pred = depth_pred / torch.max(depth_pred)
depth_gt = depth_gt / torch.max(depth_gt)
depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
class BNHead(DepthBaseDecodeHead):
"""Just a batchnorm."""
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
super().__init__(**kwargs)
self.input_transform = input_transform
self.in_index = in_index
self.upsample = upsample
# self.bn = nn.SyncBatchNorm(self.in_channels)
if self.classify:
self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
else:
self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if "concat" in self.input_transform:
inputs = [inputs[i] for i in self.in_index]
if "resize" in self.input_transform:
inputs = [
resize(
input=x,
size=[s * self.upsample for s in inputs[0].shape[2:]],
mode="bilinear",
align_corners=self.align_corners,
)
for x in inputs
]
inputs = torch.cat(inputs, dim=1)
elif self.input_transform == "multiple_select":
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def _forward_feature(self, inputs, img_metas=None, **kwargs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
# accept lists (for cls token)
inputs = list(inputs)
for i, x in enumerate(inputs):
if len(x) == 2:
x, cls_token = x[0], x[1]
if len(x.shape) == 2:
x = x[:, :, None, None]
cls_token = cls_token[:, :, None, None].expand_as(x)
inputs[i] = torch.cat((x, cls_token), 1)
else:
x = x[0]
if len(x.shape) == 2:
x = x[:, :, None, None]
inputs[i] = x
x = self._transform_inputs(inputs)
# feats = self.bn(x)
return x
def forward(self, inputs, img_metas=None, **kwargs):
"""Forward function."""
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
output = self.depth_pred(output)
return output
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from .ops import resize
def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""
outputs = dict()
for name, value in inputs.items():
outputs[f"{prefix}.{name}"] = value
return outputs
class DepthEncoderDecoder(nn.Module):
"""Encoder Decoder depther.
EncoderDecoder typically consists of backbone and decode_head.
"""
def __init__(self, backbone, decode_head):
super(DepthEncoderDecoder, self).__init__()
self.backbone = backbone
self.decode_head = decode_head
self.align_corners = self.decode_head.align_corners
def extract_feat(self, img):
"""Extract features from images."""
return self.backbone(img)
def encode_decode(self, img, img_metas, rescale=True, size=None):
"""Encode images with backbone and decode into a depth estimation
map of the same size as input."""
x = self.extract_feat(img)
out = self._decode_head_forward_test(x, img_metas)
# crop the pred depth to the certain range.
out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
if rescale:
if size is None:
if img_metas is not None:
size = img_metas[0]["ori_shape"][:2]
else:
size = img.shape[2:]
out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
return out
def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, **kwargs)
losses.update(add_prefix(loss_decode, "decode"))
return losses
def _decode_head_forward_test(self, x, img_metas):
"""Run forward function and calculate loss for decode head in
inference."""
depth_pred = self.decode_head.forward_test(x, img_metas)
return depth_pred
def forward_dummy(self, img):
"""Dummy forward function."""
depth = self.encode_decode(img, None)
return depth
def forward_train(self, img, img_metas, depth_gt, **kwargs):
"""Forward function for training.
Args:
img (Tensor): Input images.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
depth_gt (Tensor): Depth gt
used if the architecture supports depth estimation task.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(img)
losses = dict()
# the last of x saves the info from neck
loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
losses.update(loss_decode)
return losses
def whole_inference(self, img, img_meta, rescale, size=None):
"""Inference with full image."""
return self.encode_decode(img, img_meta, rescale, size=size)
def slide_inference(self, img, img_meta, rescale, stride, crop_size):
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
"""
h_stride, w_stride = stride
h_crop, w_crop = crop_size
batch_size, _, h_img, w_img = img.size()
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = img.new_zeros((batch_size, 1, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
depth_pred = self.encode_decode(crop_img, img_meta, rescale)
preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
if torch.onnx.is_in_onnx_export():
# cast count_mat to constant while exporting to ONNX
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
preds = preds / count_mat
return preds
def inference(self, img, img_meta, rescale, size=None, mode="whole"):
"""Inference with slide/whole style.
Args:
img (Tensor): The input image of shape (N, 3, H, W).
img_meta (dict): Image info dict where each dict has: 'img_shape',
'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
rescale (bool): Whether rescale back to original shape.
Returns:
Tensor: The output depth map.
"""
assert mode in ["slide", "whole"]
ori_shape = img_meta[0]["ori_shape"]
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
if mode == "slide":
depth_pred = self.slide_inference(img, img_meta, rescale)
else:
depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
output = depth_pred
flip = img_meta[0]["flip"]
if flip:
flip_direction = img_meta[0]["flip_direction"]
assert flip_direction in ["horizontal", "vertical"]
if flip_direction == "horizontal":
output = output.flip(dims=(3,))
elif flip_direction == "vertical":
output = output.flip(dims=(2,))
return output
def simple_test(self, img, img_meta, rescale=True):
"""Simple test with single image."""
depth_pred = self.inference(img, img_meta, rescale)
if torch.onnx.is_in_onnx_export():
# our inference backend only support 4D output
depth_pred = depth_pred.unsqueeze(0)
return depth_pred
depth_pred = depth_pred.cpu().numpy()
# unravel batch dim
depth_pred = list(depth_pred)
return depth_pred
def aug_test(self, imgs, img_metas, rescale=True):
"""Test with augmentations.
Only rescale=True is supported.
"""
# aug_test rescale all imgs back to ori_shape for now
assert rescale
# to save memory, we get augmented depth logit inplace
depth_pred = self.inference(imgs[0], img_metas[0], rescale)
for i in range(1, len(imgs)):
cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
depth_pred += cur_depth_pred
depth_pred /= len(imgs)
depth_pred = depth_pred.cpu().numpy()
# unravel batch dim
depth_pred = list(depth_pred)
return depth_pred
def forward_test(self, imgs, img_metas, **kwargs):
"""
Args:
imgs (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_metas (List[List[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch.
"""
for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
if not isinstance(var, list):
raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
# all images in the same aug batch all of the same ori_shape and pad
# shape
for img_meta in img_metas:
ori_shapes = [_["ori_shape"] for _ in img_meta]
assert all(shape == ori_shapes[0] for shape in ori_shapes)
img_shapes = [_["img_shape"] for _ in img_meta]
assert all(shape == img_shapes[0] for shape in img_shapes)
pad_shapes = [_["pad_shape"] for _ in img_meta]
assert all(shape == pad_shapes[0] for shape in pad_shapes)
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], **kwargs)
else:
return self.aug_test(imgs, img_metas, **kwargs)
def forward(self, img, img_metas, return_loss=True, **kwargs):
"""Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``.
Note this setting will change the expected inputs. When
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
and List[dict]), and when ``resturn_loss=False``, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
"""
if return_loss:
return self.forward_train(img, img_metas, **kwargs)
else:
return self.forward_test(img, img_metas, **kwargs)
def train_step(self, data_batch, optimizer, **kwargs):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
``num_samples``.
``loss`` is a tensor for back propagation, which can be a
weighted sum of multiple losses.
``log_vars`` contains all the variables to be sent to the
logger.
``num_samples`` indicates the batch size (when the model is
DDP, it means the batch size on each GPU), which is used for
averaging the logs.
"""
losses = self(**data_batch)
# split losses and images
real_losses = {}
log_imgs = {}
for k, v in losses.items():
if "img" in k:
log_imgs[k] = v
else:
real_losses[k] = v
loss, log_vars = self._parse_losses(real_losses)
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
return outputs
def val_step(self, data_batch, **kwargs):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
output = self(**data_batch, **kwargs)
return output
@staticmethod
def _parse_losses(losses):
import torch.distributed as dist
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
which may be a weighted sum of all losses, log_vars contains
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(f"{loss_name} is not a tensor or list of tensors")
loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
log_vars["loss"] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import warnings
import torch.nn.functional as F
def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
output_h, output_w = tuple(int(x) for x in size)
if output_h > input_h or output_w > output_h:
if (
(output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
and (output_h - 1) % (input_h - 1)
and (output_w - 1) % (input_w - 1)
):
warnings.warn(
f"When align_corners={align_corners}, "
"the output would more aligned if "
f"input size {(input_h, input_w)} is `x+1` and "
f"out size {(output_h, output_w)} is `nx+1`"
)
return F.interpolate(input, size, scale_factor, mode, align_corners)
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from enum import Enum
from functools import partial
from typing import Union
import torch
from .backbones import _make_dinov2_model
from .depth import BNHead, DepthEncoderDecoder
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
class Weights(Enum):
NYU = "NYU"
KITTI = "KITTI"
def _make_dinov2_linear_depth_head(
*,
embed_dim: int = 1024,
layers: int = 4,
**kwargs,
):
if layers not in (1, 4):
raise AssertionError(f"Unsupported number of layers: {layers}")
if layers == 1:
in_index = [0]
else:
assert layers == 4
in_index = [0, 1, 2, 3]
return BNHead(
classify=True,
n_bins=256,
bins_strategy="UD",
norm_strategy="linear",
upsample=4,
in_channels=[embed_dim] * len(in_index),
in_index=in_index,
input_transform="resize_concat",
channels=embed_dim * len(in_index) * 2,
align_corners=False,
min_depth=0.001,
max_depth=10,
loss_decode=(),
)
def _make_dinov2_linear_depther(
*,
arch_name: str = "vit_large",
layers: int = 4,
pretrained: bool = True,
weights: Union[Weights, str] = Weights.NYU,
**kwargs,
):
if layers not in (1, 4):
raise AssertionError(f"Unsupported number of layers: {layers}")
if isinstance(weights, str):
try:
weights = Weights[weights]
except KeyError:
raise AssertionError(f"Unsupported weights: {weights}")
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
embed_dim = backbone.embed_dim
patch_size = backbone.patch_size
model_name = _make_dinov2_model_name(arch_name, patch_size)
linear_depth_head = _make_dinov2_linear_depth_head(
arch_name=arch_name,
embed_dim=embed_dim,
layers=layers,
)
layer_count = {
"vit_small": 12,
"vit_base": 12,
"vit_large": 24,
"vit_giant2": 40,
}[arch_name]
if layers == 4:
out_index = {
"vit_small": [2, 5, 8, 11],
"vit_base": [2, 5, 8, 11],
"vit_large": [4, 11, 17, 23],
"vit_giant2": [9, 19, 29, 39],
}[arch_name]
else:
assert layers == 1
out_index = [layer_count - 1]
model = DepthEncoderDecoder(backbone=backbone, decode_head=linear_depth_head)
model.backbone.forward = partial(
backbone.get_intermediate_layers,
n=out_index,
reshape=True,
return_class_token=True,
norm=False,
)
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(patch_size)(x[0]))
if pretrained:
layers_str = str(layers) if layers == 4 else ""
weights_str = weights.value.lower()
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_linear{layers_str}_head.pth"
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
model.load_state_dict(state_dict, strict=False)
return model
def dinov2_vits14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
return _make_dinov2_linear_depther(
arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs
)
def dinov2_vitb14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
return _make_dinov2_linear_depther(
arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs
)
def dinov2_vitl14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
return _make_dinov2_linear_depther(
arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs
)
def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
return _make_dinov2_linear_depther(
arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
)
...@@ -3,9 +3,36 @@ ...@@ -3,9 +3,36 @@
# This source code is licensed under the Apache License, Version 2.0 # This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree. # found in the LICENSE file in the root directory of this source tree.
import itertools
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str: def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
compact_arch_name = arch_name.replace("_", "")[:4] compact_arch_name = arch_name.replace("_", "")[:4]
return f"dinov2_{compact_arch_name}{patch_size}" return f"dinov2_{compact_arch_name}{patch_size}"
class CenterPadding(nn.Module):
def __init__(self, multiple):
super().__init__()
self.multiple = multiple
def _get_pad(self, size):
new_size = math.ceil(size / self.multiple) * self.multiple
pad_size = new_size - size
pad_size_left = pad_size // 2
pad_size_right = pad_size - pad_size_left
return pad_size_left, pad_size_right
@torch.inference_mode()
def forward(self, x):
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
output = F.pad(x, pads)
return output
...@@ -6,5 +6,7 @@ ...@@ -6,5 +6,7 @@
from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14 from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14
from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc
from dinov2.hub.depthers import dinov2_vitb14_ld, dinov2_vitg14_ld, dinov2_vitl14_ld, dinov2_vits14_ld
dependencies = ["torch"] dependencies = ["torch"]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment