diff --git a/dinov2/eval/segmentation_m2f/__init__.py b/dinov2/eval/segmentation_m2f/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c678fdf8f1dee14d7cf9be70af14e6f9a1441c3 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .core import * # noqa: F403 +from .models import * # noqa: F403 +from .ops import * # noqa: F403 diff --git a/dinov2/eval/segmentation_m2f/core/__init__.py b/dinov2/eval/segmentation_m2f/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92599806fbd221c1418d179892a0f46dc0b7d4db --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmseg.core.evaluation import * # noqa: F403 +from mmseg.core.seg import * # noqa: F403 + +from .anchor import * # noqa: F403 +from .box import * # noqa: F403 +from .utils import * # noqa: F403 diff --git a/dinov2/eval/segmentation_m2f/core/anchor/__init__.py b/dinov2/eval/segmentation_m2f/core/anchor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e71ac4d6e01462221ae01aa16d0e1231cda7e2e7 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/anchor/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .point_generator import MlvlPointGenerator # noqa: F403 diff --git a/dinov2/eval/segmentation_m2f/core/anchor/builder.py b/dinov2/eval/segmentation_m2f/core/anchor/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6dba90e22de76d2f23a86d3c057f196d55a99690 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/anchor/builder.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +from mmcv.utils import Registry, build_from_cfg + +PRIOR_GENERATORS = Registry("Generator for anchors and points") + +ANCHOR_GENERATORS = PRIOR_GENERATORS + + +def build_prior_generator(cfg, default_args=None): + return build_from_cfg(cfg, PRIOR_GENERATORS, default_args) + + +def build_anchor_generator(cfg, default_args=None): + warnings.warn("``build_anchor_generator`` would be deprecated soon, please use " "``build_prior_generator`` ") + return build_prior_generator(cfg, default_args=default_args) diff --git a/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py b/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..574d71939080e22284fe99087fb2e7336657bd97 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/anchor/point_generator.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn.modules.utils import _pair + +from .builder import PRIOR_GENERATORS + + +@PRIOR_GENERATORS.register_module() +class MlvlPointGenerator: + """Standard points generator for multi-level (Mlvl) feature maps in 2D + points-based detectors. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels in order (w, h). + offset (float): The offset of points, the value is normalized with + corresponding stride. Defaults to 0.5. + """ + + def __init__(self, strides, offset=0.5): + self.strides = [_pair(stride) for stride in strides] + self.offset = offset + + @property + def num_levels(self): + """int: number of feature levels that the generator will be applied""" + return len(self.strides) + + @property + def num_base_priors(self): + """list[int]: The number of priors (points) at a point + on the feature grid""" + return [1 for _ in range(len(self.strides))] + + def _meshgrid(self, x, y, row_major=True): + yy, xx = torch.meshgrid(y, x) + if row_major: + # warning .flatten() would cause error in ONNX exporting + # have to use reshape here + return xx.reshape(-1), yy.reshape(-1) + + else: + return yy.reshape(-1), xx.reshape(-1) + + def grid_priors(self, featmap_sizes, dtype=torch.float32, device="cuda", with_stride=False): + """Generate grid points of multiple feature levels. + + Args: + featmap_sizes (list[tuple]): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32. + device (str): The device where the anchors will be put on. + with_stride (bool): Whether to concatenate the stride to + the last dimension of points. + + Return: + list[torch.Tensor]: Points of multiple feature levels. + The sizes of each tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + + assert self.num_levels == len(featmap_sizes) + multi_level_priors = [] + for i in range(self.num_levels): + priors = self.single_level_grid_priors( + featmap_sizes[i], level_idx=i, dtype=dtype, device=device, with_stride=with_stride + ) + multi_level_priors.append(priors) + return multi_level_priors + + def single_level_grid_priors(self, featmap_size, level_idx, dtype=torch.float32, device="cuda", with_stride=False): + """Generate grid Points of a single level. + + Note: + This function is usually called by method ``self.grid_priors``. + + Args: + featmap_size (tuple[int]): Size of the feature maps, arrange as + (h, w). + level_idx (int): The index of corresponding feature map level. + dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32. + device (str, optional): The device the tensor will be put on. + Defaults to 'cuda'. + with_stride (bool): Concatenate the stride to the last dimension + of points. + + Return: + Tensor: Points of single feature levels. + The shape of tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + feat_h, feat_w = featmap_size + stride_w, stride_h = self.strides[level_idx] + shift_x = (torch.arange(0, feat_w, device=device) + self.offset) * stride_w + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_x = shift_x.to(dtype) + + shift_y = (torch.arange(0, feat_h, device=device) + self.offset) * stride_h + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_y = shift_y.to(dtype) + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + if not with_stride: + shifts = torch.stack([shift_xx, shift_yy], dim=-1) + else: + # use `shape[0]` instead of `len(shift_xx)` for ONNX export + stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(dtype) + stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(dtype) + shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], dim=-1) + all_points = shifts.to(device) + return all_points + + def valid_flags(self, featmap_sizes, pad_shape, device="cuda"): + """Generate valid flags of points of multiple feature levels. + + Args: + featmap_sizes (list(tuple)): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + pad_shape (tuple(int)): The padded shape of the image, + arrange as (h, w). + device (str): The device where the anchors will be put on. + + Return: + list(torch.Tensor): Valid flags of points of multiple levels. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_flags = [] + for i in range(self.num_levels): + point_stride = self.strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = pad_shape[:2] + valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h) + valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w) + flags = self.single_level_valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device) + multi_level_flags.append(flags) + return multi_level_flags + + def single_level_valid_flags(self, featmap_size, valid_size, device="cuda"): + """Generate the valid flags of points of a single feature map. + + Args: + featmap_size (tuple[int]): The size of feature maps, arrange as + as (h, w). + valid_size (tuple[int]): The valid size of the feature maps. + The size arrange as as (h, w). + device (str, optional): The device where the flags will be put on. + Defaults to 'cuda'. + + Returns: + torch.Tensor: The valid flags of each points in a single level \ + feature map. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) + valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid + + def sparse_priors(self, prior_idxs, featmap_size, level_idx, dtype=torch.float32, device="cuda"): + """Generate sparse points according to the ``prior_idxs``. + + Args: + prior_idxs (Tensor): The index of corresponding anchors + in the feature map. + featmap_size (tuple[int]): feature map size arrange as (w, h). + level_idx (int): The level index of corresponding feature + map. + dtype (obj:`torch.dtype`): Date type of points. Defaults to + ``torch.float32``. + device (obj:`torch.device`): The device where the points is + located. + Returns: + Tensor: Anchor with shape (N, 2), N should be equal to + the length of ``prior_idxs``. And last dimension + 2 represent (coord_x, coord_y). + """ + height, width = featmap_size + x = (prior_idxs % width + self.offset) * self.strides[level_idx][0] + y = ((prior_idxs // width) % height + self.offset) * self.strides[level_idx][1] + prioris = torch.stack([x, y], 1).to(dtype) + prioris = prioris.to(device) + return prioris diff --git a/dinov2/eval/segmentation_m2f/core/box/__init__.py b/dinov2/eval/segmentation_m2f/core/box/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf35a613f81acd77ecab2dfb75a722fa8e5c0787 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/box/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .builder import * # noqa: F403 +from .samplers import MaskPseudoSampler # noqa: F403 diff --git a/dinov2/eval/segmentation_m2f/core/box/builder.py b/dinov2/eval/segmentation_m2f/core/box/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..9538c0de3db682c2b111b085a8a1ce321c76a9ff --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/box/builder.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.utils import Registry, build_from_cfg + +BBOX_SAMPLERS = Registry("bbox_sampler") +BBOX_CODERS = Registry("bbox_coder") + + +def build_sampler(cfg, **default_args): + """Builder of box sampler.""" + return build_from_cfg(cfg, BBOX_SAMPLERS, default_args) + + +def build_bbox_coder(cfg, **default_args): + """Builder of box coder.""" + return build_from_cfg(cfg, BBOX_CODERS, default_args) diff --git a/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py b/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19c363e3fabc365d92aeaf1e78189d710db279e9 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F403 diff --git a/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py b/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c45cec3ed7af5b49bb54b92d6e6bcf59b06b4c99 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from abc import ABCMeta, abstractmethod + +import torch + +from .sampling_result import SamplingResult + + +class BaseSampler(metaclass=ABCMeta): + """Base class of samplers.""" + + def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs): + self.num = num + self.pos_fraction = pos_fraction + self.neg_pos_ub = neg_pos_ub + self.add_gt_as_proposals = add_gt_as_proposals + self.pos_sampler = self + self.neg_sampler = self + + @abstractmethod + def _sample_pos(self, assign_result, num_expected, **kwargs): + """Sample positive samples.""" + pass + + @abstractmethod + def _sample_neg(self, assign_result, num_expected, **kwargs): + """Sample negative samples.""" + pass + + def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs): + """Sample positive and negative bboxes. + + This is a simple implementation of bbox sampling given candidates, + assigning results and ground truth bboxes. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + bboxes (Tensor): Boxes to be sampled from. + gt_bboxes (Tensor): Ground truth bboxes. + gt_labels (Tensor, optional): Class labels of ground truth bboxes. + + Returns: + :obj:`SamplingResult`: Sampling result. + + Example: + >>> from mmdet.core.bbox import RandomSampler + >>> from mmdet.core.bbox import AssignResult + >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes + >>> rng = ensure_rng(None) + >>> assign_result = AssignResult.random(rng=rng) + >>> bboxes = random_boxes(assign_result.num_preds, rng=rng) + >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng) + >>> gt_labels = None + >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, + >>> add_gt_as_proposals=False) + >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels) + """ + if len(bboxes.shape) < 2: + bboxes = bboxes[None, :] + + bboxes = bboxes[:, :4] + + gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8) + if self.add_gt_as_proposals and len(gt_bboxes) > 0: + if gt_labels is None: + raise ValueError("gt_labels must be given when add_gt_as_proposals is True") + bboxes = torch.cat([gt_bboxes, bboxes], dim=0) + assign_result.add_gt_(gt_labels) + gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) + gt_flags = torch.cat([gt_ones, gt_flags]) + + num_expected_pos = int(self.num * self.pos_fraction) + pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs) + # We found that sampled indices have duplicated items occasionally. + # (may be a bug of PyTorch) + pos_inds = pos_inds.unique() + num_sampled_pos = pos_inds.numel() + num_expected_neg = self.num - num_sampled_pos + if self.neg_pos_ub >= 0: + _pos = max(1, num_sampled_pos) + neg_upper_bound = int(self.neg_pos_ub * _pos) + if num_expected_neg > neg_upper_bound: + num_expected_neg = neg_upper_bound + neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs) + neg_inds = neg_inds.unique() + + sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags) + return sampling_result diff --git a/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py b/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..3e67ea61ed0fd65cca0addde1893a3c1e176bf15 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/box/samplers/mask_pseudo_sampler.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py + +import torch + +from ..builder import BBOX_SAMPLERS +from .base_sampler import BaseSampler +from .mask_sampling_result import MaskSamplingResult + + +@BBOX_SAMPLERS.register_module() +class MaskPseudoSampler(BaseSampler): + """A pseudo sampler that does not do sampling actually.""" + + def __init__(self, **kwargs): + pass + + def _sample_pos(self, **kwargs): + """Sample positive samples.""" + raise NotImplementedError + + def _sample_neg(self, **kwargs): + """Sample negative samples.""" + raise NotImplementedError + + def sample(self, assign_result, masks, gt_masks, **kwargs): + """Directly returns the positive and negative indices of samples. + + Args: + assign_result (:obj:`AssignResult`): Assigned results + masks (torch.Tensor): Bounding boxes + gt_masks (torch.Tensor): Ground truth boxes + Returns: + :obj:`SamplingResult`: sampler results + """ + pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8) + sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags) + return sampling_result diff --git a/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py b/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py new file mode 100644 index 0000000000000000000000000000000000000000..270ffd35a5f120dd0560a7fea7fe83ef0bab66bb --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/box/samplers/mask_sampling_result.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py + +import torch + +from .sampling_result import SamplingResult + + +class MaskSamplingResult(SamplingResult): + """Mask sampling result.""" + + def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_masks = masks[pos_inds] + self.neg_masks = masks[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_masks.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_masks.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_masks = torch.empty_like(gt_masks) + else: + self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] + + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def masks(self): + """torch.Tensor: concatenated positive and negative boxes""" + return torch.cat([self.pos_masks, self.neg_masks]) + + def __nice__(self): + data = self.info.copy() + data["pos_masks"] = data.pop("pos_masks").shape + data["neg_masks"] = data.pop("neg_masks").shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = " " + ",\n ".join(parts) + return "{\n" + body + "\n}" + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + "pos_inds": self.pos_inds, + "neg_inds": self.neg_inds, + "pos_masks": self.pos_masks, + "neg_masks": self.neg_masks, + "pos_is_gt": self.pos_is_gt, + "num_gts": self.num_gts, + "pos_assigned_gt_inds": self.pos_assigned_gt_inds, + } diff --git a/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py b/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py new file mode 100644 index 0000000000000000000000000000000000000000..aaee3fe55aeb8c6da7edefbbd382d94b67b6a6b4 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/box/samplers/sampling_result.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch + + +class SamplingResult: + """Bbox sampling result. + + Example: + >>> # xdoctest: +IGNORE_WANT + >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random(rng=10) + >>> print(f'self = {self}') + self = <SamplingResult({ + 'neg_bboxes': torch.Size([12, 4]), + 'neg_inds': tensor([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12]), + 'num_gts': 4, + 'pos_assigned_gt_inds': tensor([], dtype=torch.int64), + 'pos_bboxes': torch.Size([0, 4]), + 'pos_inds': tensor([], dtype=torch.int64), + 'pos_is_gt': tensor([], dtype=torch.uint8) + })> + """ + + def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_bboxes = bboxes[pos_inds] + self.neg_bboxes = bboxes[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_bboxes.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_bboxes.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, 4) + + self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :] + + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def bboxes(self): + """torch.Tensor: concatenated positive and negative boxes""" + return torch.cat([self.pos_bboxes, self.neg_bboxes]) + + def to(self, device): + """Change the device of the data inplace. + + Example: + >>> self = SamplingResult.random() + >>> print(f'self = {self.to(None)}') + >>> # xdoctest: +REQUIRES(--gpu) + >>> print(f'self = {self.to(0)}') + """ + _dict = self.__dict__ + for key, value in _dict.items(): + if isinstance(value, torch.Tensor): + _dict[key] = value.to(device) + return self + + def __nice__(self): + data = self.info.copy() + data["pos_bboxes"] = data.pop("pos_bboxes").shape + data["neg_bboxes"] = data.pop("neg_bboxes").shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = " " + ",\n ".join(parts) + return "{\n" + body + "\n}" + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + "pos_inds": self.pos_inds, + "neg_inds": self.neg_inds, + "pos_bboxes": self.pos_bboxes, + "neg_bboxes": self.neg_bboxes, + "pos_is_gt": self.pos_is_gt, + "num_gts": self.num_gts, + "pos_assigned_gt_inds": self.pos_assigned_gt_inds, + } + + @classmethod + def random(cls, rng=None, **kwargs): + """ + Args: + rng (None | int | numpy.random.RandomState): seed or state. + kwargs (keyword arguments): + - num_preds: number of predicted boxes + - num_gts: number of true boxes + - p_ignore (float): probability of a predicted box assigned to \ + an ignored truth. + - p_assigned (float): probability of a predicted box not being \ + assigned. + - p_use_label (float | bool): with labels or not. + + Returns: + :obj:`SamplingResult`: Randomly generated sampling result. + + Example: + >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random() + >>> print(self.__dict__) + """ + from mmdet.core.bbox import demodata + from mmdet.core.bbox.assigners.assign_result import AssignResult + from mmdet.core.bbox.samplers.random_sampler import RandomSampler + + rng = demodata.ensure_rng(rng) + + # make probabalistic? + num = 32 + pos_fraction = 0.5 + neg_pos_ub = -1 + + assign_result = AssignResult.random(rng=rng, **kwargs) + + # Note we could just compute an assignment + bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng) + gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng) + + if rng.rand() > 0.2: + # sometimes algorithms squeeze their data, be robust to that + gt_bboxes = gt_bboxes.squeeze() + bboxes = bboxes.squeeze() + + if assign_result.labels is None: + gt_labels = None + else: + gt_labels = None + + if gt_labels is None: + add_gt_as_proposals = False + else: + add_gt_as_proposals = True # make probabalistic? + + sampler = RandomSampler( + num, pos_fraction, neg_pos_ub=neg_pos_ub, add_gt_as_proposals=add_gt_as_proposals, rng=rng + ) + self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels) + return self diff --git a/dinov2/eval/segmentation_m2f/core/utils/__init__.py b/dinov2/eval/segmentation_m2f/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6cdc9e19352f50bc2d5433c412ff71186c5df019 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dist_utils import reduce_mean +from .misc import add_prefix, multi_apply diff --git a/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py b/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7dfed42da821cd94e31b663d86b20b8f09799b30 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/utils/dist_utils.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch.distributed as dist + + +def reduce_mean(tensor): + """ "Obtain the mean of tensor on different GPUs.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor diff --git a/dinov2/eval/segmentation_m2f/core/utils/misc.py b/dinov2/eval/segmentation_m2f/core/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e07579e7b182b62153e81fe637ffd0f3081ef2a3 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/core/utils/misc.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from functools import partial + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f"{prefix}.{name}"] = value + + return outputs diff --git a/dinov2/eval/segmentation_m2f/models/__init__.py b/dinov2/eval/segmentation_m2f/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed89bb0064d82b4360af020798eab3d2f5a47937 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .backbones import * # noqa: F403 +from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost +from .decode_heads import * # noqa: F403 +from .losses import * # noqa: F403 +from .plugins import * # noqa: F403 +from .segmentors import * # noqa: F403 diff --git a/dinov2/eval/segmentation_m2f/models/backbones/__init__.py b/dinov2/eval/segmentation_m2f/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bf73bcbcee710676f81cb6517ae787f4d61cc6 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/backbones/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .vit_adapter import ViTAdapter diff --git a/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py b/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..26bfdf8f6ae6c107d22d61985cce34d4b5ce275f --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/backbones/adapter_modules.py @@ -0,0 +1,442 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from functools import partial + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp + +from ...ops.modules import MSDeformAttn +from .drop_path import DropPath + + +def get_reference_points(spatial_shapes, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / H_ + ref_x = ref_x.reshape(-1)[None] / W_ + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] + return reference_points + + +def deform_inputs(x, patch_size): + bs, c, h, w = x.shape + spatial_shapes = torch.as_tensor( + [(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device + ) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device) + deform_inputs1 = [reference_points, spatial_shapes, level_start_index] + + spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device) + deform_inputs2 = [reference_points, spatial_shapes, level_start_index] + + return deform_inputs1, deform_inputs2 + + +class ConvFFN(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + n = N // 21 + x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous() + x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous() + x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous() + x1 = self.dwconv(x1).flatten(2).transpose(1, 2) + x2 = self.dwconv(x2).flatten(2).transpose(1, 2) + x3 = self.dwconv(x3).flatten(2).transpose(1, 2) + x = torch.cat([x1, x2, x3], dim=1) + return x + + +class Extractor(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + with_cffn=True, + cffn_ratio=0.25, + drop=0.0, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + with_cp=False, + ): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn( + d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio + ) + self.with_cffn = with_cffn + self.with_cp = with_cp + if with_cffn: + self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop) + self.ffn_norm = norm_layer(dim) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W): + def _inner_forward(query, feat): + + attn = self.attn( + self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None + ) + query = query + attn + + if self.with_cffn: + query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W)) + return query + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class Injector(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + init_values=0.0, + with_cp=False, + ): + super().__init__() + self.with_cp = with_cp + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn( + d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio + ) + self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + + def forward(self, query, reference_points, feat, spatial_shapes, level_start_index): + def _inner_forward(query, feat): + + attn = self.attn( + self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None + ) + return query + self.gamma * attn + + if self.with_cp and query.requires_grad: + query = cp.checkpoint(_inner_forward, query, feat) + else: + query = _inner_forward(query, feat) + + return query + + +class InteractionBlock(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0.0, + drop_path=0.0, + with_cffn=True, + cffn_ratio=0.25, + init_values=0.0, + deform_ratio=1.0, + extra_extractor=False, + with_cp=False, + ): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cp=with_cp, + ) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + if extra_extractor: + self.extra_extractors = nn.Sequential( + *[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + for _ in range(2) + ] + ) + else: + self.extra_extractors = None + + def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2], + ) + for idx, blk in enumerate(blocks): + x = blk(x, H_toks, W_toks) + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + return x, c + + +class InteractionBlockWithCls(nn.Module): + def __init__( + self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0.0, + drop_path=0.0, + with_cffn=True, + cffn_ratio=0.25, + init_values=0.0, + deform_ratio=1.0, + extra_extractor=False, + with_cp=False, + ): + super().__init__() + + self.injector = Injector( + dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cp=with_cp, + ) + self.extractor = Extractor( + dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + if extra_extractor: + self.extra_extractors = nn.Sequential( + *[ + Extractor( + dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + drop=drop, + drop_path=drop_path, + with_cp=with_cp, + ) + for _ in range(2) + ] + ) + else: + self.extra_extractors = None + + def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks): + x = self.injector( + query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2], + ) + x = torch.cat((cls, x), dim=1) + for idx, blk in enumerate(blocks): + x = blk(x, H_toks, W_toks) + cls, x = ( + x[ + :, + :1, + ], + x[ + :, + 1:, + ], + ) + c = self.extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor( + query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H_c, + W=W_c, + ) + return x, c, cls + + +class SpatialPriorModule(nn.Module): + def __init__(self, inplanes=64, embed_dim=384, with_cp=False): + super().__init__() + self.with_cp = with_cp + + self.stem = nn.Sequential( + *[ + nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ] + ) + self.conv2 = nn.Sequential( + *[ + nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(2 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.conv3 = nn.Sequential( + *[ + nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.conv4 = nn.Sequential( + *[ + nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True), + ] + ) + self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, x): + def _inner_forward(x): + c1 = self.stem(x) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + c1 = self.fc1(c1) + c2 = self.fc2(c2) + c3 = self.fc3(c3) + c4 = self.fc4(c4) + + bs, dim, _, _ = c1.shape + # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s + c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s + c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s + c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s + + return c1, c2, c3, c4 + + if self.with_cp and x.requires_grad: + outs = cp.checkpoint(_inner_forward, x) + else: + outs = _inner_forward(x) + return outs diff --git a/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py b/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..864eb8738c44652d12b979fc811503f21cbb00dd --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/backbones/drop_path.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/dinov2/eval/segmentation_m2f/models/backbones/vit.py b/dinov2/eval/segmentation_m2f/models/backbones/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..8a147570451bd2fbd016ddfafbbfa33035cbd4f8 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/backbones/vit.py @@ -0,0 +1,552 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +"""Vision Transformer (ViT) in PyTorch. + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +import math +from functools import partial +from itertools import repeat +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.runner import BaseModule, load_checkpoint +from mmseg.ops import resize +from mmseg.utils import get_root_logger +from torch import Tensor + +from .drop_path import DropPath + + +def to_2tuple(x): + return tuple(repeat(x, 2)) + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + swiglu_hidden_features = int(2 * hidden_features / 3) + align_as = 8 + swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as + self.w1 = nn.Linear(in_features, swiglu_hidden_features) + self.w2 = nn.Linear(in_features, swiglu_hidden_features) + self.w3 = nn.Linear(swiglu_hidden_features, out_features) + + def forward(self, x: Tensor) -> Tensor: + x1 = self.w1(x) + x2 = self.w2(x) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding.""" + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, H, W + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, H, W): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, H, W) -> Tensor: + from xformers.ops import memory_efficient_attention, unbind + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowedAttention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, window_size=14, pad_mode="constant" + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.window_size = window_size + self.pad_mode = pad_mode + + def forward(self, x, H, W): + B, N, C = x.shape + N_ = self.window_size * self.window_size + H_ = math.ceil(H / self.window_size) * self.window_size + W_ = math.ceil(W / self.window_size) * self.window_size + + qkv = self.qkv(x) # [B, N, C] + qkv = qkv.transpose(1, 2).reshape(B, C * 3, H, W) # [B, C, H, W] + qkv = F.pad(qkv, [0, W_ - W, 0, H_ - H], mode=self.pad_mode) + + qkv = F.unfold( + qkv, kernel_size=(self.window_size, self.window_size), stride=(self.window_size, self.window_size) + ) + B, C_kw_kw, L = qkv.shape # L - the num of windows + qkv = qkv.reshape(B, C * 3, N_, L).permute(0, 3, 2, 1) # [B, L, N_, C] + qkv = qkv.reshape(B, L, N_, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + # q,k,v [B, L, num_head, N_, C/num_head] + attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_] + # if self.mask: + # attn = attn * mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) # [B, L, num_head, N_, N_] + # attn @ v = [B, L, num_head, N_, C/num_head] + x = (attn @ v).permute(0, 2, 4, 3, 1).reshape(B, C_kw_kw // 3, L) + + x = F.fold( + x, + output_size=(H_, W_), + kernel_size=(self.window_size, self.window_size), + stride=(self.window_size, self.window_size), + ) # [B, C, H_, W_] + x = x[:, :, :H, :W].reshape(B, C, N).transpose(-1, -2) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +# class WindowedAttention(nn.Module): +# def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., window_size=14, pad_mode="constant"): +# super().__init__() +# self.num_heads = num_heads +# head_dim = dim // num_heads +# self.scale = head_dim ** -0.5 +# +# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) +# self.attn_drop = nn.Dropout(attn_drop) +# self.proj = nn.Linear(dim, dim) +# self.proj_drop = nn.Dropout(proj_drop) +# self.window_size = window_size +# self.pad_mode = pad_mode +# +# def forward(self, x, H, W): +# B, N, C = x.shape +# +# N_ = self.window_size * self.window_size +# H_ = math.ceil(H / self.window_size) * self.window_size +# W_ = math.ceil(W / self.window_size) * self.window_size +# x = x.view(B, H, W, C) +# x = F.pad(x, [0, 0, 0, W_ - W, 0, H_- H], mode=self.pad_mode) +# +# x = window_partition(x, window_size=self.window_size)# nW*B, window_size, window_size, C +# x = x.view(-1, N_, C) +# +# qkv = self.qkv(x).view(-1, N_, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) +# q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) +# attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_] +# attn = attn.softmax(dim=-1) +# attn = self.attn_drop(attn) # [B, L, num_head, N_, N_] +# x = (attn @ v).transpose(1, 2).reshape(-1, self.window_size, self.window_size, C) +# +# x = window_reverse(x, self.window_size, H_, W_) +# x = x[:, :H, :W, :].reshape(B, N, C).contiguous() +# x = self.proj(x) +# x = self.proj_drop(x) +# return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + windowed=False, + window_size=14, + pad_mode="constant", + layer_scale=False, + with_cp=False, + ffn_layer=Mlp, + memeff=False, + ): + super().__init__() + self.with_cp = with_cp + self.norm1 = norm_layer(dim) + if windowed: + self.attn = WindowedAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + pad_mode=pad_mode, + ) + elif memeff: + self.attn = MemEffAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop + ) + else: + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.layer_scale = layer_scale + if layer_scale: + self.gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True) + self.gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True) + + def forward(self, x, H, W): + def _inner_forward(x): + if self.layer_scale: + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + + return x + + +class TIMMVisionTransformer(BaseModule): + """Vision Transformer. + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + layer_scale=True, + embed_layer=PatchEmbed, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + window_attn=False, + window_size=14, + pretrained=None, + with_cp=False, + pre_norm=False, + ffn_type="mlp", + memeff=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + pretrained: (str): pretrained path + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + self.norm_layer = norm_layer + self.act_layer = act_layer + self.pretrain_size = img_size + self.drop_path_rate = drop_path_rate + self.drop_rate = drop_rate + self.patch_size = patch_size + + window_attn = [window_attn] * depth if not isinstance(window_attn, list) else window_attn + window_size = [window_size] * depth if not isinstance(window_size, list) else window_size + logging.info("window attention:", window_attn) + logging.info("window size:", window_size) + logging.info("layer scale:", layer_scale) + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm + ) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + ffn_types = {"mlp": Mlp, "swiglu": SwiGLUFFN} + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + windowed=window_attn[i], + window_size=window_size[i], + layer_scale=layer_scale, + with_cp=with_cp, + ffn_layer=ffn_types[ffn_type], + memeff=memeff, + ) + for i in range(depth) + ] + ) + + # self.norm = norm_layer(embed_dim) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # For CLIP + if pre_norm: + norm_pre = norm_layer(embed_dim) + self.norm_pre = norm_pre + else: + self.norm_pre = nn.Identity() + self.init_weights(pretrained) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location="cpu", strict=False, logger=logger) + + def forward_features(self, x): + x, H, W = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_token, x), dim=1) + x = self.pos_drop(x + self.pos_embed) + + # For CLIP + x = self.norm_pre(x) + + for blk in self.blocks: + x = blk(x, H, W) + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + return x + + @staticmethod + def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): + """Resize pos_embed weights. + + Resize pos_embed using bicubic interpolate method. + Args: + pos_embed (torch.Tensor): Position embedding weights. + input_shpae (tuple): Tuple for (downsampled input image height, + downsampled input image width). + pos_shape (tuple): The resolution of downsampled origin training + image. + mode (str): Algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'``. Default: ``'nearest'`` + Return: + torch.Tensor: The resized pos_embed of shape [B, L_new, C] + """ + assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]" + pos_h, pos_w = pos_shape + # keep dim for easy deployment + cls_token_weight = pos_embed[:, 0:1] + pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :] + pos_embed_weight = pos_embed_weight.reshape(1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) + pos_embed_weight = resize(pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) + pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) + pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1) + return pos_embed diff --git a/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py b/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc4f0f65e04ed764464d141607b3b2073220f6b --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/backbones/vit_adapter.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.models.builder import BACKBONES +from torch.nn.init import normal_ + +from ...ops.modules import MSDeformAttn +from .adapter_modules import InteractionBlock, InteractionBlockWithCls, SpatialPriorModule, deform_inputs +from .vit import TIMMVisionTransformer + + +@BACKBONES.register_module() +class ViTAdapter(TIMMVisionTransformer): + def __init__( + self, + pretrain_size=224, + num_heads=12, + conv_inplane=64, + n_points=4, + deform_num_heads=6, + init_values=0.0, + interaction_indexes=None, + with_cffn=True, + cffn_ratio=0.25, + deform_ratio=1.0, + add_vit_feature=True, + pretrained=None, + use_extra_extractor=True, + freeze_vit=False, + use_cls=True, + with_cp=False, + *args, + **kwargs + ): + + super().__init__(num_heads=num_heads, pretrained=pretrained, with_cp=with_cp, *args, **kwargs) + if freeze_vit: + for param in self.parameters(): + param.requires_grad = False + + # self.num_classes = 80 + self.use_cls = use_cls + if not self.use_cls: + self.cls_token = None + self.num_block = len(self.blocks) + self.pretrain_size = (pretrain_size, pretrain_size) + self.interaction_indexes = interaction_indexes + self.add_vit_feature = add_vit_feature + embed_dim = self.embed_dim + + block_fn = InteractionBlockWithCls if use_cls else InteractionBlock + + self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) + self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False) + self.interactions = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=deform_num_heads, + n_points=n_points, + init_values=init_values, + drop_path=self.drop_path_rate, + norm_layer=self.norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor), + with_cp=with_cp, + ) + for i in range(len(interaction_indexes)) + ] + ) + self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) + self.norm1 = nn.SyncBatchNorm(embed_dim) + self.norm2 = nn.SyncBatchNorm(embed_dim) + self.norm3 = nn.SyncBatchNorm(embed_dim) + self.norm4 = nn.SyncBatchNorm(embed_dim) + + self.up.apply(self._init_weights) + self.spm.apply(self._init_weights) + self.interactions.apply(self._init_weights) + self.apply(self._init_deform_weights) + normal_(self.level_embed) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def _get_pos_embed(self, pos_embed, H, W): + pos_embed = pos_embed.reshape( + 1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1 + ).permute(0, 3, 1, 2) + pos_embed = ( + F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False) + .reshape(1, -1, H * W) + .permute(0, 2, 1) + ) + return pos_embed + + def _init_deform_weights(self, m): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + + def _add_level_embed(self, c2, c3, c4): + c2 = c2 + self.level_embed[0] + c3 = c3 + self.level_embed[1] + c4 = c4 + self.level_embed[2] + return c2, c3, c4 + + def forward(self, x): + deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size) + + # SPM forward + c1, c2, c3, c4 = self.spm(x) + c2, c3, c4 = self._add_level_embed(c2, c3, c4) + c = torch.cat([c2, c3, c4], dim=1) + + # Patch Embedding forward + H_c, W_c = x.shape[2] // 16, x.shape[3] // 16 + x, H_toks, W_toks = self.patch_embed(x) + # print("H_toks, W_toks =", H_toks, W_toks) + bs, n, dim = x.shape + pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H_toks, W_toks) + if self.use_cls: + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_token, x), dim=1) + pos_embed = torch.cat((self.pos_embed[:, :1], pos_embed), dim=1) + x = self.pos_drop(x + pos_embed) + # For CLIP + x = self.norm_pre(x) + + # Interaction + if self.use_cls: + cls, x = ( + x[ + :, + :1, + ], + x[ + :, + 1:, + ], + ) + outs = list() + for i, layer in enumerate(self.interactions): + indexes = self.interaction_indexes[i] + if self.use_cls: + x, c, cls = layer( + x, + c, + cls, + self.blocks[indexes[0] : indexes[-1] + 1], + deform_inputs1, + deform_inputs2, + H_c, + W_c, + H_toks, + W_toks, + ) + else: + x, c = layer( + x, + c, + self.blocks[indexes[0] : indexes[-1] + 1], + deform_inputs1, + deform_inputs2, + H_c, + W_c, + H_toks, + W_toks, + ) + outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous()) + + # Split & Reshape + c2 = c[:, 0 : c2.size(1), :] + c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :] + c4 = c[:, c2.size(1) + c3.size(1) :, :] + + c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous() + c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous() + c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous() + c1 = self.up(c2) + c1 + + if self.add_vit_feature: + x1, x2, x3, x4 = outs + + x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False) + x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False) + x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False) + x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False) + # print(c1.shape, c2.shape, c3.shape, c4.shape, x1.shape, x2.shape, x3.shape, x4.shape, H_c, H_toks) + c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 + + # Final Norm + f1 = self.norm1(c1) + f2 = self.norm2(c2) + f3 = self.norm3(c3) + f4 = self.norm4(c4) + return [f1, f2, f3, f4] diff --git a/dinov2/eval/segmentation_m2f/models/builder.py b/dinov2/eval/segmentation_m2f/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..d7cf7b919f6b0e8e00bde45bc244d9c29a36fed6 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/builder.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from mmcv.utils import Registry + +TRANSFORMER = Registry("Transformer") +MASK_ASSIGNERS = Registry("mask_assigner") +MATCH_COST = Registry("match_cost") + + +def build_match_cost(cfg): + """Build Match Cost.""" + return MATCH_COST.build(cfg) + + +def build_assigner(cfg): + """Build Assigner.""" + return MASK_ASSIGNERS.build(cfg) + + +def build_transformer(cfg): + """Build Transformer.""" + return TRANSFORMER.build(cfg) diff --git a/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py b/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01f08b88950750337781fc671adfea2a935ea8fe --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/decode_heads/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .mask2former_head import Mask2FormerHead diff --git a/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py b/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d1705fc444fa8d1583d88fca36d7fe1e060db9e7 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/decode_heads/mask2former_head.py @@ -0,0 +1,544 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init +from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence +from mmcv.ops import point_sample +from mmcv.runner import ModuleList, force_fp32 +from mmseg.models.builder import HEADS, build_loss +from mmseg.models.decode_heads.decode_head import BaseDecodeHead + +from ...core import build_sampler, multi_apply, reduce_mean +from ..builder import build_assigner +from ..utils import get_uncertain_point_coords_with_randomness + + +@HEADS.register_module() +class Mask2FormerHead(BaseDecodeHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer decoder. + pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel + decoder. Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of tranformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder. Defaults to None. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder position encoding. Defaults to None. + loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification + loss. Defaults to None. + loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. + Defaults to None. + loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. + Defaults to None. + train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of + Mask2Former head. + test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of + Mask2Former head. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + in_channels, + feat_channels, + out_channels, + num_things_classes=80, + num_stuff_classes=53, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=None, + enforce_decoder_input_project=False, + transformer_decoder=None, + positional_encoding=None, + loss_cls=None, + loss_mask=None, + loss_dice=None, + train_cfg=None, + test_cfg=None, + init_cfg=None, + **kwargs, + ): + super(Mask2FormerHead, self).__init__( + in_channels=in_channels, + channels=feat_channels, + num_classes=(num_things_classes + num_stuff_classes), + init_cfg=init_cfg, + input_transform="multiple_select", + **kwargs, + ) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.transformerlayers.attn_cfgs.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update(in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels) + self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] + self.transformer_decoder = build_transformer_layer_sequence(transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if self.decoder_embed_dims != feat_channels or enforce_decoder_input_project: + self.decoder_input_projs.append(Conv2d(feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = build_positional_encoding(positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels) + + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), + nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), + nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels), + ) + self.conv_seg = None # fix a bug here (conv_seg is not used) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + self.sampler = build_sampler(self.train_cfg.sampler, context=self) + self.num_points = self.train_cfg.get("num_points", 12544) + self.oversample_ratio = self.train_cfg.get("oversample_ratio", 3.0) + self.importance_sample_ratio = self.train_cfg.get("importance_sample_ratio", 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.loss_dice = build_loss(loss_dice) + + def init_weights(self): + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas): + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape [num_queries, + cls_out_channels]. + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape [num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for all + images. Each with shape (n, ), n is the sum of number of stuff + type and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[list[Tensor]]: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images. + Each with shape [num_queries, ]. + - label_weights_list (list[Tensor]): Label weights of all + images.Each with shape [num_queries, ]. + - mask_targets_list (list[Tensor]): Mask targets of all images. + Each with shape [num_queries, h, w]. + - mask_weights_list (list[Tensor]): Mask weights of all images. + Each with shape [num_queries, ]. + - num_total_pos (int): Number of positive samples in all + images. + - num_total_neg (int): Number of negative samples in all + images. + """ + ( + labels_list, + label_weights_list, + mask_targets_list, + mask_weights_list, + pos_inds_list, + neg_inds_list, + ) = multi_apply( + self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas + ) + + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas): + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_labels (Tensor): Ground truth class indices for one image with + shape (num_gts, ). + gt_masks (Tensor): Ground truth mask for each image, each with + shape (num_gts, h, w). + img_metas (dict): Image informtation. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each \ + image. + - neg_inds (Tensor): Sampled negative indices for each \ + image. + """ + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1) + + # assign and sample + assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas) + sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((self.num_queries,), self.num_classes, dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones((self.num_queries,)) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((self.num_queries,)) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds) + + def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas): + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + gt_labels_list (list[Tensor]): Ground truth class indices for each + image, each with shape (num_gts, ). + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (num_gts, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single \ + decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + ( + labels_list, + label_weights_list, + mask_targets_list, + mask_weights_list, + num_total_pos, + num_total_neg, + ) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas) + # shape (batch_size, num_queries) + labels = torch.stack(labels_list, dim=0) + # shape (batch_size, num_queries) + label_weights = torch.stack(label_weights_list, dim=0) + # shape (num_total_gts, h, w) + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_tensor(self.class_weight) + loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio + ) + # shape (num_total_gts, h, w) -> (num_total_gts, num_points) + mask_point_targets = point_sample(mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1) + + # dice loss + loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks) + + # mask loss + # shape (num_queries, num_points) -> (num_queries * num_points, ) + mask_point_preds = mask_point_preds.reshape(-1, 1) + # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) + mask_point_targets = mask_point_targets.reshape(-1) + loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points) + + return loss_cls, loss_mask, loss_dice + + @force_fp32(apply_to=("all_cls_scores", "all_mask_preds")) + def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas): + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape [num_decoder, batch_size, num_queries, + cls_out_channels]. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape [num_decoder, batch_size, num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (n, ). n is the sum of number of stuff type + and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image with + shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + losses_cls, losses_mask, losses_dice = multi_apply( + self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list + ) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict["loss_cls"] = losses_cls[-1] + loss_dict["loss_mask"] = losses_mask[-1] + loss_dict["loss_dice"] = losses_dice[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): + loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i + loss_dict[f"d{num_dec_layer}.loss_mask"] = loss_mask_i + loss_dict[f"d{num_dec_layer}.loss_dice"] = loss_dice_i + num_dec_layer += 1 + return loss_dict + + def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (num_queries, batch_size, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + decoder_out = decoder_out.transpose(0, 1) + # shape (num_queries, batch_size, c) + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_feature) + attn_mask = F.interpolate(mask_pred, attn_mask_target_size, mode="bilinear", align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat((1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask + + def forward(self, feats, img_metas): + """Forward function. + + Args: + feats (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + + Returns: + tuple: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_size = len(img_metas) + mask_features, multi_scale_memorys = self.pixel_decoder(feats) + # multi_scale_memorys (from low resolution to high resolution) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + decoder_input = decoder_input.flatten(2).permute(2, 0, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + mask = decoder_input.new_zeros((batch_size,) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding(mask) + decoder_positional_encoding = decoder_positional_encoding.flatten(2).permute(2, 0, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + # shape (num_queries, c) -> (num_queries, batch_size, c) + query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, batch_size, 1)) + query_embed = self.query_embed.weight.unsqueeze(1).repeat((1, batch_size, 1)) + + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask = self.forward_head(query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + attn_masks = [attn_mask, None] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + attn_masks=attn_masks, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None, + ) + cls_pred, mask_pred, attn_mask = self.forward_head( + query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:] + ) + + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + return cls_pred_list, mask_pred_list + + def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, gt_masks): + """Forward function for training mode. + + Args: + x (list[Tensor]): Multi-level features from the upstream network, + each is a 4D-tensor. + img_metas (list[Dict]): List of image information. + gt_semantic_seg (list[tensor]):Each element is the ground truth + of semantic segmentation with the shape (N, H, W). + train_cfg (dict): The training config, which not been used in + maskformer. + gt_labels (list[Tensor]): Each element is ground truth labels of + each box, shape (num_gts,). + gt_masks (list[BitmapMasks]): Each element is masks of instances + of a image, shape (num_gts, h, w). + + Returns: + losses (dict[str, Tensor]): a dictionary of loss components + """ + + # forward + all_cls_scores, all_mask_preds = self(x, img_metas) + + # loss + losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas) + + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Test segment without test-time aumengtation. + + Only the output of last decoder layers was used. + + Args: + inputs (list[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + test_cfg (dict): Testing config. + + Returns: + seg_mask (Tensor): Predicted semantic segmentation logits. + """ + all_cls_scores, all_mask_preds = self(inputs, img_metas) + cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1] + ori_h, ori_w, _ = img_metas[0]["ori_shape"] + + # semantic inference + cls_score = F.softmax(cls_score, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + seg_mask = torch.einsum("bqc,bqhw->bchw", cls_score, mask_pred) + return seg_mask diff --git a/dinov2/eval/segmentation_m2f/models/losses/__init__.py b/dinov2/eval/segmentation_m2f/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..229a887817372f4991b32354180592cfb236d728 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/losses/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .cross_entropy_loss import CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy +from .dice_loss import DiceLoss +from .match_costs import ClassificationCost, CrossEntropyLossCost, DiceCost diff --git a/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py b/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1f9dd4aa52ebe94cc527db36b1c7fa2f53813e --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.models.builder import LOSSES +from mmseg.models.losses.utils import get_class_weight, weight_reduce_loss + + +def cross_entropy( + pred, + label, + weight=None, + class_weight=None, + reduction="mean", + avg_factor=None, + ignore_index=-100, + avg_non_ignore=False, +): + """cross_entropy. The wrapper function for :func:`F.cross_entropy` + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + Default: None. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. + Options are 'none', 'mean' and 'sum'. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Default: None. + ignore_index (int): Specifies a target value that is ignored and + does not contribute to the input gradients. When + ``avg_non_ignore `` is ``True``, and the ``reduction`` is + ``''mean''``, the loss is averaged over non-ignored targets. + Defaults: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + # class_weight is a manual rescaling weight given to each class. + # If given, has to be a Tensor of size C element-wise losses + loss = F.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index) + + # apply weights and do the reduction + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and avg_non_ignore and reduction == "mean": + avg_factor = label.numel() - (label == ignore_index).sum().item() + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights = bin_label_weights * valid_mask + + return bin_labels, bin_label_weights, valid_mask + + +def binary_cross_entropy( + pred, + label, + weight=None, + reduction="mean", + avg_factor=None, + class_weight=None, + ignore_index=-100, + avg_non_ignore=False, + **kwargs, +): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + Note: In bce loss, label < 0 is invalid. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + + Returns: + torch.Tensor: The calculated loss + """ + if pred.size(1) == 1: + # For binary class segmentation, the shape of pred is + # [N, 1, H, W] and that of label is [N, H, W]. + assert label.max() <= 1, "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes" + pred = pred.squeeze() + if pred.dim() != label.dim(): + assert (pred.dim() == 2 and label.dim() == 1) or (pred.dim() == 4 and label.dim() == 3), ( + "Only pred shape [N, C], label shape [N] or pred shape [N, C, " "H, W], label shape [N, H, W] are supported" + ) + # `weight` returned from `_expand_onehot_labels` + # has been treated for valid (non-ignore) pixels + label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.shape, ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + weight = weight * valid_mask + else: + weight = valid_mask + # average loss over non-ignored and valid elements + if reduction == "mean" and avg_factor is None and avg_non_ignore: + avg_factor = valid_mask.sum().item() + + loss = F.binary_cross_entropy_with_logits(pred, label.float(), pos_weight=class_weight, reduction="none") + # do the reduction for the weighted loss + loss = weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy( + pred, target, label, reduction="mean", avg_factor=None, class_weight=None, ignore_index=None, **kwargs +): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask' + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + """ + assert ignore_index is None, "BCE loss does not support ignore_index" + assert reduction == "mean" and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits(pred_slice, target, weight=class_weight, reduction="mean")[None] + + +@LOSSES.register_module(force=True) +class CrossEntropyLoss(nn.Module): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + def __init__( + self, + use_sigmoid=False, + use_mask=False, + reduction="mean", + class_weight=None, + loss_weight=1.0, + loss_name="loss_ce", + avg_non_ignore=False, + ): + super(CrossEntropyLoss, self).__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self.avg_non_ignore = avg_non_ignore + if not self.avg_non_ignore and self.reduction == "mean": + warnings.warn( + "Default ``avg_non_ignore`` is False, if you would like to " + "ignore the certain label and average loss over non-ignore " + "labels, which is the same with PyTorch official " + "cross_entropy, set ``avg_non_ignore=True``." + ) + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + self._loss_name = loss_name + + def extra_repr(self): + """Extra repr.""" + s = f"avg_non_ignore={self.avg_non_ignore}" + return s + + def forward( + self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, ignore_index=-100, **kwargs + ): + """Forward function.""" + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + # Note: for BCE loss, label < 0 is invalid. + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + avg_non_ignore=self.avg_non_ignore, + ignore_index=ignore_index, + **kwargs, + ) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py b/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc5ba893c502861032ed531283f225e183eb693 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/losses/dice_loss.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from mmseg.models.builder import LOSSES +from mmseg.models.losses.utils import weight_reduce_loss + + +def dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None): + """Calculate dice loss, which is proposed in + `V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + + input = pred.flatten(1) + target = target.flatten(1).float() + + a = torch.sum(input * target, 1) + b = torch.sum(input * input, 1) + eps + c = torch.sum(target * target, 1) + eps + d = (2 * a) / (b + c) + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +def naive_dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None): + """Calculate naive dice loss, the coefficient in the denominator is the + first power instead of the second power. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + input = pred.flatten(1) + target = target.flatten(1).float() + + a = torch.sum(input * target, 1) + b = torch.sum(input, 1) + c = torch.sum(target, 1) + d = (2 * a + eps) / (b + c + eps) + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@LOSSES.register_module(force=True) +class DiceLoss(nn.Module): + def __init__(self, use_sigmoid=True, activate=True, reduction="mean", naive_dice=False, loss_weight=1.0, eps=1e-3): + """Dice Loss, there are two forms of dice loss is supported: + + - the one proposed in `V-Net: Fully Convolutional Neural + Networks for Volumetric Medical Image Segmentation + <https://arxiv.org/abs/1606.04797>`_. + - the dice loss in which the power of the number in the + denominator is the first power instead of the second + power. + + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + activate (bool): Whether to activate the predictions inside, + this will disable the inside sigmoid operation. + Defaults to True. + reduction (str, optional): The method used + to reduce the loss. Options are "none", + "mean" and "sum". Defaults to 'mean'. + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power.Defaults to False. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + eps (float): Avoid dividing by zero. Defaults to 1e-3. + """ + + super(DiceLoss, self).__init__() + self.use_sigmoid = use_sigmoid + self.reduction = reduction + self.naive_dice = naive_dice + self.loss_weight = loss_weight + self.eps = eps + self.activate = activate + + def forward(self, pred, target, weight=None, reduction_override=None, avg_factor=None): + """Forward function. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *). + target (torch.Tensor): The label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + + assert reduction_override in (None, "none", "mean", "sum") + reduction = reduction_override if reduction_override else self.reduction + + if self.activate: + if self.use_sigmoid: + pred = pred.sigmoid() + else: + raise NotImplementedError + + if self.naive_dice: + loss = self.loss_weight * naive_dice_loss( + pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor + ) + else: + loss = self.loss_weight * dice_loss( + pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor + ) + + return loss diff --git a/dinov2/eval/segmentation_m2f/models/losses/match_costs.py b/dinov2/eval/segmentation_m2f/models/losses/match_costs.py new file mode 100644 index 0000000000000000000000000000000000000000..4917d2a939c01398dd49c0d90b06f4c37d283ce0 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/losses/match_costs.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +from ..builder import MATCH_COST + + +@MATCH_COST.register_module() +class ClassificationCost: + """ClsSoftmaxCost.Borrow from + mmdet.core.bbox.match_costs.match_cost.ClassificationCost. + + Args: + weight (int | float, optional): loss_weight + + Examples: + >>> import torch + >>> self = ClassificationCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3430, -0.3525, -0.3045], + [-0.3077, -0.2931, -0.3992], + [-0.3664, -0.3455, -0.2881], + [-0.3343, -0.2701, -0.3956]]) + """ + + def __init__(self, weight=1.0): + self.weight = weight + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + # Following the official DETR repo, contrary to the loss that + # NLL is used, we approximate it in 1 - cls_score[gt_label]. + # The 1 is a constant that doesn't change the matching, + # so it can be omitted. + cls_score = cls_pred.softmax(-1) + cls_cost = -cls_score[:, gt_labels] + return cls_cost * self.weight + + +@MATCH_COST.register_module() +class DiceCost: + """Cost of mask assignments based on dice losses. + + Args: + weight (int | float, optional): loss_weight. Defaults to 1. + pred_act (bool, optional): Whether to apply sigmoid to mask_pred. + Defaults to False. + eps (float, optional): default 1e-12. + """ + + def __init__(self, weight=1.0, pred_act=False, eps=1e-3): + self.weight = weight + self.pred_act = pred_act + self.eps = eps + + def binary_mask_dice_loss(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction in shape (N1, H, W). + gt_masks (Tensor): Ground truth in shape (N2, H, W) + store 0 or 1, 0 for negative class and 1 for + positive class. + + Returns: + Tensor: Dice cost matrix in shape (N1, N2). + """ + mask_preds = mask_preds.reshape((mask_preds.shape[0], -1)) + gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float() + numerator = 2 * torch.einsum("nc,mc->nm", mask_preds, gt_masks) + denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :] + loss = 1 - (numerator + self.eps) / (denominator + self.eps) + return loss + + def __call__(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction logits in shape (N1, H, W). + gt_masks (Tensor): Ground truth in shape (N2, H, W). + + Returns: + Tensor: Dice cost matrix in shape (N1, N2). + """ + if self.pred_act: + mask_preds = mask_preds.sigmoid() + dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks) + return dice_cost * self.weight + + +@MATCH_COST.register_module() +class CrossEntropyLossCost: + """CrossEntropyLossCost. + + Args: + weight (int | float, optional): loss weight. Defaults to 1. + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to True. + """ + + def __init__(self, weight=1.0, use_sigmoid=True): + assert use_sigmoid, "use_sigmoid = False is not supported yet." + self.weight = weight + self.use_sigmoid = use_sigmoid + + def _binary_cross_entropy(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): The prediction with shape (num_query, 1, *) or + (num_query, *). + gt_labels (Tensor): The learning label of prediction with + shape (num_gt, *). + Returns: + Tensor: Cross entropy cost matrix in shape (num_query, num_gt). + """ + cls_pred = cls_pred.flatten(1).float() + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + pos = F.binary_cross_entropy_with_logits(cls_pred, torch.ones_like(cls_pred), reduction="none") + neg = F.binary_cross_entropy_with_logits(cls_pred, torch.zeros_like(cls_pred), reduction="none") + cls_cost = torch.einsum("nc,mc->nm", pos, gt_labels) + torch.einsum("nc,mc->nm", neg, 1 - gt_labels) + cls_cost = cls_cost / n + + return cls_cost + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits. + gt_labels (Tensor): Labels. + Returns: + Tensor: Cross entropy cost matrix with weight in + shape (num_query, num_gt). + """ + if self.use_sigmoid: + cls_cost = self._binary_cross_entropy(cls_pred, gt_labels) + else: + raise NotImplementedError + + return cls_cost * self.weight diff --git a/dinov2/eval/segmentation_m2f/models/plugins/__init__.py b/dinov2/eval/segmentation_m2f/models/plugins/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81a60db4de31238cb38e078683e5ca265839fe60 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/plugins/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder diff --git a/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py b/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..db1947175917f73f3f24184cb09c78e092d46ef8 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/plugins/msdeformattn_pixel_decoder.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init, normal_init, xavier_init +from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence +from mmcv.runner import BaseModule, ModuleList + +from ...core.anchor import MlvlPointGenerator +from ..utils.transformer import MultiScaleDeformableAttention + + +@PLUGIN_LAYERS.register_module() +class MSDeformAttnPixelDecoder(BaseModule): + """Pixel decoder with multi-scale deformable attention. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + strides (list[int] | tuple[int]): Output strides of feature from + backbone. + feat_channels (int): Number of channels for feature. + out_channels (int): Number of channels for output. + num_outs (int): Number of output scales. + norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer + encoder. Defaults to `DetrTransformerEncoder`. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer encoder position encoding. Defaults to + dict(type='SinePositionalEncoding', num_feats=128, + normalize=True). + init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict. + """ + + def __init__( + self, + in_channels=[256, 512, 1024, 2048], + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_outs=3, + norm_cfg=dict(type="GN", num_groups=32), + act_cfg=dict(type="ReLU"), + encoder=dict( + type="DetrTransformerEncoder", + num_layers=6, + transformerlayers=dict( + type="BaseTransformerLayer", + attn_cfgs=dict( + type="MultiScaleDeformableAttention", + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None, + ), + feedforward_channels=1024, + ffn_dropout=0.0, + operation_order=("self_attn", "norm", "ffn", "norm"), + ), + init_cfg=None, + ), + positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self.strides = strides + self.num_input_levels = len(in_channels) + self.num_encoder_levels = encoder.transformerlayers.attn_cfgs.num_levels + assert self.num_encoder_levels >= 1, "num_levels in attn_cfgs must be at least one" + input_conv_list = [] + # from top to down (low to high resolution) + for i in range(self.num_input_levels - 1, self.num_input_levels - self.num_encoder_levels - 1, -1): + input_conv = ConvModule( + in_channels[i], feat_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=None, bias=True + ) + input_conv_list.append(input_conv) + self.input_convs = ModuleList(input_conv_list) + + self.encoder = build_transformer_layer_sequence(encoder) + self.postional_encoding = build_positional_encoding(positional_encoding) + # high resolution to low resolution + self.level_encoding = nn.Embedding(self.num_encoder_levels, feat_channels) + + # fpn-like structure + self.lateral_convs = ModuleList() + self.output_convs = ModuleList() + self.use_bias = norm_cfg is None + # from top to down (low to high resolution) + # fpn for the rest features that didn't pass in encoder + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1): + lateral_conv = ConvModule( + in_channels[i], feat_channels, kernel_size=1, bias=self.use_bias, norm_cfg=norm_cfg, act_cfg=None + ) + output_conv = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + self.lateral_convs.append(lateral_conv) + self.output_convs.append(output_conv) + + self.mask_feature = Conv2d(feat_channels, out_channels, kernel_size=1, stride=1, padding=0) + + self.num_outs = num_outs + self.point_generator = MlvlPointGenerator(strides) + + def init_weights(self): + """Initialize weights.""" + for i in range(0, self.num_encoder_levels): + xavier_init(self.input_convs[i].conv, gain=1, bias=0, distribution="uniform") + + for i in range(0, self.num_input_levels - self.num_encoder_levels): + caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) + caffe2_xavier_init(self.output_convs[i].conv, bias=0) + + caffe2_xavier_init(self.mask_feature, bias=0) + + normal_init(self.level_encoding, mean=0, std=1) + for p in self.encoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + # init_weights defined in MultiScaleDeformableAttention + for layer in self.encoder.layers: + for attn in layer.attentions: + if isinstance(attn, MultiScaleDeformableAttention): + attn.init_weights() + + def forward(self, feats): + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + + Returns: + tuple: A tuple containing the following: + + - mask_feature (Tensor): shape (batch_size, c, h, w). + - multi_scale_features (list[Tensor]): Multi scale \ + features, each in shape (batch_size, c, h, w). + """ + # generate padding mask for each level, for each image + batch_size = feats[0].shape[0] + encoder_input_list = [] + padding_mask_list = [] + level_positional_encoding_list = [] + spatial_shapes = [] + reference_points_list = [] + for i in range(self.num_encoder_levels): + level_idx = self.num_input_levels - i - 1 + feat = feats[level_idx] + feat_projected = self.input_convs[i](feat) + h, w = feat.shape[-2:] + + # no padding + padding_mask_resized = feat.new_zeros((batch_size,) + feat.shape[-2:], dtype=torch.bool) + pos_embed = self.postional_encoding(padding_mask_resized) + level_embed = self.level_encoding.weight[i] + level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed + # (h_i * w_i, 2) + reference_points = self.point_generator.single_level_grid_priors( + feat.shape[-2:], level_idx, device=feat.device + ) + # normalize + factor = feat.new_tensor([[w, h]]) * self.strides[level_idx] + reference_points = reference_points / factor + + # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c) + feat_projected = feat_projected.flatten(2).permute(2, 0, 1) + level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1) + padding_mask_resized = padding_mask_resized.flatten(1) + + encoder_input_list.append(feat_projected) + padding_mask_list.append(padding_mask_resized) + level_positional_encoding_list.append(level_pos_embed) + spatial_shapes.append(feat.shape[-2:]) + reference_points_list.append(reference_points) + # shape (batch_size, total_num_query), + # total_num_query=sum([., h_i * w_i,.]) + padding_masks = torch.cat(padding_mask_list, dim=1) + # shape (total_num_query, batch_size, c) + encoder_inputs = torch.cat(encoder_input_list, dim=0) + level_positional_encodings = torch.cat(level_positional_encoding_list, dim=0) + device = encoder_inputs.device + # shape (num_encoder_levels, 2), from low + # resolution to high resolution + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=device) + # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = torch.cat(reference_points_list, dim=0) + reference_points = reference_points[None, :, None].repeat(batch_size, 1, self.num_encoder_levels, 1) + valid_radios = reference_points.new_ones((batch_size, self.num_encoder_levels, 2)) + # shape (num_total_query, batch_size, c) + memory = self.encoder( + query=encoder_inputs, + key=None, + value=None, + query_pos=level_positional_encodings, + key_pos=None, + attn_masks=None, + key_padding_mask=None, + query_key_padding_mask=padding_masks, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_radios=valid_radios, + ) + # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query) + memory = memory.permute(1, 2, 0) + + # from low resolution to high resolution + num_query_per_level = [e[0] * e[1] for e in spatial_shapes] + outs = torch.split(memory, num_query_per_level, dim=-1) + outs = [x.reshape(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) for i, x in enumerate(outs)] + + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1): + x = feats[i] + cur_feat = self.lateral_convs[i](x) + y = cur_feat + F.interpolate(outs[-1], size=cur_feat.shape[-2:], mode="bilinear", align_corners=False) + y = self.output_convs[i](y) + outs.append(y) + multi_scale_features = outs[: self.num_outs] + + mask_feature = self.mask_feature(outs[-1]) + return mask_feature, multi_scale_features diff --git a/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py b/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..adf0062691e4889612e118f28ced853cd0bc33db --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/segmentors/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .encoder_decoder_mask2former import EncoderDecoderMask2Former diff --git a/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py b/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe572c9d317303bff8d51b85217d144906ebfe7 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/segmentors/encoder_decoder_mask2former.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.core import add_prefix +from mmseg.models import builder +from mmseg.models.builder import SEGMENTORS +from mmseg.models.segmentors.base import BaseSegmentor +from mmseg.ops import resize + + +@SEGMENTORS.register_module() +class EncoderDecoderMask2Former(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + + def __init__( + self, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None, + ): + super(EncoderDecoderMask2Former, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get("pretrained") is None, "both backbone and segmentor set pretrained weight" + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + if neck is not None: + self.neck = builder.build_neck(neck) + decode_head.update(train_cfg=train_cfg) + decode_head.update(test_cfg=test_cfg) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + + def _init_auxiliary_head(self, auxiliary_head): + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(builder.build_head(head_cfg)) + else: + self.auxiliary_head = builder.build_head(auxiliary_head) + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + out = resize(input=out, size=img.shape[2:], mode="bilinear", align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(x, img_metas, gt_semantic_seg, **kwargs) + + losses.update(add_prefix(loss_decode, "decode")) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return seg_logits + + def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, f"aux_{idx}")) + else: + loss_aux = self.auxiliary_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, "aux")) + + return losses + + def forward_dummy(self, img): + """Dummy forward function.""" + seg_logit = self.encode_decode(img, None) + + return seg_logit + + def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, img_metas, gt_semantic_seg, **kwargs) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train(x, img_metas, gt_semantic_seg) + losses.update(loss_aux) + + return losses + + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + num_classes = self.num_classes + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + if rescale: + preds = resize( + preds, + size=img_meta[0]["ori_shape"][:2], + mode="bilinear", + align_corners=self.align_corners, + warning=False, + ) + return preds + + def whole_inference(self, img, img_meta, rescale): + """Inference with full image.""" + + seg_logit = self.encode_decode(img, img_meta) + if rescale: + # support dynamic shape for onnx + if torch.onnx.is_in_onnx_export(): + size = img.shape[2:] + else: + size = img_meta[0]["ori_shape"][:2] + seg_logit = resize(seg_logit, size=size, mode="bilinear", align_corners=self.align_corners, warning=False) + + return seg_logit + + def inference(self, img, img_meta, rescale): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output segmentation map. + """ + + assert self.test_cfg.mode in ["slide", "whole"] + ori_shape = img_meta[0]["ori_shape"] + assert all(_["ori_shape"] == ori_shape for _ in img_meta) + if self.test_cfg.mode == "slide": + seg_logit = self.slide_inference(img, img_meta, rescale) + else: + seg_logit = self.whole_inference(img, img_meta, rescale) + output = F.softmax(seg_logit, dim=1) + flip = img_meta[0]["flip"] + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + seg_logit = self.inference(img, img_meta, rescale) + seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + seg_pred = seg_pred.unsqueeze(0) + return seg_pred + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) + seg_logit += cur_seg_logit + seg_logit /= len(imgs) + seg_pred = seg_logit.argmax(dim=1) + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/dinov2/eval/segmentation_m2f/models/utils/__init__.py b/dinov2/eval/segmentation_m2f/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7fdc1668b1015c8feea8fa1a4691bc0ebdbd936 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .assigner import MaskHungarianAssigner +from .point_sample import get_uncertain_point_coords_with_randomness +from .positional_encoding import LearnedPositionalEncoding, SinePositionalEncoding +from .transformer import DetrTransformerDecoder, DetrTransformerDecoderLayer, DynamicConv, Transformer diff --git a/dinov2/eval/segmentation_m2f/models/utils/assigner.py b/dinov2/eval/segmentation_m2f/models/utils/assigner.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb08fc1bb2e36336989b45a1d3850f260c05963 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/utils/assigner.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from abc import ABCMeta, abstractmethod + +import torch + +from ..builder import MASK_ASSIGNERS, build_match_cost + +try: + from scipy.optimize import linear_sum_assignment +except ImportError: + linear_sum_assignment = None + + +class AssignResult(metaclass=ABCMeta): + """Collection of assign results.""" + + def __init__(self, num_gts, gt_inds, labels): + self.num_gts = num_gts + self.gt_inds = gt_inds + self.labels = labels + + @property + def info(self): + info = { + "num_gts": self.num_gts, + "gt_inds": self.gt_inds, + "labels": self.labels, + } + return info + + +class BaseAssigner(metaclass=ABCMeta): + """Base assigner that assigns boxes to ground truth boxes.""" + + @abstractmethod + def assign(self, masks, gt_masks, gt_masks_ignore=None, gt_labels=None): + """Assign boxes to either a ground truth boxes or a negative boxes.""" + pass + + +@MASK_ASSIGNERS.register_module() +class MaskHungarianAssigner(BaseAssigner): + """Computes one-to-one matching between predictions and ground truth for + mask. + + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of three components: + classification cost, regression L1 cost and regression iou cost. The + targets don't include the no_object, so generally there are more + predictions than targets. After the one-to-one matching, the un-matched + are treated as backgrounds. Thus each query prediction will be assigned + with `0` or a positive integer indicating the ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + cls_cost (obj:`mmcv.ConfigDict`|dict): Classification cost config. + mask_cost (obj:`mmcv.ConfigDict`|dict): Mask cost config. + dice_cost (obj:`mmcv.ConfigDict`|dict): Dice cost config. + """ + + def __init__( + self, + cls_cost=dict(type="ClassificationCost", weight=1.0), + dice_cost=dict(type="DiceCost", weight=1.0), + mask_cost=dict(type="MaskFocalCost", weight=1.0), + ): + self.cls_cost = build_match_cost(cls_cost) + self.dice_cost = build_match_cost(dice_cost) + self.mask_cost = build_match_cost(mask_cost) + + def assign(self, cls_pred, mask_pred, gt_labels, gt_masks, img_meta, gt_masks_ignore=None, eps=1e-7): + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + + Args: + mask_pred (Tensor): Predicted mask, shape [num_query, h, w] + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_masks (Tensor): Ground truth mask, shape [num_gt, h, w]. + gt_labels (Tensor): Label of `gt_masks`, shape (num_gt,). + img_meta (dict): Meta information for current image. + gt_masks_ignore (Tensor, optional): Ground truth masks that are + labelled as `ignored`. Default None. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + assert gt_masks_ignore is None, "Only case when gt_masks_ignore is None is supported." + num_gts, num_queries = gt_labels.shape[0], cls_pred.shape[0] + + # 1. assign -1 by default + assigned_gt_inds = cls_pred.new_full((num_queries,), -1, dtype=torch.long) + assigned_labels = cls_pred.new_full((num_queries,), -1, dtype=torch.long) + if num_gts == 0 or num_queries == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) + + # 2. compute the weighted costs + # classification and maskcost. + if self.cls_cost.weight != 0 and cls_pred is not None: + cls_cost = self.cls_cost(cls_pred, gt_labels) + else: + cls_cost = 0 + + if self.mask_cost.weight != 0: + # mask_pred shape = [nq, h, w] + # gt_mask shape = [ng, h, w] + # mask_cost shape = [nq, ng] + mask_cost = self.mask_cost(mask_pred, gt_masks) + else: + mask_cost = 0 + + if self.dice_cost.weight != 0: + dice_cost = self.dice_cost(mask_pred, gt_masks) + else: + dice_cost = 0 + cost = cls_cost + mask_cost + dice_cost + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' "to install scipy first.") + + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + matched_row_inds = torch.from_numpy(matched_row_inds).to(cls_pred.device) + matched_col_inds = torch.from_numpy(matched_col_inds).to(cls_pred.device) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] + return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) diff --git a/dinov2/eval/segmentation_m2f/models/utils/point_sample.py b/dinov2/eval/segmentation_m2f/models/utils/point_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1134082bafb51432618a9632592db070f87284 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/utils/point_sample.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +from mmcv.ops import point_sample + + +def get_uncertainty(mask_pred, labels): + """Estimate uncertainty based on pred logits. + + We estimate uncertainty as L1 distance between 0.0 and the logits + prediction in 'mask_pred' for the foreground class in `classes`. + + Args: + mask_pred (Tensor): mask predication logits, shape (num_rois, + num_classes, mask_height, mask_width). + + labels (list[Tensor]): Either predicted or ground truth label for + each predicted mask, of length num_rois. + + Returns: + scores (Tensor): Uncertainty scores with the most uncertain + locations having the highest uncertainty score, + shape (num_rois, 1, mask_height, mask_width) + """ + if mask_pred.shape[1] == 1: + gt_class_logits = mask_pred.clone() + else: + inds = torch.arange(mask_pred.shape[0], device=mask_pred.device) + gt_class_logits = mask_pred[inds, labels].unsqueeze(1) + return -torch.abs(gt_class_logits) + + +def get_uncertain_point_coords_with_randomness( + mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio +): + """Get ``num_points`` most uncertain points with random points during + train. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'get_uncertainty()' function that takes point's logit prediction as + input. + + Args: + mask_pred (Tensor): A tensor of shape (num_rois, num_classes, + mask_height, mask_width) for class-specific or class-agnostic + prediction. + labels (list): The ground truth class for each instance. + num_points (int): The number of points to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled + via importnace sampling. + + Returns: + point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) + that contains the coordinates sampled points. + """ + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = mask_pred.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device) + point_logits = point_sample(mask_pred, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = get_uncertainty(point_logits, labels) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=mask_pred.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device) + point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) + return point_coords diff --git a/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py b/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5d6fabe946d06fe97cc799da47bae93758b34e --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/utils/positional_encoding.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING +from mmcv.runner import BaseModule + + +@POSITIONAL_ENCODING.register_module() +class SinePositionalEncoding(BaseModule): + """Position encoding with sine and cosine functions. + + See `End-to-End Object Detection with Transformers + <https://arxiv.org/pdf/2005.12872>`_ for details. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. Note the final returned dimension + for each position is 2 times of this value. + temperature (int, optional): The temperature used for scaling + the position embedding. Defaults to 10000. + normalize (bool, optional): Whether to normalize the position + embedding. Defaults to False. + scale (float, optional): A scale factor that scales the position + embedding. The scale will be used only when `normalize` is True. + Defaults to 2*pi. + eps (float, optional): A value added to the denominator for + numerical stability. Defaults to 1e-6. + offset (float): offset add to embed when do the normalization. + Defaults to 0. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__( + self, num_feats, temperature=10000, normalize=False, scale=2 * math.pi, eps=1e-6, offset=0.0, init_cfg=None + ): + super(SinePositionalEncoding, self).__init__(init_cfg) + if normalize: + assert isinstance(scale, (float, int)), ( + "when normalize is set," "scale should be provided and in float or int type, " f"found {type(scale)}" + ) + self.num_feats = num_feats + self.temperature = temperature + self.normalize = normalize + self.scale = scale + self.eps = eps + self.offset = offset + + def forward(self, mask): + """Forward function for `SinePositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + # For convenience of exporting to ONNX, it's required to convert + # `masks` from bool to int. + mask = mask.to(torch.int) + not_mask = 1 - mask # logical_not + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + y_embed = (y_embed + self.offset) / (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = (x_embed + self.offset) / (x_embed[:, :, -1:] + self.eps) * self.scale + dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + # use `view` instead of `flatten` for dynamically exporting to ONNX + B, H, W = mask.size() + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f"(num_feats={self.num_feats}, " + repr_str += f"temperature={self.temperature}, " + repr_str += f"normalize={self.normalize}, " + repr_str += f"scale={self.scale}, " + repr_str += f"eps={self.eps})" + return repr_str + + +@POSITIONAL_ENCODING.register_module() +class LearnedPositionalEncoding(BaseModule): + """Position embedding with learnable embedding weights. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. The final returned dimension for + each position is 2 times of this value. + row_num_embed (int, optional): The dictionary size of row embeddings. + Default 50. + col_num_embed (int, optional): The dictionary size of col embeddings. + Default 50. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, num_feats, row_num_embed=50, col_num_embed=50, init_cfg=dict(type="Uniform", layer="Embedding")): + super(LearnedPositionalEncoding, self).__init__(init_cfg) + self.row_embed = nn.Embedding(row_num_embed, num_feats) + self.col_embed = nn.Embedding(col_num_embed, num_feats) + self.num_feats = num_feats + self.row_num_embed = row_num_embed + self.col_num_embed = col_num_embed + + def forward(self, mask): + """Forward function for `LearnedPositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + h, w = mask.shape[-2:] + x = torch.arange(w, device=mask.device) + y = torch.arange(h, device=mask.device) + x_embed = self.col_embed(x) + y_embed = self.row_embed(y) + pos = ( + torch.cat((x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(1, w, 1)), dim=-1) + .permute(2, 0, 1) + .unsqueeze(0) + .repeat(mask.shape[0], 1, 1, 1) + ) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f"(num_feats={self.num_feats}, " + repr_str += f"row_num_embed={self.row_num_embed}, " + repr_str += f"col_num_embed={self.col_num_embed})" + return repr_str diff --git a/dinov2/eval/segmentation_m2f/models/utils/transformer.py b/dinov2/eval/segmentation_m2f/models/utils/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8befe6011a34d5ccecb82c8b17b61e19f732f96b --- /dev/null +++ b/dinov2/eval/segmentation_m2f/models/utils/transformer.py @@ -0,0 +1,989 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import warnings +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv.cnn import Linear, build_activation_layer, build_norm_layer, xavier_init +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.registry import FEEDFORWARD_NETWORK, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE +from mmcv.cnn.bricks.transformer import BaseTransformerLayer, TransformerLayerSequence, build_transformer_layer_sequence +from mmcv.runner.base_module import BaseModule, Sequential +from mmcv.utils import deprecated_api_warning, to_2tuple +from torch.nn.init import normal_ + +from ..builder import TRANSFORMER + +try: + from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention + +except ImportError: + warnings.warn( + "`MultiScaleDeformableAttention` in MMCV has been moved to " + "`mmcv.ops.multi_scale_deform_attn`, please update your MMCV" + ) + from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1 + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + + def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"): + + super(AdaptivePadding, self).__init__() + + assert padding in ("same", "corner") + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == "corner": + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == "same": + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return x + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + to gets fully covered by filter and stride you specified.. + Default: True. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding="corner", + dilation=1, + bias=False, + norm_cfg=dict(type="LN"), + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding + ) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f"Expect " f"input_size is " f"`Sequence` " f"but get {input_size}" + + H, W = input_size + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = ( + H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1 + ) // self.sampler.stride[0] + 1 + out_w = ( + W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1 + ) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size + + +def inverse_sigmoid(x, eps=1e-5): + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the + inverse. + eps (float): EPS avoid numerical + overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse + function of sigmoid, has same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +@FEEDFORWARD_NETWORK.register_module(force=True) +class FFN(BaseModule): + """Implements feed-forward networks (FFNs) with identity connection. + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. Defaults: 256. + feedforward_channels (int): The hidden dimension of FFNs. + Defaults: 1024. + num_fcs (int, optional): The number of fully-connected layers in + FFNs. Default: 2. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='ReLU') + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + add_identity (bool, optional): Whether to add the + identity connection. Default: `True`. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + @deprecated_api_warning({"dropout": "ffn_drop", "add_residual": "add_identity"}, cls_name="FFN") + def __init__( + self, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + act_cfg=dict(type="ReLU", inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True, + init_cfg=None, + with_cp=False, + **kwargs, + ): + super().__init__(init_cfg) + assert num_fcs >= 2, "num_fcs should be no less " f"than 2. got {num_fcs}." + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + self.with_cp = with_cp + layers = [] + in_channels = embed_dims + for _ in range(num_fcs - 1): + layers.append(Sequential(Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(ffn_drop))) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity() + self.add_identity = add_identity + + @deprecated_api_warning({"residual": "identity"}, cls_name="FFN") + def forward(self, x, identity=None): + """Forward function for `FFN`. + The function would add x to the output tensor if residue is None. + """ + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.layers, x) + else: + out = self.layers(x) + + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + return identity + self.dropout_layer(out) + + +@TRANSFORMER_LAYER.register_module() +class DetrTransformerDecoderLayer(BaseTransformerLayer): + """Implements decoder layer in DETR transformer. + + Args: + attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): + Configs for self_attention or cross_attention, the order + should be consistent with it in `operation_order`. If it is + a dict, it would be expand to the number of attention in + `operation_order`. + feedforward_channels (int): The hidden dimension for FFNs. + ffn_dropout (float): Probability of an element to be zeroed + in ffn. Default 0.0. + operation_order (tuple[str]): The execution order of operation + in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). + Default:None + act_cfg (dict): The activation config for FFNs. Default: `LN` + norm_cfg (dict): Config dict for normalization layer. + Default: `LN`. + ffn_num_fcs (int): The number of fully-connected layers in FFNs. + Default:2. + """ + + def __init__( + self, + attn_cfgs, + feedforward_channels, + ffn_dropout=0.0, + operation_order=None, + act_cfg=dict(type="ReLU", inplace=True), + norm_cfg=dict(type="LN"), + ffn_num_fcs=2, + **kwargs, + ): + super(DetrTransformerDecoderLayer, self).__init__( + attn_cfgs=attn_cfgs, + feedforward_channels=feedforward_channels, + ffn_dropout=ffn_dropout, + operation_order=operation_order, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + ffn_num_fcs=ffn_num_fcs, + **kwargs, + ) + assert len(operation_order) == 6 + assert set(operation_order) == set(["self_attn", "norm", "cross_attn", "ffn"]) + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DetrTransformerEncoder(TransformerLayerSequence): + """TransformerEncoder of DETR. + + Args: + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. Only used when `self.pre_norm` is `True` + """ + + def __init__(self, *args, post_norm_cfg=dict(type="LN"), **kwargs): + super(DetrTransformerEncoder, self).__init__(*args, **kwargs) + if post_norm_cfg is not None: + self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None + else: + assert not self.pre_norm, f"Use prenorm in " f"{self.__class__.__name__}," f"Please specify post_norm_cfg" + self.post_norm = None + + def forward(self, *args, **kwargs): + """Forward function for `TransformerCoder`. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + x = super(DetrTransformerEncoder, self).forward(*args, **kwargs) + if self.post_norm is not None: + x = self.post_norm(x) + return x + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DetrTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + + def __init__(self, *args, post_norm_cfg=dict(type="LN"), return_intermediate=False, **kwargs): + + super(DetrTransformerDecoder, self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + if post_norm_cfg is not None: + self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] + else: + self.post_norm = None + + def forward(self, query, *args, **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + if not self.return_intermediate: + x = super().forward(query, *args, **kwargs) + if self.post_norm: + x = self.post_norm(x)[None] + return x + + intermediate = [] + for layer in self.layers: + query = layer(query, *args, **kwargs) + if self.return_intermediate: + if self.post_norm is not None: + intermediate.append(self.post_norm(query)) + else: + intermediate.append(query) + return torch.stack(intermediate) + + +@TRANSFORMER.register_module() +class Transformer(BaseModule): + """Implements the DETR transformer. + + Following the official DETR implementation, this module copy-paste + from torch.nn.Transformer with modifications: + + * positional encodings are passed in MultiheadAttention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers + + See `paper: End-to-End Object Detection with Transformers + <https://arxiv.org/pdf/2005.12872>`_ for details. + + Args: + encoder (`mmcv.ConfigDict` | Dict): Config of + TransformerEncoder. Defaults to None. + decoder ((`mmcv.ConfigDict` | Dict)): Config of + TransformerDecoder. Defaults to None + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Defaults to None. + """ + + def __init__(self, encoder=None, decoder=None, init_cfg=None): + super(Transformer, self).__init__(init_cfg=init_cfg) + self.encoder = build_transformer_layer_sequence(encoder) + self.decoder = build_transformer_layer_sequence(decoder) + self.embed_dims = self.encoder.embed_dims + + def init_weights(self): + # follow the official DETR to init parameters + for m in self.modules(): + if hasattr(m, "weight") and m.weight.dim() > 1: + xavier_init(m, distribution="uniform") + self._is_init = True + + def forward(self, x, mask, query_embed, pos_embed): + """Forward function for `Transformer`. + + Args: + x (Tensor): Input query with shape [bs, c, h, w] where + c = embed_dims. + mask (Tensor): The key_padding_mask used for encoder and decoder, + with shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, with shape + [num_query, c]. + pos_embed (Tensor): The positional encoding for encoder and + decoder, with the same shape as `x`. + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - out_dec: Output from decoder. If return_intermediate_dec \ + is True output has shape [num_dec_layers, bs, + num_query, embed_dims], else has shape [1, bs, \ + num_query, embed_dims]. + - memory: Output results from encoder, with shape \ + [bs, embed_dims, h, w]. + """ + bs, c, h, w = x.shape + # use `view` instead of `flatten` for dynamically exporting to ONNX + x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] + pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] + mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w] + memory = self.encoder(query=x, key=None, value=None, query_pos=pos_embed, query_key_padding_mask=mask) + target = torch.zeros_like(query_embed) + # out_dec: [num_layers, num_query, bs, dim] + out_dec = self.decoder( + query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask + ) + out_dec = out_dec.transpose(1, 2) + memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) + return out_dec, memory + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DeformableDetrTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + coder_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + + def __init__(self, *args, return_intermediate=False, **kwargs): + + super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + + def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + reference_points (Tensor): The reference + points of offset. has shape + (bs, num_query, 4) when as_two_stage, + otherwise has shape ((bs, num_query, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + reg_branch: (obj:`nn.ModuleList`): Used for + refining the regression results. Only would + be passed when with_box_refine is True, + otherwise would be passed a `None`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = ( + reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None] + ) + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * valid_ratios[:, None] + output = layer(output, *args, reference_points=reference_points_input, **kwargs) + output = output.permute(1, 0, 2) + + if reg_branches is not None: + tmp = reg_branches[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + output = output.permute(1, 0, 2) + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points) + + return output, reference_points + + +@TRANSFORMER.register_module() +class DeformableDetrTransformer(Transformer): + """Implements the DeformableDETR transformer. + + Args: + as_two_stage (bool): Generate query from encoder features. + Default: False. + num_feature_levels (int): Number of feature maps from FPN: + Default: 4. + two_stage_num_proposals (int): Number of proposals when set + `as_two_stage` as True. Default: 300. + """ + + def __init__(self, as_two_stage=False, num_feature_levels=4, two_stage_num_proposals=300, **kwargs): + super(DeformableDetrTransformer, self).__init__(**kwargs) + self.as_two_stage = as_two_stage + self.num_feature_levels = num_feature_levels + self.two_stage_num_proposals = two_stage_num_proposals + self.embed_dims = self.encoder.embed_dims + self.init_layers() + + def init_layers(self): + """Initialize layers of the DeformableDetrTransformer.""" + self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dims)) + + if self.as_two_stage: + self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) + self.enc_output_norm = nn.LayerNorm(self.embed_dims) + self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2) + self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) + else: + self.reference_points = nn.Linear(self.embed_dims, 2) + + def init_weights(self): + """Initialize the transformer weights.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + if not self.as_two_stage: + xavier_init(self.reference_points, distribution="uniform", bias=0.0) + normal_(self.level_embeds) + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + """Generate proposals from encoded memory. + + Args: + memory (Tensor) : The output of encoder, + has shape (bs, num_key, embed_dim). num_key is + equal the number of points on feature map from + all level. + memory_padding_mask (Tensor): Padding mask for memory. + has shape (bs, num_key). + spatial_shapes (Tensor): The shape of all feature maps. + has shape (num_level, 2). + + Returns: + tuple: A tuple of feature map and bbox prediction. + + - output_memory (Tensor): The input of decoder, \ + has shape (bs, num_key, embed_dim). num_key is \ + equal the number of points on feature map from \ + all levels. + - output_proposals (Tensor): The normalized proposal \ + after a inverse sigmoid, has shape \ + (bs, num_keys, 4). + """ + + N, S, C = memory.shape + proposals = [] + _cur = 0 + for lvl, (H, W) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device), + torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device), + ) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(N, -1, 4) + proposals.append(proposal) + _cur += H * W + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf")) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + """Get the reference points used in decoder. + + Args: + spatial_shapes (Tensor): The shape of all + feature maps, has shape (num_level, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + device (obj:`device`): The device where + reference_points should be. + + Returns: + Tensor: reference points used in decoder, has \ + shape (bs, num_keys, num_levels, 2). + """ + reference_points_list = [] + for lvl, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device), + torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device), + ) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def get_valid_ratio(self, mask): + """Get the valid radios of feature maps of all level.""" + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000): + """Get the position embedding of proposal.""" + scale = 2 * math.pi + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def forward( + self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, reg_branches=None, cls_branches=None, **kwargs + ): + """Forward function for `Transformer`. + + Args: + mlvl_feats (list(Tensor)): Input queries from + different level. Each element has shape + [bs, embed_dims, h, w]. + mlvl_masks (list(Tensor)): The key_padding_mask from + different level used for encoder and decoder, + each element has shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, + with shape [num_query, c]. + mlvl_pos_embeds (list(Tensor)): The positional encoding + of feats from different level, has the shape + [bs, embed_dims, h, w]. + reg_branches (obj:`nn.ModuleList`): Regression heads for + feature maps from each decoder layer. Only would + be passed when + `with_box_refine` is True. Default to None. + cls_branches (obj:`nn.ModuleList`): Classification heads + for feature maps from each decoder layer. Only would + be passed when `as_two_stage` + is True. Default to None. + + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - inter_states: Outputs from decoder. If + return_intermediate_dec is True output has shape \ + (num_dec_layers, bs, num_query, embed_dims), else has \ + shape (1, bs, num_query, embed_dims). + - init_reference_out: The initial value of reference \ + points, has shape (bs, num_queries, 4). + - inter_references_out: The internal value of reference \ + points in decoder, has shape \ + (num_dec_layers, bs,num_query, embed_dims) + - enc_outputs_class: The classification score of \ + proposals generated from \ + encoder's feature maps, has shape \ + (batch, h*w, num_classes). \ + Only would be returned when `as_two_stage` is True, \ + otherwise None. + - enc_outputs_coord_unact: The regression results \ + generated from encoder's feature maps., has shape \ + (batch, h*w, 4). Only would \ + be returned when `as_two_stage` is True, \ + otherwise None. + """ + assert self.as_two_stage or query_embed is not None + + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + feat_flatten.append(feat) + mask_flatten.append(mask) + feat_flatten = torch.cat(feat_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feat_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1) + + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device) + + feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + memory = self.encoder( + query=feat_flatten, + key=None, + value=None, + query_pos=lvl_pos_embed_flatten, + query_key_padding_mask=mask_flatten, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + **kwargs, + ) + + memory = memory.permute(1, 0, 2) + bs, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) + enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory) + enc_outputs_coord_unact = reg_branches[self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) + query_pos, query = torch.split(pos_trans_out, c, dim=2) + else: + query_pos, query = torch.split(query_embed, c, dim=1) + query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) + query = query.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_pos).sigmoid() + init_reference_out = reference_points + + # decoder + query = query.permute(1, 0, 2) + memory = memory.permute(1, 0, 2) + query_pos = query_pos.permute(1, 0, 2) + inter_states, inter_references = self.decoder( + query=query, + key=None, + value=memory, + query_pos=query_pos, + key_padding_mask=mask_flatten, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + **kwargs, + ) + + inter_references_out = inter_references + if self.as_two_stage: + return inter_states, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact + return inter_states, init_reference_out, inter_references_out, None, None + + +@TRANSFORMER.register_module() +class DynamicConv(BaseModule): + """Implements Dynamic Convolution. + + This module generate parameters for each sample and + use bmm to implement 1*1 convolution. Code is modified + from the `official github repo <https://github.com/PeizeSun/ + SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ . + + Args: + in_channels (int): The input feature channel. + Defaults to 256. + feat_channels (int): The inner feature channel. + Defaults to 64. + out_channels (int, optional): The output feature channel. + When not specified, it will be set to `in_channels` + by default + input_feat_shape (int): The shape of input feature. + Defaults to 7. + with_proj (bool): Project two-dimentional feature to + one-dimentional feature. Default to True. + act_cfg (dict): The activation config for DynamicConv. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__( + self, + in_channels=256, + feat_channels=64, + out_channels=None, + input_feat_shape=7, + with_proj=True, + act_cfg=dict(type="ReLU", inplace=True), + norm_cfg=dict(type="LN"), + init_cfg=None, + ): + super(DynamicConv, self).__init__(init_cfg) + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.input_feat_shape = input_feat_shape + self.with_proj = with_proj + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.in_channels * self.feat_channels + self.num_params_out = self.out_channels * self.feat_channels + self.dynamic_layer = nn.Linear(self.in_channels, self.num_params_in + self.num_params_out) + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + num_output = self.out_channels * input_feat_shape**2 + if self.with_proj: + self.fc_layer = nn.Linear(num_output, self.out_channels) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, param_feature, input_feature): + """Forward function for `DynamicConv`. + + Args: + param_feature (Tensor): The feature can be used + to generate the parameter, has shape + (num_all_proposals, in_channels). + input_feature (Tensor): Feature that + interact with parameters, has shape + (num_all_proposals, in_channels, H, W). + + Returns: + Tensor: The output feature has shape + (num_all_proposals, out_channels). + """ + input_feature = input_feature.flatten(2).permute(2, 0, 1) + + input_feature = input_feature.permute(1, 0, 2) + parameters = self.dynamic_layer(param_feature) + + param_in = parameters[:, : self.num_params_in].view(-1, self.in_channels, self.feat_channels) + param_out = parameters[:, -self.num_params_out :].view(-1, self.feat_channels, self.out_channels) + + # input_feature has shape (num_all_proposals, H*W, in_channels) + # param_in has shape (num_all_proposals, in_channels, feat_channels) + # feature has shape (num_all_proposals, H*W, feat_channels) + features = torch.bmm(input_feature, param_in) + features = self.norm_in(features) + features = self.activation(features) + + # param_out has shape (batch_size, feat_channels, out_channels) + features = torch.bmm(features, param_out) + features = self.norm_out(features) + features = self.activation(features) + + if self.with_proj: + features = features.flatten(1) + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features diff --git a/dinov2/eval/segmentation_m2f/ops/modules/__init__.py b/dinov2/eval/segmentation_m2f/ops/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49aa8fe612fd4c088e294707c5ee16bd1cb5b5e7 --- /dev/null +++ b/dinov2/eval/segmentation_m2f/ops/modules/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/fundamentalvision/Deformable-DETR/tree/main/models/ops/modules +# https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 + +from .ms_deform_attn import MSDeformAttn diff --git a/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py b/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b4fa23712e87d1a2682b57e71ee37fe8524cff --- /dev/null +++ b/dinov2/eval/segmentation_m2f/ops/modules/ms_deform_attn.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import warnings + +import torch +import torch.nn.functional as F +from torch import nn +from torch.autograd import Function +from torch.cuda.amp import custom_fwd +from torch.nn.init import constant_, xavier_uniform_ + + +class MSDeformAttnFunction(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward( + ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step + ): + output = ms_deform_attn_core_pytorch( + value, + value_spatial_shapes, + # value_level_start_index, + sampling_locations, + attention_weights, + ) + return output + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_) + return output.transpose(1, 2).contiguous() + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, ratio=1.0): + """Multi-Scale Deformable Attention Module. + + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError("d_model must be divisible by n_heads, " "but got {} and {}".format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 + # which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn( + "You'd better set d_model in MSDeformAttn to make " + "the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation." + ) + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + self.ratio = ratio + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, int(d_model * ratio)) + self.output_proj = nn.Linear(int(d_model * ratio), d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def forward( + self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None, + ): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + # print(query.shape) + # print(reference_points.shape) + # print(input_flatten.shape) + # print(input_spatial_shapes.shape) + # print(input_level_start_index.shape) + # print(input_spatial_shapes) + # print(input_level_start_index) + + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + + value = value.view(N, Len_in, self.n_heads, int(self.ratio * self.d_model) // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = ( + reference_points[:, :, None, :, None, :] + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + ) + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + ) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1]) + ) + output = MSDeformAttnFunction.apply( + value, + input_spatial_shapes, + input_level_start_index, + sampling_locations, + attention_weights, + self.im2col_step, + ) + output = self.output_proj(output) + return output diff --git a/notebooks/semantic_segmentation.ipynb b/notebooks/semantic_segmentation.ipynb index 08815fefad7c9a60969fc58f2ce3d630eca365b2..8cea91446624638ebf9a494daf29fe3f3811d29c 100644 --- a/notebooks/semantic_segmentation.ipynb +++ b/notebooks/semantic_segmentation.ipynb @@ -165,16 +165,15 @@ "BACKBONE_SIZE = \"small\" # in (\"small\", \"base\", \"large\" or \"giant\")\n", "\n", "\n", - "BACKBONE_ARCHS = {\n", + "backbone_archs = {\n", " \"small\": \"vits14\",\n", " \"base\": \"vitb14\",\n", " \"large\": \"vitl14\",\n", " \"giant\": \"vitg14\",\n", "}\n", - "\n", - "\n", - "backbone_arch = BACKBONE_ARCHS[BACKBONE_SIZE]\n", + "backbone_arch = backbone_archs[BACKBONE_SIZE]\n", "backbone_name = f\"dinov2_{backbone_arch}\"\n", + "\n", "backbone_model = torch.hub.load(repo_or_dir=\"facebookresearch/dinov2\", model=backbone_name)\n", "backbone_model.eval()\n", "backbone_model.cuda()" @@ -200,20 +199,20 @@ "text": [ "/private/home/plabatut/.conda/envs/dinov2-extras-conda/lib/python3.9/site-packages/mmseg/models/losses/cross_entropy_loss.py:235: UserWarning: Default ``avg_non_ignore`` is False, if you would like to ignore the certain label and average loss over non-ignore labels, which is the same with PyTorch official cross_entropy, set ``avg_non_ignore=True``.\n", " warnings.warn(\n", - "2023-08-31 06:05:23,461 - mmcv - INFO - initialize BNHead with init_cfg {'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}\n", - "2023-08-31 06:05:23,463 - mmcv - INFO - \n", + "2023-08-31 06:29:03,743 - mmcv - INFO - initialize BNHead with init_cfg {'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}\n", + "2023-08-31 06:29:03,744 - mmcv - INFO - \n", "decode_head.conv_seg.weight - torch.Size([21, 1536, 1, 1]): \n", "NormalInit: mean=0, std=0.01, bias=0 \n", " \n", - "2023-08-31 06:05:23,464 - mmcv - INFO - \n", + "2023-08-31 06:29:03,745 - mmcv - INFO - \n", "decode_head.conv_seg.bias - torch.Size([21]): \n", "NormalInit: mean=0, std=0.01, bias=0 \n", " \n", - "2023-08-31 06:05:23,464 - mmcv - INFO - \n", + "2023-08-31 06:29:03,745 - mmcv - INFO - \n", "decode_head.bn.weight - torch.Size([1536]): \n", "The value is the same before and after calling `init_weights` of EncoderDecoder \n", " \n", - "2023-08-31 06:05:23,465 - mmcv - INFO - \n", + "2023-08-31 06:29:03,746 - mmcv - INFO - \n", "decode_head.bn.bias - torch.Size([1536]): \n", "The value is the same before and after calling `init_weights` of EncoderDecoder \n", " \n" @@ -360,9 +359,8 @@ "}\n", "\n", "\n", - "colormap = DATASET_COLORMAPS[HEAD_DATASET]\n", - "\n", - "def render_segmentation(segmentation_logits):\n", + "def render_segmentation(segmentation_logits, dataset):\n", + " colormap = DATASET_COLORMAPS[dataset]\n", " colormap_array = np.array(colormap, dtype=np.uint8)\n", " segmentation_values = colormap_array[segmentation_logits + 1]\n", " return Image.fromarray(segmentation_values)\n", @@ -370,7 +368,1086 @@ "\n", "array = np.array(image)[:, :, ::-1] # BGR\n", "segmentation_logits = inference_segmentor(model, array)[0]\n", - "segmented_image = render_segmentation(segmentation_logits)\n", + "segmented_image = render_segmentation(segmentation_logits, HEAD_DATASET)\n", + "display(segmented_image)" + ] + }, + { + "cell_type": "markdown", + "id": "de40012e-a01e-4e73-bb71-3048f16d41c8", + "metadata": {}, + "source": [ + "## Load pretrained segmentation model (Mask2Former)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ff2cbbbe-c53c-4e5b-977f-c2a7d93f4b8c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/private/home/plabatut/github/patricklabatut/dinov2/dinov2/eval/segmentation_m2f/models/losses/cross_entropy_loss.py:222: UserWarning: Default ``avg_non_ignore`` is False, if you would like to ignore the certain label and average loss over non-ignore labels, which is the same with PyTorch official cross_entropy, set ``avg_non_ignore=True``.\n", + " warnings.warn(\n", + "/private/home/plabatut/.conda/envs/dinov2-extras-conda/lib/python3.9/site-packages/mmcv/ops/multi_scale_deform_attn.py:209: UserWarning: You'd better set embed_dims in MultiScaleDeformAttention to make the dimension of each attention head a power of 2 which is more efficient in our CUDA implementation.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "load checkpoint from http path: https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_ade20k_m2f.pth\n" + ] + }, + { + "data": { + "text/plain": [ + "EncoderDecoderMask2Former(\n", + " (backbone): ViTAdapter(\n", + " (patch_embed): PatchEmbed(\n", + " (proj): Conv2d(3, 1536, kernel_size=(14, 14), stride=(14, 14))\n", + " (norm): Identity()\n", + " )\n", + " (pos_drop): Dropout(p=0.0, inplace=False)\n", + " (blocks): Sequential(\n", + " (0): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): Identity()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (1): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (2): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (3): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (4): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (5): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (6): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (7): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (8): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (9): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (10): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (11): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (12): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (13): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (14): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (15): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (16): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (17): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (18): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (19): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (20): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (21): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (22): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (23): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (24): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (25): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (26): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (27): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (28): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (29): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (30): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (31): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (32): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (33): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (34): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (35): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (36): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (37): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (38): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (39): Block(\n", + " (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): Attention(\n", + " (qkv): Linear(in_features=1536, out_features=4608, bias=True)\n", + " (attn_drop): Dropout(p=0.0, inplace=False)\n", + " (proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (drop_path): DropPath()\n", + " (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (mlp): SwiGLUFFN(\n", + " (w1): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w2): Linear(in_features=1536, out_features=4096, bias=True)\n", + " (w3): Linear(in_features=4096, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " )\n", + " (norm_pre): Identity()\n", + " (spm): SpatialPriorModule(\n", + " (stem): Sequential(\n", + " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (1): SyncBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (4): SyncBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (7): SyncBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (8): ReLU(inplace=True)\n", + " (9): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", + " )\n", + " (conv2): Sequential(\n", + " (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (1): SyncBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " )\n", + " (conv3): Sequential(\n", + " (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (1): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " )\n", + " (conv4): Sequential(\n", + " (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", + " (1): SyncBatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " )\n", + " (fc1): Conv2d(64, 1536, kernel_size=(1, 1), stride=(1, 1))\n", + " (fc2): Conv2d(128, 1536, kernel_size=(1, 1), stride=(1, 1))\n", + " (fc3): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1))\n", + " (fc4): Conv2d(256, 1536, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (interactions): Sequential(\n", + " (0): InteractionBlockWithCls(\n", + " (injector): Injector(\n", + " (query_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (feat_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MSDeformAttn(\n", + " (sampling_offsets): Linear(in_features=1536, out_features=576, bias=True)\n", + " (attention_weights): Linear(in_features=1536, out_features=288, bias=True)\n", + " (value_proj): Linear(in_features=1536, out_features=768, bias=True)\n", + " (output_proj): Linear(in_features=768, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (extractor): Extractor(\n", + " (query_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (feat_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MSDeformAttn(\n", + " (sampling_offsets): Linear(in_features=1536, out_features=192, bias=True)\n", + " (attention_weights): Linear(in_features=1536, out_features=96, bias=True)\n", + " (value_proj): Linear(in_features=1536, out_features=768, bias=True)\n", + " (output_proj): Linear(in_features=768, out_features=1536, bias=True)\n", + " )\n", + " (ffn): ConvFFN(\n", + " (fc1): Linear(in_features=1536, out_features=384, bias=True)\n", + " (dwconv): DWConv(\n", + " (dwconv): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=384, out_features=1536, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ffn_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (drop_path): DropPath()\n", + " )\n", + " )\n", + " (1): InteractionBlockWithCls(\n", + " (injector): Injector(\n", + " (query_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (feat_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MSDeformAttn(\n", + " (sampling_offsets): Linear(in_features=1536, out_features=576, bias=True)\n", + " (attention_weights): Linear(in_features=1536, out_features=288, bias=True)\n", + " (value_proj): Linear(in_features=1536, out_features=768, bias=True)\n", + " (output_proj): Linear(in_features=768, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (extractor): Extractor(\n", + " (query_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (feat_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MSDeformAttn(\n", + " (sampling_offsets): Linear(in_features=1536, out_features=192, bias=True)\n", + " (attention_weights): Linear(in_features=1536, out_features=96, bias=True)\n", + " (value_proj): Linear(in_features=1536, out_features=768, bias=True)\n", + " (output_proj): Linear(in_features=768, out_features=1536, bias=True)\n", + " )\n", + " (ffn): ConvFFN(\n", + " (fc1): Linear(in_features=1536, out_features=384, bias=True)\n", + " (dwconv): DWConv(\n", + " (dwconv): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=384, out_features=1536, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ffn_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (drop_path): DropPath()\n", + " )\n", + " )\n", + " (2): InteractionBlockWithCls(\n", + " (injector): Injector(\n", + " (query_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (feat_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MSDeformAttn(\n", + " (sampling_offsets): Linear(in_features=1536, out_features=576, bias=True)\n", + " (attention_weights): Linear(in_features=1536, out_features=288, bias=True)\n", + " (value_proj): Linear(in_features=1536, out_features=768, bias=True)\n", + " (output_proj): Linear(in_features=768, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (extractor): Extractor(\n", + " (query_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (feat_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MSDeformAttn(\n", + " (sampling_offsets): Linear(in_features=1536, out_features=192, bias=True)\n", + " (attention_weights): Linear(in_features=1536, out_features=96, bias=True)\n", + " (value_proj): Linear(in_features=1536, out_features=768, bias=True)\n", + " (output_proj): Linear(in_features=768, out_features=1536, bias=True)\n", + " )\n", + " (ffn): ConvFFN(\n", + " (fc1): Linear(in_features=1536, out_features=384, bias=True)\n", + " (dwconv): DWConv(\n", + " (dwconv): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=384, out_features=1536, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ffn_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (drop_path): DropPath()\n", + " )\n", + " )\n", + " (3): InteractionBlockWithCls(\n", + " (injector): Injector(\n", + " (query_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (feat_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MSDeformAttn(\n", + " (sampling_offsets): Linear(in_features=1536, out_features=576, bias=True)\n", + " (attention_weights): Linear(in_features=1536, out_features=288, bias=True)\n", + " (value_proj): Linear(in_features=1536, out_features=768, bias=True)\n", + " (output_proj): Linear(in_features=768, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (extractor): Extractor(\n", + " (query_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (feat_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MSDeformAttn(\n", + " (sampling_offsets): Linear(in_features=1536, out_features=192, bias=True)\n", + " (attention_weights): Linear(in_features=1536, out_features=96, bias=True)\n", + " (value_proj): Linear(in_features=1536, out_features=768, bias=True)\n", + " (output_proj): Linear(in_features=768, out_features=1536, bias=True)\n", + " )\n", + " (ffn): ConvFFN(\n", + " (fc1): Linear(in_features=1536, out_features=384, bias=True)\n", + " (dwconv): DWConv(\n", + " (dwconv): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=384, out_features=1536, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ffn_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (drop_path): DropPath()\n", + " )\n", + " (extra_extractors): Sequential(\n", + " (0): Extractor(\n", + " (query_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (feat_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MSDeformAttn(\n", + " (sampling_offsets): Linear(in_features=1536, out_features=192, bias=True)\n", + " (attention_weights): Linear(in_features=1536, out_features=96, bias=True)\n", + " (value_proj): Linear(in_features=1536, out_features=768, bias=True)\n", + " (output_proj): Linear(in_features=768, out_features=1536, bias=True)\n", + " )\n", + " (ffn): ConvFFN(\n", + " (fc1): Linear(in_features=1536, out_features=384, bias=True)\n", + " (dwconv): DWConv(\n", + " (dwconv): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=384, out_features=1536, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ffn_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (drop_path): DropPath()\n", + " )\n", + " (1): Extractor(\n", + " (query_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (feat_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (attn): MSDeformAttn(\n", + " (sampling_offsets): Linear(in_features=1536, out_features=192, bias=True)\n", + " (attention_weights): Linear(in_features=1536, out_features=96, bias=True)\n", + " (value_proj): Linear(in_features=1536, out_features=768, bias=True)\n", + " (output_proj): Linear(in_features=768, out_features=1536, bias=True)\n", + " )\n", + " (ffn): ConvFFN(\n", + " (fc1): Linear(in_features=1536, out_features=384, bias=True)\n", + " (dwconv): DWConv(\n", + " (dwconv): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384)\n", + " )\n", + " (act): GELU(approximate='none')\n", + " (fc2): Linear(in_features=384, out_features=1536, bias=True)\n", + " (drop): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (ffn_norm): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)\n", + " (drop_path): DropPath()\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (up): ConvTranspose2d(1536, 1536, kernel_size=(2, 2), stride=(2, 2))\n", + " (norm1): SyncBatchNorm(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (norm2): SyncBatchNorm(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (norm3): SyncBatchNorm(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (norm4): SyncBatchNorm(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (decode_head): Mask2FormerHead(\n", + " input_transform=multiple_select, ignore_index=255, align_corners=False\n", + " (loss_decode): CrossEntropyLoss(avg_non_ignore=False)\n", + " (conv_seg): None\n", + " (dropout): Dropout2d(p=0.1, inplace=False)\n", + " (pixel_decoder): MSDeformAttnPixelDecoder(\n", + " (input_convs): ModuleList(\n", + " (0-2): 3 x ConvModule(\n", + " (conv): Conv2d(1536, 1536, kernel_size=(1, 1), stride=(1, 1))\n", + " (gn): GroupNorm(32, 1536, eps=1e-05, affine=True)\n", + " )\n", + " )\n", + " (encoder): DetrTransformerEncoder(\n", + " (layers): ModuleList(\n", + " (0-5): 6 x BaseTransformerLayer(\n", + " (attentions): ModuleList(\n", + " (0): MultiScaleDeformableAttention(\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " (sampling_offsets): Linear(in_features=1536, out_features=768, bias=True)\n", + " (attention_weights): Linear(in_features=1536, out_features=384, bias=True)\n", + " (value_proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (output_proj): Linear(in_features=1536, out_features=1536, bias=True)\n", + " )\n", + " )\n", + " (ffns): ModuleList(\n", + " (0): FFN(\n", + " (activate): ReLU(inplace=True)\n", + " (layers): Sequential(\n", + " (0): Sequential(\n", + " (0): Linear(in_features=1536, out_features=6144, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (1): Linear(in_features=6144, out_features=1536, bias=True)\n", + " (2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (dropout_layer): Identity()\n", + " )\n", + " )\n", + " (norms): ModuleList(\n", + " (0-1): 2 x LayerNorm((1536,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (postional_encoding): SinePositionalEncoding(num_feats=768, temperature=10000, normalize=True, scale=6.283185307179586, eps=1e-06)\n", + " (level_encoding): Embedding(3, 1536)\n", + " (lateral_convs): ModuleList(\n", + " (0): ConvModule(\n", + " (conv): Conv2d(1536, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (gn): GroupNorm(32, 1536, eps=1e-05, affine=True)\n", + " )\n", + " )\n", + " (output_convs): ModuleList(\n", + " (0): ConvModule(\n", + " (conv): Conv2d(1536, 1536, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (gn): GroupNorm(32, 1536, eps=1e-05, affine=True)\n", + " (activate): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (mask_feature): Conv2d(1536, 1536, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " (transformer_decoder): DetrTransformerDecoder(\n", + " (layers): ModuleList(\n", + " (0-8): 9 x DetrTransformerDecoderLayer(\n", + " (attentions): ModuleList(\n", + " (0-1): 2 x MultiheadAttention(\n", + " (attn): MultiheadAttention(\n", + " (out_proj): NonDynamicallyQuantizableLinear(in_features=1536, out_features=1536, bias=True)\n", + " )\n", + " (proj_drop): Dropout(p=0.0, inplace=False)\n", + " (dropout_layer): Identity()\n", + " )\n", + " )\n", + " (ffns): ModuleList(\n", + " (0): FFN(\n", + " (activate): ReLU(inplace=True)\n", + " (layers): Sequential(\n", + " (0): Sequential(\n", + " (0): Linear(in_features=1536, out_features=6144, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (1): Linear(in_features=6144, out_features=1536, bias=True)\n", + " (2): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (dropout_layer): Identity()\n", + " )\n", + " )\n", + " (norms): ModuleList(\n", + " (0-2): 3 x LayerNorm((1536,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " )\n", + " (post_norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (decoder_input_projs): ModuleList(\n", + " (0-2): 3 x Identity()\n", + " )\n", + " (decoder_positional_encoding): SinePositionalEncoding(num_feats=768, temperature=10000, normalize=True, scale=6.283185307179586, eps=1e-06)\n", + " (query_embed): Embedding(100, 1536)\n", + " (query_feat): Embedding(100, 1536)\n", + " (level_embed): Embedding(3, 1536)\n", + " (cls_embed): Linear(in_features=1536, out_features=151, bias=True)\n", + " (mask_embed): Sequential(\n", + " (0): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Linear(in_features=1536, out_features=1536, bias=True)\n", + " (3): ReLU(inplace=True)\n", + " (4): Linear(in_features=1536, out_features=1536, bias=True)\n", + " )\n", + " (loss_cls): CrossEntropyLoss(avg_non_ignore=False)\n", + " (loss_mask): CrossEntropyLoss(avg_non_ignore=False)\n", + " (loss_dice): DiceLoss()\n", + " )\n", + ")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import dinov2.eval.segmentation_m2f.models.segmentors\n", + "\n", + "CONFIG_URL = f\"{DINOV2_BASE_URL}/dinov2_vitg14/dinov2_vitg14_ade20k_m2f_config.py\"\n", + "CHECKPOINT_URL = f\"{DINOV2_BASE_URL}/dinov2_vitg14/dinov2_vitg14_ade20k_m2f.pth\"\n", + "\n", + "cfg_str = load_config_from_url(CONFIG_URL)\n", + "cfg = mmcv.Config.fromstring(cfg_str, file_format=\".py\")\n", + "\n", + "model = init_segmentor(cfg)\n", + "load_checkpoint(model, CHECKPOINT_URL, map_location=\"cpu\")\n", + "model.cuda()\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "id": "53c0309f-df2b-4912-bca5-e57d8b3875b3", + "metadata": {}, + "source": [ + "## Semantic segmentation on sample image" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f4abb13b-0e5a-4a40-8d44-21da4286ba7d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/private/home/plabatut/.conda/envs/dinov2-extras-conda/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1678402374358/work/aten/src/ATen/native/TensorShape.cpp:3483.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAAfKElEQVR4nO3dO4xlR3rY8WqBomFvSueKjA7W4KJ7Q2Kzuduhm1Lq1JMZAwECpgWGlIeAAGPgbJQ6lcy0tydbKNwZgLCCgSMqNlM5GCy2FZz15eXt+zjv+qrq9wMDPoY9px9z//erOo+LXz+mLN58OuY3fv7x4v7iv8xyAP/8p29m+Tgs5JcfL3IfQnl+SP+u+5t/uPiX7m+ePw74Mr65ePzy0z8/8Qv+18d/6Pmh/ufVdfc33/7Tn/U/gD7u/vI/nP4F/+ObV/P+jozz280nuQ9hJX9/9Vfj/sc/mfc4AIA+sr1Def7xYtwQDGQxYvyd3e3Pv08P3+/9y8vNZvcf/+vLO0Mwa/qL9387bggubIngzaePKT3PfRQQ1Gfp/21XoUd4/njx5b/5cuIx7NX39uffz7UKffvz7w/++w8PD7v/uNdjCKuwAM/FBjAs4eDsu23wtqBDk3wsvQd9eHhIyQTMqoYOwa/++/9JzQYYmN2Jlee9gp4O6m6eB6UXitDVNzUSYPMuLG3GfV/RpUQjdoJrDrDu0pqeG8BvLv54/uPeRUpvLh7TycuQjlnurKsRHt6/y30INGpog3NehvR8+IWePf+Xf/7TN+oLB23rO5dQ9YW8/uL93579Ndtr2cu7Dvjm8e9O/wLphRFmD3MWxl8KUuQS9M3j3z29H5buwvqMv7Cnz0J0NwQXNgFv792xOwdbcIb+jt2cctBNK2My/sbRzn0op8gc4BHbwHukF4Z6/njR/TXx4xh/4aA+O8Gp0CXojvRCRtHqa/ylOPmXoIcOwW8+fez+Wuh4oAJ//vizY//pzcXj9q/RH1994bQ+Q3D+AAOrqeNU5z3qS6EEGNo1rsehxl/1DcgZWD0JMADM7+wqdIgATz8Xmvr8zjb/cP/p+nL718E5de/M591/7DkNG3+hv9MNDhFgYDVTLkBSXxjqRIOt1BPX7z59/KXVkQlO9PLAYxgmfDRgBAGGSnxxfXX6F1R5CjTROAOrP0vQUL+JV/0CUxxbhfZWhdCsQk/0xfXVf+69V/rlqIcBl2VzfCHdpjIrizIBOxEaIrv9+fe5D+EnlojliTbDEkzAULn+J099+U9/tuSB5Hc2sZura3MwqwkU4OcfL9zhGaKJNvuO03+63f5KJWZpgQIMjHb2FOhxWqvvwf9LiVlIlD1gILu93Eau75r7tfaGWYgAAwdErm9/m6vrWfKpwSxBgIF9RdT3dBTnSu9yHxDsAQM/KiK9W+sX0WnSzCjWBOxqYCA4czBziRVggPgsRx/jRtCDCDDAGDJMT39/9VcH/70AA4ynwYwmwACTaDDjhAuw87DY8zs3KCU8DWYEG+YAM/CgQw46tgGcBBhgaW4rzUHhlqCTVWigUnWvVLsG6akT42+KGWDY9UtvyM5Z6FFILKHuBrPrdH2TJWiAle022KJ0y0zAANkYiFsmwFA268+lq6PBNoD3nF1/TpagoWjqWwenSbcpaICff7x44/YL5XjxH+d8C//6f3sNolF707Ae1y1ogIlj3rjG/B0LZfytXtdjGS5On/XnlNLFr0POmcbfeVWQtH989z73IUShu22KnGEbwHsEOKUqwgPQCdtgAd7Vs76p4iVo6QUqs7m6Dthg9R0t4hdu4vgrvQDEV9t1wOoLQBFqCzAA5NJ/AzgFXIKesv5s/AUYqs+tuB7e/fsf/+HifsGjaUm4AI+mvgCDjLwL5uPNtsHOwJoi1tdu3PgrvQCDDE3v5vr/pr05mMliBXgo6QUYZMqzH37M8M4QzGiBTsJy9yuARc3y5KUuw+nxZvqHqsygM7BSnACPqK/xF2hKnLtw/LHBTBMlwACcEKe+HQ3eM3T8TQIMABONqG8KEmC7vwBLm2UDmBmFCPAINoCBdkRbf+64CHii/AE2/gJQrnHrzyl7gN15A2AFS6w/f/3D57N/zKbkn4ABoFCjx99U4p2wjL8A/Tn3KiwTMEC11HdRU8bflH0Cfv7xYtA2sPEXoA/pja+kCVh9AfpQ3yLkD/DzjxfPP16c/WXqC0AcE9efU/Yl6K2uwXvL0aILMJTxtxRRAtx5/vHi315f5T4KAFhcoADfX/z10395e+WZzw359v3+E0b9AEDM+1Ay3YIB3g3qzeN/O9jXs56+IicvyjU6+I3u/r1vN4T19Q+ff/XZd7mPolQXv558J+ZxZV2CV+qyHIvuQb65tGnE+LvyHnCbAZ5+BlYaMQHHye1TFjDDGtTaYx/BdxOoSa8AR47uad3rvhfuvKbXF+jJKdArmGX8TWcDXG56d+0GQIyXs3RovZ2iNU6/Cmiu+qaUPqkjsf05q+ugvQXeyDPr9th814CiBboMKaPGX9O7Tz9ydA8yEANFy38rylC+fX9TXIcmau3zBRhtxvXnZAI+qNaBuMrWOjsaKJQAn7JXrFAv9FXWdBxr0UCJLEEPEGeBOshhAMtxCnT1BHiw7BlW34N8WcBFwIuadwM4WYIeLcvqtMacZi2aahh/o5m9vkmA5zIljT2Dob49OS2L0qlvIyxB59dnTVt9B/HlAuIzAUdxbPlUS4A+bAAvZ4n15yTA0cjtXOwHA8FZgqZm3tAAYQkwAGQgwFTOEAzEJMDUL/u9U6A/1yC1Q4BphQYDIyx0CnQSYADIwmVINMRNsthd4K3pwtmaPpd2mIBpi4Xolu1trz68f2fDlYwEmOZocJuOtVaDOWG5DeAkwLRJg9llFOagReubBJhmaXBT+vS13AbbAC6UANMuDWZPuQ1mdkuPv0mAaZwGUzrjb7lchgTwo+0QLGwszQQMcEARZ2Z5l1A0AaZ1VqE5IX6D83p4/+5XD7/PfRTzW2EDOAkwJA2u3cSIrtngQntfZYNXIMCQkgZz0jrL0YXWtz5f/OYX6/xGAgzUbMaqCWQLVqtvEmDYMgTXZ/ZkLtfgEute4jH38eLu2Tq/0Sd3L//w6hsZBiqxdBUe3r9r6tzjnl/PXz38/rebGq5rXa2+qbsO+O7lH1JKMgyUbp2ZrPprhWsdbaP5MbpdhgEKtX42irhWeKjRn5FzoYf6ydSrwTTONjAj1NHg7s1EHZ/LaOtc/rtVw5I9QF7BN4Ybz2pY+/u+hmCAEcKOjzGPinTwMiQNBkoUYQZds3Z9Pl/17W/l9ed0bAnatUlA9Z4GTK5Y09E9YA2mNbdX97kPgcy2SZ5S4jgXKXk/0d8Xv/nFi7u1f1OJhZTUtz2n6zhLOxfdFc5e98qsef+NLQEGOGCuwuUaQ42//a2/+9sRYDD+NqdnXAttcNjzsdnjOmCgHpur62PtGVfT7v+a3rPuIyy6biy6xTk1AbseCajA5up6YvkCjsJ7H0p9R8u1/pxMwEBlFpoy5xqFDxr3YYPffouzzgTYxUgAW9OvU5q3mhkH393Potz5O8vVR1viCjDY9GXtTqEnTO197kUP4q9fvc31W58PsJ1ggBl1xS00vanw3O764je/yHsAJmCAkUbPwYWmNx2v7+bq+rcbJxUNI8DgMcBMUs1EeNpcq+7RZLkHVscbFkgppW/f37gdB6OduP64LFUmNiwTMPyROZhmddPtxPp+/cPncx3PCrJvAKeeE7CLkQDqYMbdlXH9OfVfgtZggGPCrj/L7Ql565ssQcMuq9CMELO+Wc6ZKmsVOrsBAXZBMC3QYPoLey2vwbcIzoIGOCVmYk/LeJvor3/4/KvPvsvyWxdnWIDtBNMClyQVpMQ6kl2EU6CTCRgoke4yWpD6JgEGCqK7/XVfqywL0VahexJgOMAqdByiO8XeV8/JWXHG3zQiwLaBgf7kM5Sn346Fkvyrh997NsNZvkDAIqS3CAutVG+uroM3OPtdONK4G3G4IJgWuCCYdni3lEXctycArOZgg6dMxgH3m0NtAKfRt6I0BNMCQ/BoJqo6dLf6Gn3Dr7B3poyw/pzcCxqAPsLed3OoIPVNlqAB6G+186hn160/x6lvmhJg1yMBHNMnS3UMlEUkOdrub8cEDDCbQe3Z/uI6SryV8SZcZZkUYEMw0DiZOabL8ENKca4GDrX+nKZPwBpM3dyTsm5l5bPQiflXD79PWTMcc/05zbIErcHArnXyUFY7T9tcXZ/+ou19sgc/9+BVDn5jrCx8OYDZrNCAmrrbU89PuftlwTOcUbT15+Q6YGAuS7/0b66u1Xf2X7+mbi06i4D1TXNNwFahoWWj0xu5FuXa+6qGmoktRO/yhQAmGff6Lr2r2f1SR4jxyg0OewZWsgQN0I4g73syrkWHMluAPZ4BGhRhomKQprbSI4+/yRI0QIMKvaS4MgIM0K5cJV5hJzj4+Jvm3QO2Cg300c4SaEGaWpoOwklYwHhesivTZbiyb2vMi4DT7EvQLggGqMCJBs+1WL3oKnT89edkAgZgkBnn4xWuRwo7/iYBBmCoUhocub5pibOgrUIDx1S2ubi0yF+uaDfY2lXE+nNa6DIkDYY2RQ4Gy3n6fc+Y5FLqm5ZbgnZJErRGfUeLNkFOF+FU6uDrz8keMDBR9zqb/dWWgCJkODJ3wgKm8iI70ebqur4heOv0j8dXn3034+9V0PpzMgEDkNHXP3y+xIeNv/6cFg2wbWAAzlqowfGZgAGoShHjbxJgALKbcQgupb5JgAGowxe/+UVB9U1LB9g2MBX49v1N7kMAKrT4BKzBAJw2/WKksi5A6liCBiCnGev7+tXbqUezIjfiACCPee/CUZw1JmCr0JTONjDMbq76lrj43LEEDb1oMMxoodm3rLOgLUEDsJ7lbv5cVn3TagH2hGAAztb34B05jv1f2/oWl96OCRggs4ofhbRr9Ox7sMr37y6mHU5+602lTsUCaFaf+o67IWWh429a+SQsDQZoUOOXGx1jXxaABfWsb4MPJVw7wIZgAEgmYACWY/H5BGdBAzC/QeltcP05mYAB8qryGqSl61vBNUgpS4BtAwOQUvr6h8/bnH07eSZgDQZINY6/X332ndOee7IEDcAM+qc35ajvwUcF531+cLYAG4IBTthcXec+hAHin3IV8IZZJmAAJsl4rdHErOatsgAD5FHHBvDQ+k4ff+s4BTrlDbBVaICirV/fmrgRB/Rye3Wf+xAglsinXBUh8xK0IZgiqC/syVXfatafU/YAJw0GKI3Zdxb5AwzBGX9ZQh1nYJ02e31rGn9TkAAbgoGmnK1v5IuAPeBoLk7CAljQw/t325pWMPVafJ6RAMMp1p+Zoivu0O6GHX/Vd14hlqAB6lPBvDuO+vYUJcC2gQnI+Mto9dU3+zOOKjsDK1mChmPUl9HarO/6g2/A5ysMIsAAc6qsvs55Xk6UJehkFRog8BlYpy09/j5dfy59/E2hApw0mDCsP0MKsO97TAX1TdECnLzwAcTQc993hfrujb911DcFDHDSYKBY1WwAh936raa+KeBJWJebzYeHh1k+1NOQf/v+ZpaPTN28BaRxYVeeKxNxAp7FwdfQ26t7r62c5icE+litvvVd/rtVbYBP8ArLMX42IOYlv1WKGODLzWbp06GNwsASCr2IaCvs1m+Vwu0Bz6JnXLe/zN4wyfhLAEX02/g7l4gTcGfNa4K98gLZZa9v/PG3plOgU8wAv3719vWrt2ndLmpw4/wAkNHm6jpvfb/67DtnPq8vYoBf3D3bvs3RYKBuBQ2+69d39xToysbfVOse8Gi7DbYx3A7vvchCehsXOsCjb8rx7fub6S+pt1f3GgwMtbm6Png/rOy53RO/vhVfAdyJuAQdh8EImIX6DlV9fVPwCbiTdxK1KA0M1eV2OwerLwdFD3C3Ct1VMHv/Dg7E2Y8KiCladzvqG4cl6KncVAsoRSn1fbr+XN8p0KmIAF9uNt3fRO5c5GMDSOXUtx3Rl6ALstdgS9NAHAXVt5HxNxUxAaedIbi/7P2zNA0E0f8uV9nr25QyArw1KGnZGwyQXVn3mNwbf3dvjFifwgI8VPYGG4KBjOI/X+GEitPbKSbAl5tNtxBdXNKsRQNZFLTv22nh5hu7ignwrhKTVtwBA+Xq/3SjFKa+e6off1OhAR4k+yr0lgYDKxi07Byzvo0oLMC7p0P375kGA42oo74tjL+puACPpsFA9Yqub2sbwKn0ABd6VZIGA7MbtOkbrb67Ghl/U+kBTsU2GGBGRV9ulJocf1MFAU7DGxwhw4ZgIIvIs29qafxNdQQ4De9ZkAbLcAS+C5Su9DtNtjn+pmoCnMpscPLqD0xT1p0m2VXe05AuN5sPDw8H/9Pt1f2grA5t8EKxHHrYAIOUUt+m1p9TTRNwZ9GBcrlMmoOBEfqMv8Hr2+z6cyo0wCOeThifBgODVFDfxpW3BN2ZcSE6jr0GF/pZACuor76trT+nQifgs5abJteMopkYOKi++rbp4uXLl7mPYbxjQ/DW7L3MFUXT8EK8y6E4NdV3uwHc4Pibap2At6q51raaTyQUX1KKU1N9KXUPuHP/7uLm+jH3UVAk9aUs9V3v2/j4myqYgPucwj7XS232l2xzMLSpvvqSSg9w//dNE9MVqnxxjqRcvoZUJuxtJjmh7CXolNKLu2cf0plTsba6l90gN8CaotxLrbIL+N2EE2rd9G35/htbZU/AnaH35ajjJbiOzwKYooLBt9kN4FTBBDzOiVteFBQ2c/AgBX1noVP6U345rdEA7yn3pVmDzyr3m0vjzta39NmXGpagGycwJ/jiUCj1bUElE7ALgnlKfSnU6fpKbzVMwDVQmqd8TahSZfVt+QysVFOAGz+pXW+gDsfG3wpOeN7qXq4br2+qJsC+kezydoTKVJPe1PywtKuSAANUwHVHTRHgehj7oGgH61vTynPyAIafqifAL+6eDb0lFkBkNaWXp+oJMEC5no6/6lu9Sq4DBqhGrel1+tWe2gJ8udl8eOj7cKT6RLsz5elt6VCHCnl9/cPnX332Xa3p3WMDuFNbgFPzDQ6izxlh8gy71Lc1Fe4Bv371tuWFjmrOhb69uq/mcwHu3128uHumvrsqnIBf3D17/ept7qPIKftC9Izh7D6UaZjs9n4IvTtkugon4OSSpOpeHUzDZPTt+5unbwG9KRyk5VXJEyqcgKlV9sme1pz+edv+V+8OGUeAq1Vlrnp+Ut++v/GayCBT/rAo8WnG32PqXILu+K5buYWzDq4wj/5Qs3wcGlFzgJMGAyfNnkwNpr/KA5w0uNVlMa+DnDbj4LvaRy6aC5Ceqj/ApLoa3P9z8SLIU10dV/jZkOGOEegEAW5FTfvBGsw46/88+AnsGH8PqjnA22+5t2BbDTYYOrlaaBTmmMovQ3JXLFyS1Lgg8WvzUiXDz2k1T8Cdbg72c7C1wp//dV7yBi1EB3kVZjWr7fUOFfCQFuJV96z6A7zlp2GrqffgNCh45IIf3uxsAB/TRIB9+59ausEBX2ICHhJLKOIbHXM6n5GBp4/K94D33L+7uLl+zH0Ucc376KGAm69PP7VoR8hEZVXt4NH6mWxHKwHuhmAnZO06dl/lKm8ifYxnzFWjmh/ayk7XsgB5QhNL0BxT+p/w2Y+/mhdxKlDuT+N2/Vl9T2srwC/untmZ2HOwYaWHeYrqN+eqVOu3rOifRvU9q60Ac1CX273oztLgFV47FnqvEPYiFhpU1o9iN+Sobx8tBtgQ/JQ5+JiCXvioWxEZ9uo6SIsBTn5KeqvpDtIAoTQaYAaZ0uByV6EpRfy5cEZNfbLVE2B6mTIKl/6SUfrxU5mwa9FOfh6quQB7RNIUYRtsCG5WzBStoNlPvCbNBZiJwu4KxzwqgGNaDLAhuFYLNVjaw2p8Coy5Fm39ub8WA8x0YZsU9sBgIaEarL6DNBpgQ3DFukXyuUqs6MQXocFeS0doNMBMF79M048w/ufYsgjVoaO+4wiwH52VFPdyqb4UJOZ+MKe18jhCljDv84OX8DSinsBKxbI8hNvlv6O1G+AXd888HrhBWluHyG/78srSYMZpegna+7VZ9P/T7kWTWfhBOs3XpxRNBzj9/wbbBl6NlwYm8iPUx/Z5mqt9ucwzI7S7BA3Qgm2Dl1iavtxsLu9m/6itaH0C3jIEr8YEA1n4oxeNAAO0wtVKoQiwu2JBMcRjFnNl+HKzmf5BWibAzGDo3pKXUUbwYzMvX8/sBBgogFpEc//uwq0UJhLgn7AKPZohGIrjj2FeAkw2/vBDduP+GJpVZiHAKbmEfCZugMdCvFdblC9vLgK8zzu7NfmTDzRLgMlMg6FQ1g4nEuA/2v1JMgSvzM0BgAYJ8I+8m5vONjC04MXdMy+Y0wnwYYbg9RmCgaYI8E94T5eXtWjIYtCfO/PJXAQYCM17snX4Oq9PgAFIqV+Djb8zEuCj/JwBsBwB3mcbGOKwLrqy019wY8m8BPgADQY4xivkXD7JfQCh3b+7uLl+zH0Uhbm9uje1QH0uN5vLu9wHURcT8GHe4k3hdhxQroNXA15uNlkOpm4CfIY9j3Fur+5lGMrlovwVCDALkmGAYwT4PEPwRIMaLNhsmcCCsP68EAFmDT2zqr5sqW8El5uN+i7HWdCsRFwBdpmAj/KEYMjF+EsLTMB9uSYYlqO4Mb365k/uXv4h91FUS4BPeXH37PWrt9t/1GBq0rN5y+0diC6NswR9hjty0LjuetB5Y+ka04K8+kYmluIrC/QyVzKlFzqWoIexCk0dxlVw+3+NWJfWXdgjwOft7QQDagrTWYIGgAwEuBfXBAMwLwHuy+nQQJucCL0QX9YxDMGUzp1BITsBHsAQDDTIzbAWIsAjGYIBmEKAASADAYZG3V7d2wmGjNyIA5p2sMFP77Nxe3Xf/ctjzV7i1hx7v5e7f1CZi5cvX+Y+hsLs3hXLbSnhhCnJfFp6Ac7FSVgLMQEP5s6U0FPP8brn/wiVEWBgVaPjul0Ghzo4CWuM7QXBLkaC1agvlRHgqTQYgBEEeCR3xQJgCgEGgAwEeDw7wQCMJsAAkIEAT2InGKibu3AsR4Cn6hpsFRoW5Rok6iPAs9FgoDLG30UJMABkIMBzMgQD0JMAz8CpWAAMJcAAkIEAA0AGAjwz28AA9CHA87ANDMAgAjwbt4YGauIi4KUJ8JzMwbAEt8GiSgK8CEMwAKcJ8MwMwQD0IcBLMQTDLKw/U6tPch8AwGHSS90EeH4v7p69fvU291FAUHtZvb26P/troEoCDCzrdE21lmYJ8CK6Ifj+3cXN9WPuY4H1qCn0J8BLsRBNO3QXRnAWNDCJ+sI4ArwsFyNRN/WF0SxBL85OMFWSXpjIBAwAGQgwAPs8CmkFArwgDygE4BgBBuAnjL/rcBLWslwNDKc9vRWl07vyUt/VmIAX1y1EW4WGXbdX991fx/7T+ocEKxPgNXhIMGz17KsGUz1L0OtxQTDNUtNSWH9ekwCvSoNphOLCWQIMzEx9S2T2XZ894JXYBgZglwkYGMOYWxPjbxYm4PW4HgmALQFelYVoIBrjby4CDNAu9c1IgDOwCg1EoL55CXAeGgzQOAFem21gAJIAZ2QIBjKy/pydAANABgKcgVVoOMtTgameAOfhphwAjRNgIBzjLy0Q4GwMwXCQ+tIIAQaADAQ4J0MwkMurb7z+Z+YbAAAZCHBmLkkCaJMA5/fi7tnlZpP7KABYlQADgTgFmnYIMABkIMBROBcaoCkCHIgGA7RDgIFAbq/ucx8CrESAYzEEAzRCgKNwQTBAUwQYADIQ4EAMwZBsA9MMAY7lxd0z28AALRDgiDSYxhmCaYEAA0AGAhyOhwRDSun26r77K/eBwFI+yX0AAKcMbbDHOVAKE3BEhmAYzdxMKS5evnyZ+xg46sPDQ+5DgLIZiI+5e/mH3IfQOkvQQM0OTsOqTAQCHNrlZmMIhtntVlmMycUeMNA0G8bkIsDRORULlua8LbKwBA2Q0hyjsNVsBjEBF8AQDEUwSTOIAAPMSYPpSYCj84xCKI4G04cAF0CDAeojwGW43GxyHwIwQPAh2G2wIhBggLaobxACDAAZCDAAZCDAAA2x/hyHAANABgJcBs9EAqiMAAPML+Z9oa0/hyLAADOLWV+iEWCAJhh/oxHgArx+9Tb3IQB9xRx/1TcgAY5OfYGJ1DcmAS7AzfVj7kMAeok5/hLTJ7kPAKASQeq79xwIj3IJS4ABZpClvmefuaS+kQkwwFTr1zf44w7pQ4ABxos5+FIEAQYYaeX66m5lBDg6p0BDTGvWd3R6Pzw82AYOy2VIAIMVUd+OR7mEZQIGGKCg9BKcAIfmrSvEsVx6lw6theiYBBjgvIXqu9qMq8EB2QMGOKP0+hKTAAOcUk19Pzw8eLhLKAIclw1gyC7I7Z3n4rLGUAQYoCHe2cchwACQgbOgAdrS7QS/uHuW+0Bq1melQYAB1vaz9DeblO0U6Jvrx/t3F7l+9/qMXtUXYIBV/Sz9Te5DSDfXj6/fvTUEH7TaNrkAA5DZweadvnNIBWeTCTBAi26uH1PWW2OdLWgFiT3NWdAAjcpVuA8PD9XHtQ8TMAAr0d1dJmCAdq1ZRPXdI8AAR81+x+YIp0AThADH5dlhEIFnFs3C+PuUAAM0bYVHJKnvQQIMcEb1Q/CiDVbfYwQ4Lj+1EEfFDV70GYVex05wGRJAL12Dd58QfLDKhT5C+PWree5MuTdM3+jvcQIMMMDZUXj3F5QS4+7xDBOfkrTCXnJlLEEDLOX26r77K/eBDDCuowf/r0UXtysgwACQwb8Cfzi08Mu2voAAAAAASUVORK5CYII=\n", + "text/plain": [ + "<PIL.Image.Image image mode=RGB size=640x480>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "array = np.array(image)[:, :, ::-1] # BGR\n", + "segmentation_logits = inference_segmentor(model, array)[0]\n", + "segmented_image = render_segmentation(segmentation_logits, \"ade20k\")\n", "display(segmented_image)" ] }