diff --git a/scripts/scene_completion.py b/scripts/scene_completion.py index 6baef21f331197d3d88ea8cebfb0e564d8e4e9ff..e3514d42c20f249760a19f42177b7ee379b3f4a0 100644 --- a/scripts/scene_completion.py +++ b/scripts/scene_completion.py @@ -113,7 +113,7 @@ def get_completed_scene(shifted_disparity_path, shifted_t_path): output = stdout.read().decode() print(output) - remote_file_path = "/mainfs/ECShome/kproject/mona/MDBNet360_GDP/output/scene_completed_prediction.obj" + remote_file_path = os.getenv("REMOTE_OUTPUT_PATH_OBJ") current_dir = os.getcwd() parent_dir = os.path.dirname(current_dir) local_file_folder = os.path.join(parent_dir, "edgenet-360/Output") @@ -121,7 +121,7 @@ def get_completed_scene(shifted_disparity_path, shifted_t_path): with SCPClient(client.get_transport()) as scp: scp.get(remote_file_path, local_file_path) # Download file - remote_file_path = "/mainfs/ECShome/kproject/mona/MDBNet360_GDP/output/scene_completed_prediction.mtl" + remote_file_path = os.getenv("REMOTE_OUTPUT_PATH_MTL") local_file_path = os.path.join(local_file_folder,"scene_completed_prediction.mtl") with SCPClient(client.get_transport()) as scp: scp.get(remote_file_path, local_file_path) # Download file diff --git a/scripts/simple_tab.py b/scripts/simple_tab.py index 026ab0d4931b73b4b587c853b6d94c922ca8205e..f4b69bd9fb818e03122145e2531b060c78eeafdd 100644 --- a/scripts/simple_tab.py +++ b/scripts/simple_tab.py @@ -17,6 +17,7 @@ from debug_tool.tabs.shifter_tab import ShifterTab from debug_tool.tabs.depth_tab import DepthTab from debug_tool.tabs.material_tab import MaterialTab from debug_tool.tabs.edge_net_tab import EdgeNetTab +from debug_tool.tabs.mbdnet_tab import MBDNetTab from Depth2Disparity.depth_to_disparity import Depth2Disparity @@ -24,12 +25,16 @@ class PipelineWorker(QThread): progress = pyqtSignal(str) finished = pyqtSignal(bool, str) - def __init__(self, tab_instance): + def __init__(self, tab_instance, edgenet_flag=True,input_depth_flag=False, depth_map_path=None): super().__init__() + self.tab = tab_instance self.distance_points = self.tab.distance_points # added distance points to class for disparity calculation self.ratio = self.tab.ratio # image scale ratio for reversal - + self.edgenet_flag = edgenet_flag + self.input_depth_flag = input_depth_flag + self.depth_file_path = depth_map_path + def run(self): try: self.run_pipeline() @@ -58,7 +63,17 @@ class PipelineWorker(QThread): self.progress.emit("Copying input file to scripts/360monodepthexecution...") self.tab.depth.copy_file() - + + def copy_depth_map(self): + try: + if self.depth_file_path: + dest_path = os.path.join(self.tab.config_reader.directories['edgeNetDir'], 'Data', 'Input', 'depth_e.png') + copy_file(self.depth_file_path, dest_path) + self.progress.emit("Depth map copied successfully!") + except Exception as e: + print(f"Copy depth map failed: {str(e)}") + raise + def shift_image(self): print("Starting shift_image") # Debug print if not self.tab.should_shift_image: @@ -169,25 +184,42 @@ class PipelineWorker(QThread): self.tab.edge_net._run_blender_flip_process() self.progress.emit("Completed blender flip (blenderFlip.py)") self.progress.emit("Post-processing completed!") - self.progress.emit("File saved in edgenet-360\Output") + self.progress.emit("File saved in edgenet-360 -> Output") + + def run_mbdnet(self): + print("Starting MBDNet") + self.progress.emit("Running MBDNet...") + try: + self.tab.mbdnet._run_mbdnet_process() + print("Completed MBDNet") + except Exception as e: + print(f"MBDNet failed: {str(e)}") + raise def run_pipeline(self): - self.clean_temp_files() - self.shift_image() - self.copy_file() - self.run_depth_estimation() - #self.run_depth_to_disparity() - self.run_material_recognition() - self.run_edge_net() - self.run_post_processing() - self.progress.emit("Pipeline completed!") - + self.clean_temp_files() + self.shift_image() + self.copy_file() + if self.input_depth_flag: + self.copy_depth_map() + else: + self.run_depth_estimation() + #self.run_depth_to_disparity() + self.run_material_recognition() + if self.edgenet_flag: + self.run_edge_net() + self.run_post_processing() + else: + self.run_mbdnet() + self.progress.emit("Pipeline completed!") +9 class SimpleTab(QWidget): def __init__(self, config_reader): super().__init__() self.config_reader = config_reader self.input_path = None self.pipeline_thread = None + self.depth_input_path = None # Store states that will be used by worker thread self.should_shift_image = False @@ -198,6 +230,7 @@ class SimpleTab(QWidget): self.depth = DepthTab(self.config_reader) self.material = MaterialTab(self.config_reader) self.edge_net = EdgeNetTab(self.config_reader) + self.mbdnet = MBDNetTab(self.config_reader) # Hide their UIs as we'll use our own self.shifter.hide() @@ -297,7 +330,11 @@ class SimpleTab(QWidget): self.shift_image_check = QCheckBox("Shift Input Image") self.shift_image_check.setStyleSheet("QCheckBox { margin: 5px; background-color: #3e3e3e;}") options_layout.addWidget(self.shift_image_check) - + + self.Input_depth_check = QCheckBox("Input Depth Map") + self.Input_depth_check.setStyleSheet("QCheckBox { margin: 5px; background-color: #3e3e3e;}") + self.Input_depth_check.clicked.connect(self.handle_depth_select) + options_layout.addWidget(self.Input_depth_check) # SSC Model selection ssc_model_layout = QHBoxLayout() ssc_model_label = QLabel("SSC Model:") @@ -591,7 +628,21 @@ class SimpleTab(QWidget): self.file_selected = True self.flash_timer.stop() self.select_btn.setStyleSheet("QPushButton { margin: 5px; padding: 5px; border-radius: 10px;}") - + + def handle_depth_select(self): + self.update_status("Input Depth Map selected") + file_path = select_file( + self, + "Select Depth Map", + "Images (*.png)", + initial_dir=self.config_reader.directories['edgeNetDir'] + '/Data' + ) + + if file_path: + self.depth_input_path = file_path + self.update_status(f"Selected input file: {file_path}") + + def start_flashing(self): if not self.file_selected: self.flash_timer.start(1000) # Flash every 1000 milliseconds @@ -637,11 +688,14 @@ class SimpleTab(QWidget): self.progress_bar.show() self.run_pipeline_btn.setEnabled(False) - #TODO: Add distance points to the pipeline for depth estimation # Get the distance points self.distance_points = self.distance_preview.get_points() - self.pipeline_thread = PipelineWorker(self) + self.pipeline_thread = PipelineWorker( + self, + edgenet_flag=self.ssc_model_combo.currentText() == "EdgeNet360", + input_depth_flag=self.Input_depth_check.isChecked(), + depth_map_path=self.depth_input_path) # Connect signals self.pipeline_thread.progress.connect(self.update_status)