Skip to content
Snippets Groups Projects
simple_tab.py 25.15 KiB
from PyQt6.QtWidgets import (QWidget, QVBoxLayout, QHBoxLayout, 
                             QGroupBox, QCheckBox, QMessageBox, QPushButton,
                             QProgressBar, QComboBox, QLabel, QSizePolicy, QScrollArea)
from PyQt6.QtCore import Qt, QThread, pyqtSignal, QTimer
import os
import sys

# Import utilities from debug_tool
from debug_tool.utils.qt_widgets import (create_group_with_text, create_button_layout, 
                                       create_info_group, create_preview_group)
from debug_tool.utils.file_handlers import select_file, clean_directory, copy_file, run_command
from debug_tool.utils.image_handlers import update_preview, load_and_resize_image, convert_cv_to_pixmap
from debug_tool.utils.clickable_label import ClickableLabel

# Import existing module implementations
from debug_tool.tabs.shifter_tab import ShifterTab
from debug_tool.tabs.depth_tab import DepthTab
from debug_tool.tabs.material_tab import MaterialTab
from debug_tool.tabs.edge_net_tab import EdgeNetTab

class PipelineWorker(QThread):
    progress = pyqtSignal(str)
    finished = pyqtSignal(bool, str)
    
    def __init__(self, tab_instance):
        super().__init__()
        self.tab = tab_instance
        
    def run(self):
        try:
            self.run_pipeline()
            self.finished.emit(True, "Pipeline completed successfully!")
        except Exception as e:
            print(f"Pipeline failed with error: {str(e)}")  # Debug print
            self.finished.emit(False, f"Pipeline failed: {str(e)}")
            
    def clean_temp_files(self):
        self.tab.depth.clean_input_dir()  # This one doesn't have dialogs
        self.progress.emit("Cleaning input directory...")   
        self.tab.depth.remove_rgb()
        self.progress.emit("Removing RGB image from 360monodepthexecution...")
        self.tab.material.clean_working_dir(silent=True)
        self.progress.emit("Cleaning Dynamic-Backward-Attention-Transformer temp directory...")
        self.tab.edge_net.clean_output_directory(silent=True)
        self.progress.emit("Cleaning EdgeNet output directory...")
    
    def copy_file(self):
        # Determine which file to use as input
        self.tab.depth.depth_input_path = self.tab.shifter.shifted_image_path if self.tab.should_shift_image else self.tab.input_path
        
        if self.tab.should_shift_image:
            self.progress.emit("Copying shifted image to scripts/360monodepthexecution...")
        else:
            self.progress.emit("Copying input file to scripts/360monodepthexecution...")
            
        self.tab.depth.copy_file()
        
    def shift_image(self):
        print("Starting shift_image")  # Debug print
        if not self.tab.should_shift_image:
            self.progress.emit("Skipping image shift...")
            return
            
        try:
            self.progress.emit("Shifting input image...")
            # Set input path for shifter
            self.tab.shifter.input_file_path = self.tab.input_path
            # Use the thread-safe version
            success, output = self.tab.shifter.run_shifter_process(self.progress.emit)
            
            if not success:
                raise RuntimeError(f"Image shifting failed: {output}")
            
            print("Completed shift_image")
            
            # Change material recognition input file path to shifted image
            self.tab.material.input_file_path = self.tab.shifter.shifted_image_path
            
            return self.tab.shifter.shifted_image_path
        except Exception as e:
            print(f"Shift image failed: {str(e)}")
            raise
        
    def run_depth_estimation(self):
        print("Starting depth_estimation")  # Debug print
        self.progress.emit("Running depth estimation...")
        self.tab.depth.run_depth_estimation()
        print("Completed depth_estimation")  # Debug print
        
    def run_material_recognition(self):
        print("Starting material_recognition")
        self.progress.emit("Running material recognition...")

        try:
            print(f"Input file path: {self.tab.material.input_file_path}")

            print("Running split 360...")
            success = self.tab.material.run_split_360()
            if not success:
                raise Exception("Split 360 failed")

            print("Running material recognition...")
            success = self.tab.material.run_material_recognition()
            if not success:
                raise Exception("Material recognition step failed")

            print("Starting combine step...")
            print(f"Current working directory: {os.getcwd()}")
            print(f"Material recognition directory: {self.tab.material.material_recog_dir}")
            print(f"Checking if cubemap directory exists: {os.path.exists(self.tab.material.cubemap_dir)}")
            print(f"Checking if material output directory exists: {os.path.exists(self.tab.material.material_output_dir)}")
            print("Files in cubemap directory:")
            if os.path.exists(self.tab.material.cubemap_dir):
                print("\n".join(os.listdir(self.tab.material.cubemap_dir)))
            print("Files in material output directory:")
            if os.path.exists(self.tab.material.material_output_dir):
                print("\n".join(os.listdir(self.tab.material.material_output_dir)))

            success = self.tab.material.run_combine()
            if not success:
                raise Exception("Combine step failed")

        except Exception as e:
            print(f"Material recognition error: {str(e)}")
            raise
        
        print("Completed material_recognition")
        
    def run_edge_net(self):
        print("Starting edge_net")
        self.progress.emit("Running EdgeNet enhance360.py and infer360.py...")

        try:
            self.tab.edge_net.include_top = self.tab.should_include_top  # Use cached state
            self.tab.edge_net._run_edge_net_process()
            print("Completed edge_net")
        except Exception as e:
            print(f"EdgeNet failed: {str(e)}")
            raise
            
    def run_post_processing(self):
        self.progress.emit("Running post-processing...")
        self.tab.edge_net._run_mesh_split_process()
        self.progress.emit("Completed mesh split (replace.py)")
        self.tab.edge_net._run_blender_flip_process()
        self.progress.emit("Completed blender flip (blenderFlip.py)") 
        self.progress.emit("Post-processing completed!")
    
    def run_pipeline(self):
        self.clean_temp_files()
        self.shift_image()
        self.copy_file()
        self.run_depth_estimation()
        self.run_material_recognition()
        self.run_edge_net()
        self.run_post_processing()
        self.progress.emit("Pipeline completed!")

