From 5c053f2285466cb884d148dc54489ec654327668 Mon Sep 17 00:00:00 2001 From: manurare <mreyarea@gmail.com> Date: Mon, 3 Jun 2024 09:26:42 +0200 Subject: [PATCH] Add zoedepth persp monodepth --- code/python/requirements.txt | 2 +- code/python/src/main.py | 2 +- code/python/src/utility/depthmap_utils.py | 41 +++++++++++++++++++++-- 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/code/python/requirements.txt b/code/python/requirements.txt index 7de70ff..3b9acc2 100644 --- a/code/python/requirements.txt +++ b/code/python/requirements.txt @@ -6,4 +6,4 @@ opencv-python>=4.5.1.48 torch>=1.8.1 torchvision>=0.9.1 wheel>=0.36.2 -timm>=0.4.12 \ No newline at end of file +timm==0.6.7 \ No newline at end of file diff --git a/code/python/src/main.py b/code/python/src/main.py index 4cbce87..b16bf6b 100644 --- a/code/python/src/main.py +++ b/code/python/src/main.py @@ -97,7 +97,7 @@ class Options(): parser.add_argument("--padding", type=float, default="0.3") parser.add_argument("--multires_levels", type=int, default=1, help="Levels of multi-resolution pyramid. If > 1" "then --grid_size is the lowest resolution") - parser.add_argument("--persp_monodepth", type=str, default="midas2", choices=["midas2", "midas3", "boost"]) + parser.add_argument("--persp_monodepth", type=str, default="midas2", choices=["midas2", "midas3", "boost", "zoedepth"]) parser.add_argument('--depthalignstep', type=int, nargs='+', default=[1, 2, 3, 4]) parser.add_argument("--rm_debug_folder", default=True, action='store_false') parser.add_argument("--intermediate_data", default=False, action='store_true', help="save intermediate data" diff --git a/code/python/src/utility/depthmap_utils.py b/code/python/src/utility/depthmap_utils.py index 8604942..dc2f4ff 100644 --- a/code/python/src/utility/depthmap_utils.py +++ b/code/python/src/utility/depthmap_utils.py @@ -5,6 +5,9 @@ from skimage.transform import pyramid_gaussian from PIL import Image import numpy as np +import torch +from torchvision import transforms +from torchvision.utils import make_grid from struct import unpack import os @@ -171,6 +174,8 @@ def run_persp_monodepth(rgb_image_data_list, persp_monodepth, use_large_model=Tr return MiDaS_torch_hub_data(rgb_image_data_list, persp_monodepth, use_large_model=use_large_model) if persp_monodepth == "boost": return boosting_monodepth(rgb_image_data_list) + if persp_monodepth == "zoedepth": + return zoedepth_monodepth(rgb_image_data_list) def MiDaS_torch_hub_data(rgb_image_data_list, persp_monodepth, use_large_model=True): @@ -182,7 +187,6 @@ def MiDaS_torch_hub_data(rgb_image_data_list, persp_monodepth, use_large_model=T :param use_large_model: the MiDaS model type. :type use_large_model: bool, optional """ - import torch # 1)initial PyTorch run-time environment if use_large_model: @@ -248,7 +252,6 @@ def MiDaS_torch_hub_file(rgb_image_path, use_large_model=True): :type use_large_model: bool, optional """ import cv2 - import torch # import urllib.request # import matplotlib.pyplot as plt @@ -298,7 +301,6 @@ def boosting_monodepth(rgb_image_data_list): # Load merge network import cv2 import argparse - import torch import warnings warnings.simplefilter('ignore', np.RankWarning) @@ -541,6 +543,39 @@ def boosting_monodepth(rgb_image_data_list): return depthmaps +@torch.no_grad() +def zoedepth_monodepth(rgb_image_data_list): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + torch.hub._validate_not_a_forked_repo=lambda a,b,c: True + torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True) + + tfoms = transforms.Compose([transforms.ToTensor()]) + + repo = "isl-org/ZoeDepth" + model_zoe = torch.hub.load(repo, "ZoeD_NK", pretrained=True) + model_zoe = model_zoe.to(device) + model_zoe.eval() + + depthmaps = [] + for img in rgb_image_data_list: + img_t = tfoms(img / 255.).unsqueeze(0).type(torch.float32).to(device) + out = model_zoe(img_t)['metric_depth'] + + out = torch.nn.functional.interpolate( + out, + size=img.shape[:2], + mode="nearest-exact", + ).squeeze(0) + + if torch.any(out < 0): + log.warn("Negative depth value") + out = torch.clamp(out, min=1e-6) + depthmaps.append(out) + + del model_zoe + # grid = make_grid(depthmaps, nrow=5)[0] + return [depth2disparity(d.squeeze().cpu().numpy()) for d in depthmaps] + def read_dpt(dpt_file_path): """read depth map from *.dpt file. -- GitLab