import paramiko
from dotenv import load_dotenv
from scp import SCPClient
import os
import subprocess
import time
import sys

load_dotenv()
hostname = os.getenv("HOSTNAME")
#username = os.getenv("USERNAME")
password = os.getenv("PASSWORD")
# hostname = "soton.ac.uk"
username = "kproject"
# password = "i7-13700H"

def send_files(shifted_disparity_path, shifted_t_path):
    # sends two files needed to iridis
    # files keep their names. NAMES MUST BE shifted_disparity.png AND shifted_t.png
    
    command = [
        "scp",
        shifted_disparity_path,
        f"{username}@{hostname}:{os.getenv("REMOTE_INPUT_PATH")}"
    ]
    try:
        subprocess.run(command, check=True)

    except subprocess.CalledProcessError as e:
        print(f"Error during SCP: {e}")
        return False, False


    command = [
        "scp",
        shifted_t_path,
        f"{username}@{hostname}:{os.getenv("REMOTE_INPUT_PATH")}"
    ]
    try:
        subprocess.run(command, check=True)

    except subprocess.CalledProcessError as e:
        print(f"Error during SCP: {e}")
        return False, False
    return True

def get_completed_scene(shifted_disparity_path, shifted_t_path):
    if send_files(shifted_disparity_path, shifted_t_path):
        
        client = paramiko.SSHClient() 
        client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) 

        # Connect to the server 
        try:
            client.connect(hostname, username=username, password=password)
            print("SSH connection established.")
            
            # Check if the connection is active
            if client.get_transport().is_active():
                print("Connection is active.")
            else:
                print("Connection is not active.")
                return False, False
        except Exception as e:
            print(f"Failed to connect to {hostname}: {e}")
            return False, False
         
        stdin, stdout, stderr = client.exec_command(
            "cd mona/MDBNet360_GDP/" + 
            " && module load conda" + 
            " && source activate" + 
            " && conda activate ssc_env" + 
            " && sbatch --partition=ecsstudents --account=ecsstudents run_enhance360_job.sh"
        )
        output = stdout.read().decode()
        print(output)

        stdin, stdout, stderr = client.exec_command("squeue -lu kproject")
        output = stdout.read().decode()
        print(output)

        time.sleep(30)

        while "RUNNING" in output:
            print("Sleeping for 30")
            time.sleep(30)
            stdin, stdout, stderr = client.exec_command("squeue -lu kproject")
            output = stdout.read().decode()
            print(output)

        print("Finish Enhancing")

        stdin, stdout, stderr = client.exec_command(
            "cd mona/MDBNet360_GDP/" + 
            " && module load conda" + 
            " && source activate" + 
            " && conda activate ssc_env" + 
            " && sbatch --partition=ecsstudents --account=ecsstudents run_obj_job.sh"
        )
        
        output = stdout.read().decode()
        print(output)

        stdin, stdout, stderr = client.exec_command("squeue -lu kproject")
        output = stdout.read().decode()
        print(output)
        time.sleep(30)

        while "RUNNING" in output:
            print("Sleeping for 30")
            time.sleep(30)
            stdin, stdout, stderr = client.exec_command("squeue -lu kproject")
            output = stdout.read().decode()
            print(output)

        remote_file_path = os.getenv("REMOTE_OUTPUT_PATH_OBJ")
        current_dir = os.getcwd()
        parent_dir = os.path.dirname(current_dir)
        local_file_folder = os.path.join(parent_dir, "edgenet-360/Output")
        local_file_path = os.path.join(local_file_folder,"scene_completed_prediction.obj")  
        with SCPClient(client.get_transport()) as scp:
            scp.get(remote_file_path, local_file_path)  # Download file

        remote_file_path = os.getenv("REMOTE_OUTPUT_PATH_MTL")
        local_file_path = os.path.join(local_file_folder,"scene_completed_prediction.mtl") 
        with SCPClient(client.get_transport()) as scp:
            scp.get(remote_file_path, local_file_path)  # Download file

        print("OUTPUT DOWNLOADED")
        # Close the SSH connection 
        client.close() 
        return True, local_file_path

    else:
        return False, False


# out = get_completed_scene("C:\\Project\\AVVR-Pipeline-GDP4\\edgenet-360\\Data\\Courtyard\\shifted-disparity.png", "C:\\Project\\AVVR-Pipeline-GDP4\\edgenet-360\\Data\\Courtyard\\shifted_t.png")
# print(out)

if __name__ == "__main__":
    if len(sys.argv) != 3:
        print("Usage: python scene_completion.py <shifted_disparity_path> <shifted_t_path>")
        sys.exit(1)

    shifted_disparity_path = sys.argv[1]
    shifted_t_path = sys.argv[2]

    saved_path = get_completed_scene(shifted_disparity_path, shifted_t_path)[1]
    print(f"Shifted image saved as {saved_path}")