class SimpleTab(QWidget):
    def __init__(self, config_reader):
        super().__init__()
        self.config_reader = config_reader
        self.input_path = None
        self.pipeline_thread = None
    
        # Store states that will be used by worker thread
        self.should_shift_image = False
        self.should_include_top = False
        
        # Initialize module instances
        self.shifter = ShifterTab(self.config_reader)
        self.depth = DepthTab(self.config_reader)
        self.material = MaterialTab(self.config_reader)
        self.edge_net = EdgeNetTab(self.config_reader)
        
        # Hide their UIs as we'll use our own
        self.shifter.hide()
        self.depth.hide()
        self.material.hide()
        self.edge_net.hide()

        # Scroll Area
        self.scroll_area = QScrollArea(self)
        self.scroll_area.setWidgetResizable(True)

        self.scroll_content = QWidget()

        self.file_selected = False
        self.flash_timer = QTimer(self)
        self.flash_timer.timeout.connect(self.toggle_flash)
        self.setup_ui(self.scroll_content)
        
        self.scroll_area.setWidget(self.scroll_content)
        layout=QVBoxLayout(self)
        layout.addWidget(self.scroll_area)
        
    def setup_ui(self,parent_widget):
        layout = QVBoxLayout(parent_widget)
        
        # Controls section
        controls_group = QGroupBox("Pipeline Controls")
        controls_group.setStyleSheet("""
            QGroupBox {
                font-weight: bold;
                border: 2px solid grey;
                border-radius: 20px;
                margin-top: 10px;
                background-color: #3e3e3e;
                padding: 15px;
            }
            QGroupBox::title { 
                margin: 10px;
                background-color: transparent;
                color: white;
            }
        """)
        controls_layout = QVBoxLayout(controls_group)
        
        # Info display
        info_rows = [
            ("Input Image:", "No file selected"),
            ("Status:", "Ready - Waiting for input"),
        ]
        self.info_group, self.info_labels = create_info_group("Information", info_rows)
        self.info_group.setStyleSheet("""
            QGroupBox {
                font-weight: bold;
                border: 2px solid grey;
                border-radius: 10px;
                margin-top: 10px;
                background-color: #3e3e3e;
                padding: 20px;
            }
            QGroupBox::title { 
                margin: 10px;
                background-color: transparent;
                color: white;
            }
            QLabel{
                margin: 5px;
                background-color: #3e3e3e;
                color: white;
            }                 
        """)
        for label in self.info_labels.values():
            label.setStyleSheet("""
                QLabel{
                    margin: 5px;
                    background-color: #3e3e3e;
                    color: white;
                }
            """)
        controls_layout.addWidget(self.info_group)
        
        # Options
        options_layout = QHBoxLayout()
        
        self.include_top_check = QCheckBox("Include Top in Mesh")
        self.include_top_check.setStyleSheet("""
            QCheckBox {
                margin: 5px;
                padding: 5px;
                background-color: #3e3e3e;
                color: white;
                border: none;
                border-radius: 5px;
            }
        """)
        options_layout.addWidget(self.include_top_check)
        
        self.shift_image_check = QCheckBox("Shift Input Image")
        self.shift_image_check.setStyleSheet("QCheckBox { margin: 5px; background-color: #3e3e3e;}")
        options_layout.addWidget(self.shift_image_check)

        # SSC Model selection
        ssc_model_layout = QHBoxLayout()
        ssc_model_label = QLabel("SSC Model:")
        #ssc_model_label.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter)
        ssc_model_label.setStyleSheet("""
            QLabel {
                margin: 5px;
                background-color: transparent;
                color: white;
                font-weight: bold;
            }
        """)
        ssc_model_label.setFixedWidth(ssc_model_label.sizeHint().width())
        ssc_model_layout.addWidget(ssc_model_label)

        self.ssc_model_combo = QComboBox()
        self.ssc_model_combo.addItems(["EdgeNet360", "MDBNet"])
        self.ssc_model_combo.setStyleSheet("""
            QComboBox {
                margin: 5px;
                padding: 5px;
                background-color: #1e1e1e;
                color: white;
                border: none;
                border-radius: 5px;
            }
            QComboBox QAbstractItemView {
                background-color: #1e1e1e;
                color: white;
                selection-background-color: #5e5e5e;
                border-radius: 5px;
            }
        """)
        self.ssc_model_combo.setFixedWidth(150)
        ssc_model_layout.addWidget(self.ssc_model_combo)
        ssc_model_layout.setAlignment(Qt.AlignmentFlag.AlignLeft)
        # Add the horizontal layout to the options layout
        options_layout.addLayout(ssc_model_layout)

        controls_layout.addLayout(options_layout)
        
        # Progress Bar
        self.progress_bar = QProgressBar()
        self.progress_bar.setMinimum(0)
        self.progress_bar.setMaximum(0)  # Makes it an indefinite progress bar
        self.progress_bar.hide()  # Hidden by default
        self.progress_bar.setStyleSheet("""
            QProgressBar {
                border: 2px solid grey;
                border-radius: 5px;
                text-align: center;
            }
            QProgressBar::chunk {
                background-color: #05B8CC;
                width: 20px;
            }
        """)
        controls_layout.addWidget(self.progress_bar)
        
        # Buttons
        self.run_pipeline_btn = QPushButton("Run Pipeline")
        self.run_pipeline_btn.clicked.connect(self.run_full_pipeline)
        self.run_pipeline_btn.setEnabled(False)  # Disabled by default
        self.run_pipeline_btn.setStyleSheet("""
            QPushButton {
                margin: 5px;
                padding: 5px;
                border-radius: 10px;
            }
            QPushButton:enabled {
                background-color: green;
                color: white;
            }
            QPushButton:disabled {
            background-color: red;
            color: white;
            }
        """)
        self.run_pipeline_btn.setFixedSize(600, 40)  # Explicit size
        
        buttons_layout = QHBoxLayout()
        self.select_btn = QPushButton("Select Input Image")
        self.select_btn.clicked.connect(self.handle_file_select)
        self.select_btn.setStyleSheet("""
            QPushButton {
                margin: 5px;
                padding: 5px;
                border-radius: 10px;
            }
        """)
        self.select_btn.setFixedSize(600, 40)  # Explicit size
        
        buttons_layout.addWidget(self.select_btn)
        buttons_layout.addWidget(self.run_pipeline_btn)
        controls_layout.addLayout(buttons_layout)
        
        layout.addWidget(controls_group)

        ## Image Distance section
        self.image_distance_group = QGroupBox("Image Distance")
        distance_layout = QVBoxLayout(self.image_distance_group)
        info_label = QLabel("Please select two point on the image and input the distance from the camera to that point.")
        self.counter_label = QLabel("(0/2)")
        self.distance_preview = ClickableLabel()
        self.distance_preview.setAlignment(Qt.AlignmentFlag.AlignCenter)
        self.counter_label.setAlignment(Qt.AlignmentFlag.AlignBottom | Qt.AlignmentFlag.AlignCenter)
        self.points_info_label = QLabel()  # Label to display points and distances
        self.points_info_label.setAlignment(Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignCenter)
        self.image_distance_group.setStyleSheet("""
           QGroupBox {
               font-weight: bold;
               border: 2px solid grey;
               border-radius: 20px;
               margin-top: 10px;
               background-color: #3e3e3e;
               padding: 20px;
           }
           QGroupBox::title { 
               margin: 10px;
               background-color: transparent;
               color: white;
           }
           QLabel {
               margin: 5px;
               background-color: #3e3e3e;
               color: white;
           }
        """)
        # Center the ClickableLabel within its parent layout
        distance_preview_layout = QHBoxLayout()
        distance_preview_layout.addStretch()
        distance_preview_layout.addWidget(self.distance_preview)
        distance_preview_layout.addStretch()

        self.distance_reset_btn = QPushButton("Reset Points")
        self.distance_reset_btn.clicked.connect(self.distance_preview.clear_points)
        self.distance_reset_btn.setFixedSize(150, 40)
        self.distance_reset_btn.setStyleSheet("""
           QPushButton {
               margin: 5px;
               padding: 5px;
               border-radius: 10px;
           }
        """)
        distance_btn_layout = QHBoxLayout()
        distance_btn_layout.addStretch()
        distance_btn_layout.addWidget(self.distance_reset_btn)
        distance_btn_layout.addStretch()

        distance_layout.addWidget(info_label)
        distance_layout.addLayout(distance_preview_layout)
        distance_layout.addWidget(self.points_info_label)
        distance_layout.addWidget(self.counter_label)
        distance_layout.addWidget(self.distance_reset_btn)
        self.image_distance_group.hide()
        layout.addWidget(self.image_distance_group)

        self.distance_preview.point_added.connect(self.update_counter_Label)
        
        # Status section
        status_group, self.status_text = create_group_with_text("Pipeline Status", 300)
        status_group.setStyleSheet("""
            QGroupBox {
                font-weight: bold;
                border: 2px solid grey;
                border-radius: 20px;
                margin-top: 10px;
                background-color: #3e3e3e;
                padding: 20px;
            }
            QGroupBox::title { 
                margin: 10px;
                background-color: transparent;
                color: white;
            }
        """)
        self.status_text.setStyleSheet("""
            QTextEdit {
                background-color: #1e1e1e;
                color: white;
                border: 2px solid grey;
                border-radius: 10px;
                padding: 10px;
                font-size: 14px;
            }
            QTextEdit:focus {
                border: 2px solid #05B8CC;
            }
        """)
        layout.addWidget(status_group)
        
        # Preview section
        preview_group = QGroupBox("Preview")
        preview_group.setStyleSheet("""
            QGroupBox {
                font-weight: bold;
                border: 2px solid grey;
                border-radius: 20px;
                margin-top: 10px;
                background-color: #3e3e3e;
                padding: 20px;
            }
            QGroupBox::title { 
                margin: 10px;
                background-color: transparent;
                color: white;
            }
        """)
        preview_layout = QHBoxLayout(preview_group)
        
        input_group, self.input_preview = create_preview_group("Input Image")
        output_group, self.output_preview = create_preview_group("Current Output")
        input_group.setStyleSheet("""
            QGroupBox {
                font-weight: bold;
                border: 2px solid grey;
                border-radius: 10px;
                margin-top: 10px;
                background-color: #1e1e1e;
                padding: 20px;
            }
            QGroupBox::title { 
                margin: 10px;
                background-color: transparent;
                color: white;
            }
            QLabel {
                margin: 5px;
                background-color: #1e1e1e;
                color: white;
            }
        """)
        output_group.setStyleSheet("""
            QGroupBox {
                font-weight: bold;
                border: 2px solid grey;
                border-radius: 10px;
                margin-top: 10px;
                background-color: #1e1e1e;
                padding: 20px;
            }
            QGroupBox::title { 
                margin: 10px;
                background-color: transparent;
                color: white;
            }
            QLabel {
                margin: 5px;
                background-color: #1e1e1e;
                color: white;
            }
        """)
        preview_layout.addWidget(input_group)
        preview_layout.addWidget(output_group)
        layout.addWidget(preview_group)

        # Start flashing if no file is selected
        self.start_flashing()

    def handle_file_select(self):
        
        file_path = select_file(
            self,
            "Select Input Image",
            "Images (*.png *.jpg *.jpeg)",
            initial_dir=self.config_reader.directories['edgeNetDir'] + '/Data'
        )
        
        if file_path:
            self.input_path = file_path
            self.info_labels["Input Image:"].setText(os.path.basename(file_path))
            self.update_status(f"Selected input file: {file_path}")
            update_preview(self.input_preview, file_path, 
                         error_callback=self.update_status)
            update_preview(self.distance_preview,file_path,max_size=1500)
            pixmap = load_and_resize_image(file_path, 1500)
            pixmap = convert_cv_to_pixmap(pixmap)
            self.distance_preview.setFixedSize(pixmap.size())
            self.image_distance_group.show()
            self.update_status("Waiting for distance points...")
            self.info_labels["Status:"].setText("Waiting for distance points...")
            # Enable the run pipeline button
            self.run_pipeline_btn.setEnabled(True)
            
            # Provide input path to all modules
            self.shifter.input_file_path = file_path
            self.depth.depth_input_path = file_path
            self.material.input_file_path = file_path
            # self.edge_net.input_path = file_path # edgenet have default input path
        
        self.file_selected = True
        self.flash_timer.stop()
        self.select_btn.setStyleSheet("QPushButton { margin: 5px; padding: 5px; border-radius: 10px;}")
    
    def start_flashing(self):
        if not self.file_selected:
            self.flash_timer.start(1000)  # Flash every 1000 milliseconds

    def toggle_flash(self):
        current_style = self.select_btn.styleSheet()
        if "background-color: DarkOrange;" in current_style:
            self.select_btn.setStyleSheet("QPushButton { margin: 5px; padding: 5px; border-radius: 10px;}")
        else:
            self.select_btn.setStyleSheet("QPushButton { margin: 5px; padding: 5px; background-color: DarkOrange; border-radius: 10px;}")

    def update_counter_Label(self):
        count = len(self.distance_preview.get_points())
        self.counter_label.setText(f"({count}/2)")
        points_info = "\n".join([f"Point {i+1}: (x={x:.2f}, y={y:.2f}), Distance: {distance:.2f} meters)"
                                for i, (x, y, distance) in enumerate(self.distance_preview.get_points())])
        self.points_info_label.setText(points_info)

        #enable run pipeline button if 2 points are selected
        if count == 2:
            self.run_pipeline_btn.setEnabled(True)
            self.update_status("Distance points selected. Ready to run pipeline.")
        else:
            #pass
            self.run_pipeline_btn.setEnabled(False)
            self.update_status("Waiting for distance points...")
            self.info_labels["Status:"].setText("Waiting for distance points...")

    def run_full_pipeline(self):
        if not self.input_path:
            QMessageBox.warning(self, "Warning", "Please select an input file first")
            return
            
        if self.pipeline_thread and self.pipeline_thread.isRunning():
            QMessageBox.warning(self, "Warning", "Pipeline is already running")
            return
        
        # Cache checkbox states before starting thread
        self.should_shift_image = self.shift_image_check.isChecked()
        self.should_include_top = self.include_top_check.isChecked()
            
        # Show progress bar and update status
        self.progress_bar.show()
        self.run_pipeline_btn.setEnabled(False)
        
        self.pipeline_thread = PipelineWorker(self)
        
        # Connect signals
        self.pipeline_thread.progress.connect(self.update_status)
        self.pipeline_thread.finished.connect(self.pipeline_completed)
        
        # Disable controls while running
        #self.setEnabled(False)
        self.disable_buttons_while_running()
        self.progress_bar.setEnabled(True)  # Keep progress bar enabled
        
        #TODO: Add model selection for EdgeNet or MDBNet
        # Set the SSC model
        self.selected_model = self.ssc_model_combo.currentText()

        #TODO: Add distance points to the pipeline for depth estimation
        # Get the distance points
        self.distance_points = self.distance_preview.get_points()
        
        # Start the pipeline
        self.pipeline_thread.start()
        
    def pipeline_completed(self, success, message):
        self.setEnabled(True)
        self.progress_bar.hide()
        self.run_pipeline_btn.setEnabled(True)
        self.update_status(message)
        
        if success:
            QMessageBox.information(self, "Success", "Pipeline completed successfully!")
        else:
            QMessageBox.critical(self, "Error", f"Pipeline failed: {message}")
            
    def update_status(self, message):
        self.status_text.append(message)
        self.info_labels["Status:"].setText(message.split("...")[-1] if "..." in message else message)
        # Scroll to bottom
        scrollbar = self.status_text.verticalScrollBar()
        scrollbar.setValue(scrollbar.maximum())

    def disable_buttons_while_running(self):
        self.select_btn.setEnabled(False)
        self.run_pipeline_btn.setEnabled(False)
        self.include_top_check.setEnabled(False)
        self.shift_image_check.setEnabled(False)
        self.ssc_model_combo.setEnabled(False)
        self.distance_reset_btn.setEnabled(False)
        self.distance_preview.setEnabled(False)