Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • GDP_4.2.1
  • GDP_4.2.6
  • GDP_4.4.7
  • master
  • results
5 results

Target

Select target project
  • gdp-project-4/avvr-pipeline-gdp4
1 result
Select Git revision
  • GDP_4.2.1
  • GDP_4.2.6
  • GDP_4.4.7
  • master
  • results
5 results
Show changes
Commits on Source (10)
...@@ -113,7 +113,7 @@ def get_completed_scene(shifted_disparity_path, shifted_t_path): ...@@ -113,7 +113,7 @@ def get_completed_scene(shifted_disparity_path, shifted_t_path):
output = stdout.read().decode() output = stdout.read().decode()
print(output) print(output)
remote_file_path = "/mainfs/ECShome/kproject/mona/MDBNet360_GDP/output/scene_completed_prediction.obj" remote_file_path = os.getenv("REMOTE_OUTPUT_PATH_OBJ")
current_dir = os.getcwd() current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir) parent_dir = os.path.dirname(current_dir)
local_file_folder = os.path.join(parent_dir, "edgenet-360/Output") local_file_folder = os.path.join(parent_dir, "edgenet-360/Output")
...@@ -121,7 +121,7 @@ def get_completed_scene(shifted_disparity_path, shifted_t_path): ...@@ -121,7 +121,7 @@ def get_completed_scene(shifted_disparity_path, shifted_t_path):
with SCPClient(client.get_transport()) as scp: with SCPClient(client.get_transport()) as scp:
scp.get(remote_file_path, local_file_path) # Download file scp.get(remote_file_path, local_file_path) # Download file
remote_file_path = "/mainfs/ECShome/kproject/mona/MDBNet360_GDP/output/scene_completed_prediction.mtl" remote_file_path = os.getenv("REMOTE_OUTPUT_PATH_MTL")
local_file_path = os.path.join(local_file_folder,"scene_completed_prediction.mtl") local_file_path = os.path.join(local_file_folder,"scene_completed_prediction.mtl")
with SCPClient(client.get_transport()) as scp: with SCPClient(client.get_transport()) as scp:
scp.get(remote_file_path, local_file_path) # Download file scp.get(remote_file_path, local_file_path) # Download file
......
...@@ -17,14 +17,18 @@ from debug_tool.tabs.shifter_tab import ShifterTab ...@@ -17,14 +17,18 @@ from debug_tool.tabs.shifter_tab import ShifterTab
from debug_tool.tabs.depth_tab import DepthTab from debug_tool.tabs.depth_tab import DepthTab
from debug_tool.tabs.material_tab import MaterialTab from debug_tool.tabs.material_tab import MaterialTab
from debug_tool.tabs.edge_net_tab import EdgeNetTab from debug_tool.tabs.edge_net_tab import EdgeNetTab
from debug_tool.tabs.mbdnet_tab import MBDNetTab
class PipelineWorker(QThread): class PipelineWorker(QThread):
progress = pyqtSignal(str) progress = pyqtSignal(str)
finished = pyqtSignal(bool, str) finished = pyqtSignal(bool, str)
def __init__(self, tab_instance): def __init__(self, tab_instance, edgenet_flag=True,input_depth_flag=False, depth_map_path=None):
super().__init__() super().__init__()
self.tab = tab_instance self.tab = tab_instance
self.edgenet_flag = edgenet_flag
self.input_depth_flag = input_depth_flag
self.depth_file_path = depth_map_path
def run(self): def run(self):
try: try:
...@@ -54,7 +58,17 @@ class PipelineWorker(QThread): ...@@ -54,7 +58,17 @@ class PipelineWorker(QThread):
self.progress.emit("Copying input file to scripts/360monodepthexecution...") self.progress.emit("Copying input file to scripts/360monodepthexecution...")
self.tab.depth.copy_file() self.tab.depth.copy_file()
def copy_depth_map(self):
try:
if self.depth_file_path:
dest_path = os.path.join(self.tab.config_reader.directories['edgeNetDir'], 'Data', 'Input', 'depth_e.png')
copy_file(self.depth_file_path, dest_path)
self.progress.emit("Depth map copied successfully!")
except Exception as e:
print(f"Copy depth map failed: {str(e)}")
raise
def shift_image(self): def shift_image(self):
print("Starting shift_image") # Debug print print("Starting shift_image") # Debug print
if not self.tab.should_shift_image: if not self.tab.should_shift_image:
...@@ -145,16 +159,32 @@ class PipelineWorker(QThread): ...@@ -145,16 +159,32 @@ class PipelineWorker(QThread):
self.tab.edge_net._run_blender_flip_process() self.tab.edge_net._run_blender_flip_process()
self.progress.emit("Completed blender flip (blenderFlip.py)") self.progress.emit("Completed blender flip (blenderFlip.py)")
self.progress.emit("Post-processing completed!") self.progress.emit("Post-processing completed!")
self.progress.emit("File saved in edgenet-360\Output") self.progress.emit("File saved in edgenet-360 -> Output")
def run_mbdnet(self):
print("Starting MBDNet")
self.progress.emit("Running MBDNet...")
try:
self.tab.mbdnet._run_mbdnet_process()
print("Completed MBDNet")
except Exception as e:
print(f"MBDNet failed: {str(e)}")
raise
def run_pipeline(self): def run_pipeline(self):
self.clean_temp_files() self.clean_temp_files()
self.shift_image() self.shift_image()
self.copy_file() self.copy_file()
self.run_depth_estimation() if self.input_depth_flag:
self.copy_depth_map()
else:
self.run_depth_estimation()
self.run_material_recognition() self.run_material_recognition()
self.run_edge_net() if self.edgenet_flag:
self.run_post_processing() self.run_edge_net()
self.run_post_processing()
else:
self.run_mbdnet()
self.progress.emit("Pipeline completed!") self.progress.emit("Pipeline completed!")
class SimpleTab(QWidget): class SimpleTab(QWidget):
...@@ -163,6 +193,7 @@ class SimpleTab(QWidget): ...@@ -163,6 +193,7 @@ class SimpleTab(QWidget):
self.config_reader = config_reader self.config_reader = config_reader
self.input_path = None self.input_path = None
self.pipeline_thread = None self.pipeline_thread = None
self.depth_input_path = None
# Store states that will be used by worker thread # Store states that will be used by worker thread
self.should_shift_image = False self.should_shift_image = False
...@@ -173,6 +204,7 @@ class SimpleTab(QWidget): ...@@ -173,6 +204,7 @@ class SimpleTab(QWidget):
self.depth = DepthTab(self.config_reader) self.depth = DepthTab(self.config_reader)
self.material = MaterialTab(self.config_reader) self.material = MaterialTab(self.config_reader)
self.edge_net = EdgeNetTab(self.config_reader) self.edge_net = EdgeNetTab(self.config_reader)
self.mbdnet = MBDNetTab(self.config_reader)
# Hide their UIs as we'll use our own # Hide their UIs as we'll use our own
self.shifter.hide() self.shifter.hide()
...@@ -272,7 +304,11 @@ class SimpleTab(QWidget): ...@@ -272,7 +304,11 @@ class SimpleTab(QWidget):
self.shift_image_check = QCheckBox("Shift Input Image") self.shift_image_check = QCheckBox("Shift Input Image")
self.shift_image_check.setStyleSheet("QCheckBox { margin: 5px; background-color: #3e3e3e;}") self.shift_image_check.setStyleSheet("QCheckBox { margin: 5px; background-color: #3e3e3e;}")
options_layout.addWidget(self.shift_image_check) options_layout.addWidget(self.shift_image_check)
self.Input_depth_check = QCheckBox("Input Depth Map")
self.Input_depth_check.setStyleSheet("QCheckBox { margin: 5px; background-color: #3e3e3e;}")
self.Input_depth_check.clicked.connect(self.handle_depth_select)
options_layout.addWidget(self.Input_depth_check)
# SSC Model selection # SSC Model selection
ssc_model_layout = QHBoxLayout() ssc_model_layout = QHBoxLayout()
ssc_model_label = QLabel("SSC Model:") ssc_model_label = QLabel("SSC Model:")
...@@ -566,7 +602,21 @@ class SimpleTab(QWidget): ...@@ -566,7 +602,21 @@ class SimpleTab(QWidget):
self.file_selected = True self.file_selected = True
self.flash_timer.stop() self.flash_timer.stop()
self.select_btn.setStyleSheet("QPushButton { margin: 5px; padding: 5px; border-radius: 10px;}") self.select_btn.setStyleSheet("QPushButton { margin: 5px; padding: 5px; border-radius: 10px;}")
def handle_depth_select(self):
self.update_status("Input Depth Map selected")
file_path = select_file(
self,
"Select Depth Map",
"Images (*.png)",
initial_dir=self.config_reader.directories['edgeNetDir'] + '/Data'
)
if file_path:
self.depth_input_path = file_path
self.update_status(f"Selected input file: {file_path}")
def start_flashing(self): def start_flashing(self):
if not self.file_selected: if not self.file_selected:
self.flash_timer.start(1000) # Flash every 1000 milliseconds self.flash_timer.start(1000) # Flash every 1000 milliseconds
...@@ -612,7 +662,11 @@ class SimpleTab(QWidget): ...@@ -612,7 +662,11 @@ class SimpleTab(QWidget):
self.progress_bar.show() self.progress_bar.show()
self.run_pipeline_btn.setEnabled(False) self.run_pipeline_btn.setEnabled(False)
self.pipeline_thread = PipelineWorker(self) self.pipeline_thread = PipelineWorker(
self,
edgenet_flag=self.ssc_model_combo.currentText() == "EdgeNet360",
input_depth_flag=self.Input_depth_check.isChecked(),
depth_map_path=self.depth_input_path)
# Connect signals # Connect signals
self.pipeline_thread.progress.connect(self.update_status) self.pipeline_thread.progress.connect(self.update_status)
......