Skip to content
Snippets Groups Projects
Commit 9b0facef authored by mhz1g21's avatar mhz1g21
Browse files

Merge branch 'master' into 'GDP_4.5.1-ImplementDisparity'

# Conflicts:
#   scripts/simple_tab.py
parents 0240142d fab33b04
No related branches found
No related tags found
1 merge request!19Gdp 4.5.1 implement disparity
......@@ -113,7 +113,7 @@ def get_completed_scene(shifted_disparity_path, shifted_t_path):
output = stdout.read().decode()
print(output)
remote_file_path = "/mainfs/ECShome/kproject/mona/MDBNet360_GDP/output/scene_completed_prediction.obj"
remote_file_path = os.getenv("REMOTE_OUTPUT_PATH_OBJ")
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
local_file_folder = os.path.join(parent_dir, "edgenet-360/Output")
......@@ -121,7 +121,7 @@ def get_completed_scene(shifted_disparity_path, shifted_t_path):
with SCPClient(client.get_transport()) as scp:
scp.get(remote_file_path, local_file_path) # Download file
remote_file_path = "/mainfs/ECShome/kproject/mona/MDBNet360_GDP/output/scene_completed_prediction.mtl"
remote_file_path = os.getenv("REMOTE_OUTPUT_PATH_MTL")
local_file_path = os.path.join(local_file_folder,"scene_completed_prediction.mtl")
with SCPClient(client.get_transport()) as scp:
scp.get(remote_file_path, local_file_path) # Download file
......
......@@ -17,6 +17,7 @@ from debug_tool.tabs.shifter_tab import ShifterTab
from debug_tool.tabs.depth_tab import DepthTab
from debug_tool.tabs.material_tab import MaterialTab
from debug_tool.tabs.edge_net_tab import EdgeNetTab
from debug_tool.tabs.mbdnet_tab import MBDNetTab
from Depth2Disparity.depth_to_disparity import Depth2Disparity
......@@ -24,12 +25,16 @@ class PipelineWorker(QThread):
progress = pyqtSignal(str)
finished = pyqtSignal(bool, str)
def __init__(self, tab_instance):
def __init__(self, tab_instance, edgenet_flag=True,input_depth_flag=False, depth_map_path=None):
super().__init__()
self.tab = tab_instance
self.distance_points = self.tab.distance_points # added distance points to class for disparity calculation
self.ratio = self.tab.ratio # image scale ratio for reversal
self.edgenet_flag = edgenet_flag
self.input_depth_flag = input_depth_flag
self.depth_file_path = depth_map_path
def run(self):
try:
self.run_pipeline()
......@@ -58,7 +63,17 @@ class PipelineWorker(QThread):
self.progress.emit("Copying input file to scripts/360monodepthexecution...")
self.tab.depth.copy_file()
def copy_depth_map(self):
try:
if self.depth_file_path:
dest_path = os.path.join(self.tab.config_reader.directories['edgeNetDir'], 'Data', 'Input', 'depth_e.png')
copy_file(self.depth_file_path, dest_path)
self.progress.emit("Depth map copied successfully!")
except Exception as e:
print(f"Copy depth map failed: {str(e)}")
raise
def shift_image(self):
print("Starting shift_image") # Debug print
if not self.tab.should_shift_image:
......@@ -169,25 +184,42 @@ class PipelineWorker(QThread):
self.tab.edge_net._run_blender_flip_process()
self.progress.emit("Completed blender flip (blenderFlip.py)")
self.progress.emit("Post-processing completed!")
self.progress.emit("File saved in edgenet-360\Output")
self.progress.emit("File saved in edgenet-360 -> Output")
def run_mbdnet(self):
print("Starting MBDNet")
self.progress.emit("Running MBDNet...")
try:
self.tab.mbdnet._run_mbdnet_process()
print("Completed MBDNet")
except Exception as e:
print(f"MBDNet failed: {str(e)}")
raise
def run_pipeline(self):
self.clean_temp_files()
self.shift_image()
self.copy_file()
self.run_depth_estimation()
#self.run_depth_to_disparity()
self.run_material_recognition()
self.run_edge_net()
self.run_post_processing()
self.progress.emit("Pipeline completed!")
self.clean_temp_files()
self.shift_image()
self.copy_file()
if self.input_depth_flag:
self.copy_depth_map()
else:
self.run_depth_estimation()
#self.run_depth_to_disparity()
self.run_material_recognition()
if self.edgenet_flag:
self.run_edge_net()
self.run_post_processing()
else:
self.run_mbdnet()
self.progress.emit("Pipeline completed!")
9
class SimpleTab(QWidget):
def __init__(self, config_reader):
super().__init__()
self.config_reader = config_reader
self.input_path = None
self.pipeline_thread = None
self.depth_input_path = None
# Store states that will be used by worker thread
self.should_shift_image = False
......@@ -198,6 +230,7 @@ class SimpleTab(QWidget):
self.depth = DepthTab(self.config_reader)
self.material = MaterialTab(self.config_reader)
self.edge_net = EdgeNetTab(self.config_reader)
self.mbdnet = MBDNetTab(self.config_reader)
# Hide their UIs as we'll use our own
self.shifter.hide()
......@@ -297,7 +330,11 @@ class SimpleTab(QWidget):
self.shift_image_check = QCheckBox("Shift Input Image")
self.shift_image_check.setStyleSheet("QCheckBox { margin: 5px; background-color: #3e3e3e;}")
options_layout.addWidget(self.shift_image_check)
self.Input_depth_check = QCheckBox("Input Depth Map")
self.Input_depth_check.setStyleSheet("QCheckBox { margin: 5px; background-color: #3e3e3e;}")
self.Input_depth_check.clicked.connect(self.handle_depth_select)
options_layout.addWidget(self.Input_depth_check)
# SSC Model selection
ssc_model_layout = QHBoxLayout()
ssc_model_label = QLabel("SSC Model:")
......@@ -591,7 +628,21 @@ class SimpleTab(QWidget):
self.file_selected = True
self.flash_timer.stop()
self.select_btn.setStyleSheet("QPushButton { margin: 5px; padding: 5px; border-radius: 10px;}")
def handle_depth_select(self):
self.update_status("Input Depth Map selected")
file_path = select_file(
self,
"Select Depth Map",
"Images (*.png)",
initial_dir=self.config_reader.directories['edgeNetDir'] + '/Data'
)
if file_path:
self.depth_input_path = file_path
self.update_status(f"Selected input file: {file_path}")
def start_flashing(self):
if not self.file_selected:
self.flash_timer.start(1000) # Flash every 1000 milliseconds
......@@ -637,11 +688,14 @@ class SimpleTab(QWidget):
self.progress_bar.show()
self.run_pipeline_btn.setEnabled(False)
#TODO: Add distance points to the pipeline for depth estimation
# Get the distance points
self.distance_points = self.distance_preview.get_points()
self.pipeline_thread = PipelineWorker(self)
self.pipeline_thread = PipelineWorker(
self,
edgenet_flag=self.ssc_model_combo.currentText() == "EdgeNet360",
input_depth_flag=self.Input_depth_check.isChecked(),
depth_map_path=self.depth_input_path)
# Connect signals
self.pipeline_thread.progress.connect(self.update_status)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment