From 5c053f2285466cb884d148dc54489ec654327668 Mon Sep 17 00:00:00 2001
From: manurare <mreyarea@gmail.com>
Date: Mon, 3 Jun 2024 09:26:42 +0200
Subject: [PATCH] Add zoedepth persp monodepth
---
code/python/requirements.txt | 2 +-
code/python/src/main.py | 2 +-
code/python/src/utility/depthmap_utils.py | 41 +++++++++++++++++++++--
3 files changed, 40 insertions(+), 5 deletions(-)
diff --git a/code/python/requirements.txt b/code/python/requirements.txt
index 7de70ff..3b9acc2 100644
--- a/code/python/requirements.txt
+++ b/code/python/requirements.txt
@@ -6,4 +6,4 @@ opencv-python>=4.5.1.48
torch>=1.8.1
torchvision>=0.9.1
wheel>=0.36.2
-timm>=0.4.12
\ No newline at end of file
+timm==0.6.7
\ No newline at end of file
diff --git a/code/python/src/main.py b/code/python/src/main.py
index 4cbce87..b16bf6b 100644
--- a/code/python/src/main.py
+++ b/code/python/src/main.py
@@ -97,7 +97,7 @@ class Options():
parser.add_argument("--padding", type=float, default="0.3")
parser.add_argument("--multires_levels", type=int, default=1, help="Levels of multi-resolution pyramid. If > 1"
"then --grid_size is the lowest resolution")
- parser.add_argument("--persp_monodepth", type=str, default="midas2", choices=["midas2", "midas3", "boost"])
+ parser.add_argument("--persp_monodepth", type=str, default="midas2", choices=["midas2", "midas3", "boost", "zoedepth"])
parser.add_argument('--depthalignstep', type=int, nargs='+', default=[1, 2, 3, 4])
parser.add_argument("--rm_debug_folder", default=True, action='store_false')
parser.add_argument("--intermediate_data", default=False, action='store_true', help="save intermediate data"
diff --git a/code/python/src/utility/depthmap_utils.py b/code/python/src/utility/depthmap_utils.py
index 8604942..dc2f4ff 100644
--- a/code/python/src/utility/depthmap_utils.py
+++ b/code/python/src/utility/depthmap_utils.py
@@ -5,6 +5,9 @@ from skimage.transform import pyramid_gaussian
from PIL import Image
import numpy as np
+import torch
+from torchvision import transforms
+from torchvision.utils import make_grid
from struct import unpack
import os
@@ -171,6 +174,8 @@ def run_persp_monodepth(rgb_image_data_list, persp_monodepth, use_large_model=Tr
return MiDaS_torch_hub_data(rgb_image_data_list, persp_monodepth, use_large_model=use_large_model)
if persp_monodepth == "boost":
return boosting_monodepth(rgb_image_data_list)
+ if persp_monodepth == "zoedepth":
+ return zoedepth_monodepth(rgb_image_data_list)
def MiDaS_torch_hub_data(rgb_image_data_list, persp_monodepth, use_large_model=True):
@@ -182,7 +187,6 @@ def MiDaS_torch_hub_data(rgb_image_data_list, persp_monodepth, use_large_model=T
:param use_large_model: the MiDaS model type.
:type use_large_model: bool, optional
"""
- import torch
# 1)initial PyTorch run-time environment
if use_large_model:
@@ -248,7 +252,6 @@ def MiDaS_torch_hub_file(rgb_image_path, use_large_model=True):
:type use_large_model: bool, optional
"""
import cv2
- import torch
# import urllib.request
# import matplotlib.pyplot as plt
@@ -298,7 +301,6 @@ def boosting_monodepth(rgb_image_data_list):
# Load merge network
import cv2
import argparse
- import torch
import warnings
warnings.simplefilter('ignore', np.RankWarning)
@@ -541,6 +543,39 @@ def boosting_monodepth(rgb_image_data_list):
return depthmaps
+@torch.no_grad()
+def zoedepth_monodepth(rgb_image_data_list):
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
+ torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True)
+
+ tfoms = transforms.Compose([transforms.ToTensor()])
+
+ repo = "isl-org/ZoeDepth"
+ model_zoe = torch.hub.load(repo, "ZoeD_NK", pretrained=True)
+ model_zoe = model_zoe.to(device)
+ model_zoe.eval()
+
+ depthmaps = []
+ for img in rgb_image_data_list:
+ img_t = tfoms(img / 255.).unsqueeze(0).type(torch.float32).to(device)
+ out = model_zoe(img_t)['metric_depth']
+
+ out = torch.nn.functional.interpolate(
+ out,
+ size=img.shape[:2],
+ mode="nearest-exact",
+ ).squeeze(0)
+
+ if torch.any(out < 0):
+ log.warn("Negative depth value")
+ out = torch.clamp(out, min=1e-6)
+ depthmaps.append(out)
+
+ del model_zoe
+ # grid = make_grid(depthmaps, nrow=5)[0]
+ return [depth2disparity(d.squeeze().cpu().numpy()) for d in depthmaps]
+
def read_dpt(dpt_file_path):
"""read depth map from *.dpt file.
--
GitLab