Skip to content
Snippets Groups Projects
Unverified Commit d5b0405e authored by Patrick Labatut's avatar Patrick Labatut Committed by GitHub
Browse files

Add semantic segmentation (linear) code (#185)

Add semantic segmentation (linear) code + demo notebook
parent d5c376b5
No related branches found
No related tags found
No related merge requests found
Showing with 943 additions and 0 deletions
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .optimizer import DistOptimizerHook
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
try:
import apex
except ImportError:
print("apex is not installed")
from mmcv.runner import OptimizerHook, HOOKS
@HOOKS.register_module()
class DistOptimizerHook(OptimizerHook):
"""Optimizer hook for distributed training."""
def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
self.update_interval = update_interval
self.use_fp16 = use_fp16
def before_run(self, runner):
runner.optimizer.zero_grad()
def after_train_iter(self, runner):
runner.outputs["loss"] /= self.update_interval
if self.use_fp16:
# runner.outputs['loss'].backward()
with apex.amp.scale_loss(runner.outputs["loss"], runner.optimizer) as scaled_loss:
scaled_loss.backward()
else:
runner.outputs["loss"].backward()
if self.every_n_iters(runner, self.update_interval):
if self.grad_clip is not None:
self.clip_grads(runner.model.parameters())
runner.optimizer.step()
runner.optimizer.zero_grad()
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .backbones import * # noqa: F403
from .decode_heads import * # noqa: F403
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .vision_transformer import DinoVisionTransformer
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from mmcv.runner import BaseModule
from mmseg.models.builder import BACKBONES
@BACKBONES.register_module()
class DinoVisionTransformer(BaseModule):
"""Vision Transformer."""
def __init__(
self,
*args,
**kwargs,
):
super().__init__()
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .linear_head import BNHead
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from mmseg.models.builder import HEADS
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.ops import resize
@HEADS.register_module()
class BNHead(BaseDecodeHead):
"""Just a batchnorm."""
def __init__(self, resize_factors=None, **kwargs):
super().__init__(**kwargs)
assert self.in_channels == self.channels
self.bn = nn.SyncBatchNorm(self.in_channels)
self.resize_factors = resize_factors
def _forward_feature(self, inputs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
# print("inputs", [i.shape for i in inputs])
x = self._transform_inputs(inputs)
# print("x", x.shape)
feats = self.bn(x)
# print("feats", feats.shape)
return feats
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if self.input_transform == "resize_concat":
# accept lists (for cls token)
input_list = []
for x in inputs:
if isinstance(x, list):
input_list.extend(x)
else:
input_list.append(x)
inputs = input_list
# an image descriptor can be a local descriptor with resolution 1x1
for i, x in enumerate(inputs):
if len(x.shape) == 2:
inputs[i] = x[:, :, None, None]
# select indices
inputs = [inputs[i] for i in self.in_index]
# Resizing shenanigans
# print("before", *(x.shape for x in inputs))
if self.resize_factors is not None:
assert len(self.resize_factors) == len(inputs), (len(self.resize_factors), len(inputs))
inputs = [
resize(input=x, scale_factor=f, mode="bilinear" if f >= 1 else "area")
for x, f in zip(inputs, self.resize_factors)
]
# print("after", *(x.shape for x in inputs))
upsampled_inputs = [
resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners)
for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == "multiple_select":
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def forward(self, inputs):
"""Forward function."""
output = self._forward_feature(inputs)
output = self.cls_seg(output)
return output
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
ADE20K_COLORMAP = [
(0, 0, 0),
(120, 120, 120),
(180, 120, 120),
(6, 230, 230),
(80, 50, 50),
(4, 200, 3),
(120, 120, 80),
(140, 140, 140),
(204, 5, 255),
(230, 230, 230),
(4, 250, 7),
(224, 5, 255),
(235, 255, 7),
(150, 5, 61),
(120, 120, 70),
(8, 255, 51),
(255, 6, 82),
(143, 255, 140),
(204, 255, 4),
(255, 51, 7),
(204, 70, 3),
(0, 102, 200),
(61, 230, 250),
(255, 6, 51),
(11, 102, 255),
(255, 7, 71),
(255, 9, 224),
(9, 7, 230),
(220, 220, 220),
(255, 9, 92),
(112, 9, 255),
(8, 255, 214),
(7, 255, 224),
(255, 184, 6),
(10, 255, 71),
(255, 41, 10),
(7, 255, 255),
(224, 255, 8),
(102, 8, 255),
(255, 61, 6),
(255, 194, 7),
(255, 122, 8),
(0, 255, 20),
(255, 8, 41),
(255, 5, 153),
(6, 51, 255),
(235, 12, 255),
(160, 150, 20),
(0, 163, 255),
(140, 140, 140),
(250, 10, 15),
(20, 255, 0),
(31, 255, 0),
(255, 31, 0),
(255, 224, 0),
(153, 255, 0),
(0, 0, 255),
(255, 71, 0),
(0, 235, 255),
(0, 173, 255),
(31, 0, 255),
(11, 200, 200),
(255, 82, 0),
(0, 255, 245),
(0, 61, 255),
(0, 255, 112),
(0, 255, 133),
(255, 0, 0),
(255, 163, 0),
(255, 102, 0),
(194, 255, 0),
(0, 143, 255),
(51, 255, 0),
(0, 82, 255),
(0, 255, 41),
(0, 255, 173),
(10, 0, 255),
(173, 255, 0),
(0, 255, 153),
(255, 92, 0),
(255, 0, 255),
(255, 0, 245),
(255, 0, 102),
(255, 173, 0),
(255, 0, 20),
(255, 184, 184),
(0, 31, 255),
(0, 255, 61),
(0, 71, 255),
(255, 0, 204),
(0, 255, 194),
(0, 255, 82),
(0, 10, 255),
(0, 112, 255),
(51, 0, 255),
(0, 194, 255),
(0, 122, 255),
(0, 255, 163),
(255, 153, 0),
(0, 255, 10),
(255, 112, 0),
(143, 255, 0),
(82, 0, 255),
(163, 255, 0),
(255, 235, 0),
(8, 184, 170),
(133, 0, 255),
(0, 255, 92),
(184, 0, 255),
(255, 0, 31),
(0, 184, 255),
(0, 214, 255),
(255, 0, 112),
(92, 255, 0),
(0, 224, 255),
(112, 224, 255),
(70, 184, 160),
(163, 0, 255),
(153, 0, 255),
(71, 255, 0),
(255, 0, 163),
(255, 204, 0),
(255, 0, 143),
(0, 255, 235),
(133, 255, 0),
(255, 0, 235),
(245, 0, 255),
(255, 0, 122),
(255, 245, 0),
(10, 190, 212),
(214, 255, 0),
(0, 204, 255),
(20, 0, 255),
(255, 255, 0),
(0, 153, 255),
(0, 41, 255),
(0, 255, 204),
(41, 0, 255),
(41, 255, 0),
(173, 0, 255),
(0, 245, 255),
(71, 0, 255),
(122, 0, 255),
(0, 255, 184),
(0, 92, 255),
(184, 255, 0),
(0, 133, 255),
(255, 214, 0),
(25, 194, 194),
(102, 255, 0),
(92, 0, 255),
]
ADE20K_CLASS_NAMES = [
"",
"wall",
"building;edifice",
"sky",
"floor;flooring",
"tree",
"ceiling",
"road;route",
"bed",
"windowpane;window",
"grass",
"cabinet",
"sidewalk;pavement",
"person;individual;someone;somebody;mortal;soul",
"earth;ground",
"door;double;door",
"table",
"mountain;mount",
"plant;flora;plant;life",
"curtain;drape;drapery;mantle;pall",
"chair",
"car;auto;automobile;machine;motorcar",
"water",
"painting;picture",
"sofa;couch;lounge",
"shelf",
"house",
"sea",
"mirror",
"rug;carpet;carpeting",
"field",
"armchair",
"seat",
"fence;fencing",
"desk",
"rock;stone",
"wardrobe;closet;press",
"lamp",
"bathtub;bathing;tub;bath;tub",
"railing;rail",
"cushion",
"base;pedestal;stand",
"box",
"column;pillar",
"signboard;sign",
"chest;of;drawers;chest;bureau;dresser",
"counter",
"sand",
"sink",
"skyscraper",
"fireplace;hearth;open;fireplace",
"refrigerator;icebox",
"grandstand;covered;stand",
"path",
"stairs;steps",
"runway",
"case;display;case;showcase;vitrine",
"pool;table;billiard;table;snooker;table",
"pillow",
"screen;door;screen",
"stairway;staircase",
"river",
"bridge;span",
"bookcase",
"blind;screen",
"coffee;table;cocktail;table",
"toilet;can;commode;crapper;pot;potty;stool;throne",
"flower",
"book",
"hill",
"bench",
"countertop",
"stove;kitchen;stove;range;kitchen;range;cooking;stove",
"palm;palm;tree",
"kitchen;island",
"computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system",
"swivel;chair",
"boat",
"bar",
"arcade;machine",
"hovel;hut;hutch;shack;shanty",
"bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle",
"towel",
"light;light;source",
"truck;motortruck",
"tower",
"chandelier;pendant;pendent",
"awning;sunshade;sunblind",
"streetlight;street;lamp",
"booth;cubicle;stall;kiosk",
"television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box",
"airplane;aeroplane;plane",
"dirt;track",
"apparel;wearing;apparel;dress;clothes",
"pole",
"land;ground;soil",
"bannister;banister;balustrade;balusters;handrail",
"escalator;moving;staircase;moving;stairway",
"ottoman;pouf;pouffe;puff;hassock",
"bottle",
"buffet;counter;sideboard",
"poster;posting;placard;notice;bill;card",
"stage",
"van",
"ship",
"fountain",
"conveyer;belt;conveyor;belt;conveyer;conveyor;transporter",
"canopy",
"washer;automatic;washer;washing;machine",
"plaything;toy",
"swimming;pool;swimming;bath;natatorium",
"stool",
"barrel;cask",
"basket;handbasket",
"waterfall;falls",
"tent;collapsible;shelter",
"bag",
"minibike;motorbike",
"cradle",
"oven",
"ball",
"food;solid;food",
"step;stair",
"tank;storage;tank",
"trade;name;brand;name;brand;marque",
"microwave;microwave;oven",
"pot;flowerpot",
"animal;animate;being;beast;brute;creature;fauna",
"bicycle;bike;wheel;cycle",
"lake",
"dishwasher;dish;washer;dishwashing;machine",
"screen;silver;screen;projection;screen",
"blanket;cover",
"sculpture",
"hood;exhaust;hood",
"sconce",
"vase",
"traffic;light;traffic;signal;stoplight",
"tray",
"ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin",
"fan",
"pier;wharf;wharfage;dock",
"crt;screen",
"plate",
"monitor;monitoring;device",
"bulletin;board;notice;board",
"shower",
"radiator",
"glass;drinking;glass",
"clock",
"flag",
]
VOC2012_COLORMAP = [
(0, 0, 0),
(128, 0, 0),
(0, 128, 0),
(128, 128, 0),
(0, 0, 128),
(128, 0, 128),
(0, 128, 128),
(128, 128, 128),
(64, 0, 0),
(192, 0, 0),
(64, 128, 0),
(192, 128, 0),
(64, 0, 128),
(192, 0, 128),
(64, 128, 128),
(192, 128, 128),
(0, 64, 0),
(128, 64, 0),
(0, 192, 0),
(128, 192, 0),
(0, 64, 128),
]
VOC2012_CLASS_NAMES = [
"",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment