diff --git a/scripts/simple_tab.py b/scripts/simple_tab.py index 867c06ab3385c05cd959f7f19be0288d913229b2..b3136521d0a54c259b153c419ec06115dd438c67 100644 --- a/scripts/simple_tab.py +++ b/scripts/simple_tab.py @@ -17,14 +17,16 @@ from debug_tool.tabs.shifter_tab import ShifterTab from debug_tool.tabs.depth_tab import DepthTab from debug_tool.tabs.material_tab import MaterialTab from debug_tool.tabs.edge_net_tab import EdgeNetTab +from debug_tool.tabs.mbdnet_tab import MBDNetTab class PipelineWorker(QThread): progress = pyqtSignal(str) finished = pyqtSignal(bool, str) - def __init__(self, tab_instance): + def __init__(self, tab_instance, edgenet_flag=True): super().__init__() self.tab = tab_instance + self.edgenet_flag = edgenet_flag def run(self): try: @@ -145,7 +147,17 @@ class PipelineWorker(QThread): self.tab.edge_net._run_blender_flip_process() self.progress.emit("Completed blender flip (blenderFlip.py)") self.progress.emit("Post-processing completed!") - self.progress.emit("File saved in edgenet-360\Output") + self.progress.emit("File saved in edgenet-360 -> Output") + + def run_mbdnet(self): + print("Starting MBDNet") + self.progress.emit("Running MBDNet...") + try: + self.tab.mbdnet._run_mbdnet_process() + print("Completed MBDNet") + except Exception as e: + print(f"MBDNet failed: {str(e)}") + raise def run_pipeline(self): self.clean_temp_files() @@ -153,8 +165,11 @@ class PipelineWorker(QThread): self.copy_file() self.run_depth_estimation() self.run_material_recognition() - self.run_edge_net() - self.run_post_processing() + if self.edgenet_flag: + self.run_edge_net() + self.run_post_processing() + else: + self.run_mbdnet() self.progress.emit("Pipeline completed!") class SimpleTab(QWidget): @@ -173,6 +188,7 @@ class SimpleTab(QWidget): self.depth = DepthTab(self.config_reader) self.material = MaterialTab(self.config_reader) self.edge_net = EdgeNetTab(self.config_reader) + self.mbdnet = MBDNetTab(self.config_reader) # Hide their UIs as we'll use our own self.shifter.hide() @@ -612,7 +628,7 @@ class SimpleTab(QWidget): self.progress_bar.show() self.run_pipeline_btn.setEnabled(False) - self.pipeline_thread = PipelineWorker(self) + self.pipeline_thread = PipelineWorker(self,edgenet_flag=self.ssc_model_combo.currentText() == "EdgeNet360") # Connect signals self.pipeline_thread.progress.connect(self.update_status)