From 5c053f2285466cb884d148dc54489ec654327668 Mon Sep 17 00:00:00 2001
From: manurare <mreyarea@gmail.com>
Date: Mon, 3 Jun 2024 09:26:42 +0200
Subject: [PATCH] Add zoedepth persp monodepth

---
 code/python/requirements.txt              |  2 +-
 code/python/src/main.py                   |  2 +-
 code/python/src/utility/depthmap_utils.py | 41 +++++++++++++++++++++--
 3 files changed, 40 insertions(+), 5 deletions(-)

diff --git a/code/python/requirements.txt b/code/python/requirements.txt
index 7de70ff..3b9acc2 100644
--- a/code/python/requirements.txt
+++ b/code/python/requirements.txt
@@ -6,4 +6,4 @@ opencv-python>=4.5.1.48
 torch>=1.8.1
 torchvision>=0.9.1
 wheel>=0.36.2
-timm>=0.4.12
\ No newline at end of file
+timm==0.6.7
\ No newline at end of file
diff --git a/code/python/src/main.py b/code/python/src/main.py
index 4cbce87..b16bf6b 100644
--- a/code/python/src/main.py
+++ b/code/python/src/main.py
@@ -97,7 +97,7 @@ class Options():
         parser.add_argument("--padding", type=float, default="0.3")
         parser.add_argument("--multires_levels", type=int, default=1, help="Levels of multi-resolution pyramid. If > 1"
                                                                            "then --grid_size is the lowest resolution")
-        parser.add_argument("--persp_monodepth", type=str, default="midas2", choices=["midas2", "midas3", "boost"])
+        parser.add_argument("--persp_monodepth", type=str, default="midas2", choices=["midas2", "midas3", "boost", "zoedepth"])
         parser.add_argument('--depthalignstep', type=int, nargs='+', default=[1, 2, 3, 4])
         parser.add_argument("--rm_debug_folder", default=True, action='store_false')
         parser.add_argument("--intermediate_data", default=False, action='store_true', help="save intermediate data"
diff --git a/code/python/src/utility/depthmap_utils.py b/code/python/src/utility/depthmap_utils.py
index 8604942..dc2f4ff 100644
--- a/code/python/src/utility/depthmap_utils.py
+++ b/code/python/src/utility/depthmap_utils.py
@@ -5,6 +5,9 @@ from skimage.transform import pyramid_gaussian
 
 from PIL import Image
 import numpy as np
+import torch
+from torchvision import transforms
+from torchvision.utils import make_grid
 
 from struct import unpack
 import os
@@ -171,6 +174,8 @@ def run_persp_monodepth(rgb_image_data_list, persp_monodepth, use_large_model=Tr
         return MiDaS_torch_hub_data(rgb_image_data_list, persp_monodepth, use_large_model=use_large_model)
     if persp_monodepth == "boost":
         return boosting_monodepth(rgb_image_data_list)
+    if persp_monodepth == "zoedepth":
+        return zoedepth_monodepth(rgb_image_data_list)
 
 
 def MiDaS_torch_hub_data(rgb_image_data_list, persp_monodepth, use_large_model=True):
@@ -182,7 +187,6 @@ def MiDaS_torch_hub_data(rgb_image_data_list, persp_monodepth, use_large_model=T
     :param use_large_model: the MiDaS model type.
     :type use_large_model: bool, optional
     """
-    import torch
 
     # 1)initial PyTorch run-time environment
     if use_large_model:
@@ -248,7 +252,6 @@ def MiDaS_torch_hub_file(rgb_image_path, use_large_model=True):
     :type use_large_model: bool, optional
     """
     import cv2
-    import torch
     # import urllib.request
 
     # import matplotlib.pyplot as plt
@@ -298,7 +301,6 @@ def boosting_monodepth(rgb_image_data_list):
     # Load merge network
     import cv2
     import argparse
-    import torch
     import warnings
     warnings.simplefilter('ignore', np.RankWarning)
 
@@ -541,6 +543,39 @@ def boosting_monodepth(rgb_image_data_list):
 
     return depthmaps
 
+@torch.no_grad()
+def zoedepth_monodepth(rgb_image_data_list):
+    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+    torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
+    torch.hub.help("intel-isl/MiDaS", "DPT_BEiT_L_384", force_reload=True)
+
+    tfoms = transforms.Compose([transforms.ToTensor()])
+
+    repo = "isl-org/ZoeDepth"
+    model_zoe = torch.hub.load(repo, "ZoeD_NK", pretrained=True)
+    model_zoe = model_zoe.to(device)
+    model_zoe.eval()
+
+    depthmaps = []
+    for img in rgb_image_data_list:
+        img_t = tfoms(img / 255.).unsqueeze(0).type(torch.float32).to(device)
+        out = model_zoe(img_t)['metric_depth']
+
+        out = torch.nn.functional.interpolate(
+            out,
+            size=img.shape[:2],
+            mode="nearest-exact",
+        ).squeeze(0)
+
+        if torch.any(out < 0):
+            log.warn("Negative depth value")
+            out = torch.clamp(out, min=1e-6)
+        depthmaps.append(out)
+
+    del model_zoe
+    # grid = make_grid(depthmaps, nrow=5)[0]
+    return [depth2disparity(d.squeeze().cpu().numpy()) for d in depthmaps]
+
 
 def read_dpt(dpt_file_path):
     """read depth map from *.dpt file.
-- 
GitLab