diff --git a/scripts/simple_tab.py b/scripts/simple_tab.py index b3136521d0a54c259b153c419ec06115dd438c67..1730f81d26edd85170154aa54341e4463fd27d00 100644 --- a/scripts/simple_tab.py +++ b/scripts/simple_tab.py @@ -23,10 +23,12 @@ class PipelineWorker(QThread): progress = pyqtSignal(str) finished = pyqtSignal(bool, str) - def __init__(self, tab_instance, edgenet_flag=True): + def __init__(self, tab_instance, edgenet_flag=True,input_depth_flag=False, depth_map_path=None): super().__init__() self.tab = tab_instance self.edgenet_flag = edgenet_flag + self.input_depth_flag = input_depth_flag + self.depth_file_path = depth_map_path def run(self): try: @@ -56,7 +58,17 @@ class PipelineWorker(QThread): self.progress.emit("Copying input file to scripts/360monodepthexecution...") self.tab.depth.copy_file() - + + def copy_depth_map(self): + try: + if self.depth_file_path: + dest_path = os.path.join(self.tab.config_reader.directories['edgeNetDir'], 'Data', 'Input', 'depth_e.png') + copy_file(self.depth_file_path, dest_path) + self.progress.emit("Depth map copied successfully!") + except Exception as e: + print(f"Copy depth map failed: {str(e)}") + raise + def shift_image(self): print("Starting shift_image") # Debug print if not self.tab.should_shift_image: @@ -163,7 +175,10 @@ class PipelineWorker(QThread): self.clean_temp_files() self.shift_image() self.copy_file() - self.run_depth_estimation() + if self.input_depth_flag: + self.copy_depth_map() + else: + self.run_depth_estimation() self.run_material_recognition() if self.edgenet_flag: self.run_edge_net() @@ -178,6 +193,7 @@ class SimpleTab(QWidget): self.config_reader = config_reader self.input_path = None self.pipeline_thread = None + self.depth_input_path = None # Store states that will be used by worker thread self.should_shift_image = False @@ -288,7 +304,11 @@ class SimpleTab(QWidget): self.shift_image_check = QCheckBox("Shift Input Image") self.shift_image_check.setStyleSheet("QCheckBox { margin: 5px; background-color: #3e3e3e;}") options_layout.addWidget(self.shift_image_check) - + + self.Input_depth_check = QCheckBox("Input Depth Map") + self.Input_depth_check.setStyleSheet("QCheckBox { margin: 5px; background-color: #3e3e3e;}") + self.Input_depth_check.clicked.connect(self.handle_depth_select) + options_layout.addWidget(self.Input_depth_check) # SSC Model selection ssc_model_layout = QHBoxLayout() ssc_model_label = QLabel("SSC Model:") @@ -582,7 +602,21 @@ class SimpleTab(QWidget): self.file_selected = True self.flash_timer.stop() self.select_btn.setStyleSheet("QPushButton { margin: 5px; padding: 5px; border-radius: 10px;}") - + + def handle_depth_select(self): + self.update_status("Input Depth Map selected") + file_path = select_file( + self, + "Select Depth Map", + "Images (*.png)", + initial_dir=self.config_reader.directories['edgeNetDir'] + '/Data' + ) + + if file_path: + self.depth_input_path = file_path + self.update_status(f"Selected input file: {file_path}") + + def start_flashing(self): if not self.file_selected: self.flash_timer.start(1000) # Flash every 1000 milliseconds @@ -628,7 +662,11 @@ class SimpleTab(QWidget): self.progress_bar.show() self.run_pipeline_btn.setEnabled(False) - self.pipeline_thread = PipelineWorker(self,edgenet_flag=self.ssc_model_combo.currentText() == "EdgeNet360") + self.pipeline_thread = PipelineWorker( + self, + edgenet_flag=self.ssc_model_combo.currentText() == "EdgeNet360", + input_depth_flag=self.Input_depth_check.isChecked(), + depth_map_path=self.depth_input_path) # Connect signals self.pipeline_thread.progress.connect(self.update_status)