diff --git a/IMU/code/connect.py b/IMU/code/connect.py
index 69c7d087617e7db084da10facd9e347de6222d7c..df84b83f300e38d6787387d35f29a81f9eba86c7 100644
--- a/IMU/code/connect.py
+++ b/IMU/code/connect.py
@@ -4,7 +4,6 @@ import numpy as np
 from time import sleep, time
 from math import atan2, asin, cos, sin
 from vpython import vector, arrow, color, box, compound, scene, cross
-# from sourceCode.IMU import *  # 假设 IMU 是一个模块,如果没有可以移除这行
 
 def send_command_to_unity(command):
     host = '127.0.0.1'  # Unity服务器的IP地址
diff --git a/integration/Window.py b/integration/Window.py
index fd9b1a27ef676159e91e1c41fb86961f31049a09..61f70dbc1ce55ac7fc39e66ca4aca88340aeeca6 100644
--- a/integration/Window.py
+++ b/integration/Window.py
@@ -1,5 +1,7 @@
 import threading
+import time
 import tkinter as tk
+from tkinter import font as tkFont
 from matplotlib.figure import Figure
 from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
 from mpl_toolkits.mplot3d.art3d import Poly3DCollection
@@ -8,6 +10,14 @@ import serial
 import serial.tools.list_ports
 from time import sleep
 import math
+import datetime
+import os
+import pandas as pd
+import torch
+from torch import nn
+from torch.utils.data import DataLoader, TensorDataset, random_split
+import pickle
+import joblib
 
 class Window:
     def __init__(self, root):
@@ -75,11 +85,11 @@ class Window:
                                    wraplength=self.width / 2)
             self.label1.place(relx=0.5, rely=0.9, anchor='center')
             # Add a button to start data transmission
-            self.start_button = tk.Button(self.frame2, text="connect", command=self.start_data_transmission)
-            self.start_button.place(relx=0.5, rely=0.8, anchor='center')
+            self.start_buttonConnect = tk.Button(self.frame2, text="connect", command=self.start_data_transmission)
+            self.start_buttonConnect.place(relx=0.5, rely=0.8, anchor='center')
 
-            self.start_button = tk.Button(self.frame2, text="Disconnect", command=self.disconnect)
-            self.start_button.place(relx=0.7, rely=0.8, anchor='center')
+            self.start_buttonDisConnect = tk.Button(self.frame2, text="Disconnect", command=self.disconnect)
+            self.start_buttonDisConnect.place(relx=0.7, rely=0.8, anchor='center')
 
         else:
             print("IMU is not connected")
@@ -447,7 +457,243 @@ class Window:
 
 
             # Call update_display() again after 50 milliseconds
-            self.EMG_display_id=self.root.after(50, self.EMG_Display)
+            self.EMG_display_id=self.root.after(1, self.EMG_Display)
+
+    def _decode(self, serial_data):
+        serial_string = serial_data.decode(errors="ignore")
+        adc_string_1 = ""
+        adc_string_2 = ""
+        self.adc_values = [0, 0]
+        if '\n' in serial_string:
+            # remove new line character
+            serial_string = serial_string.replace("\n", "")
+            if serial_string != '':
+                # Convert number to binary, placing 0s in empty spots
+                serial_string = format(int(serial_string, 10), "024b")
+
+                # Separate the input number from the data
+                for i0 in range(0, 12):
+                    adc_string_1 += serial_string[i0]
+                for i0 in range(12, 24):
+                    adc_string_2 += serial_string[i0]
+
+                self.adc_values[0] = int(adc_string_1, base=2)
+                self.adc_values[1] = int(adc_string_2, base=2)
+
+                return self.adc_values
+
+class WelcomeWindow:
+    def __init__(self, root):
+        self.root = root
+        self.root.title("Welcome")
+        self.width = 1000
+        self.height = 600
+        screen_width = self.root.winfo_screenwidth()
+        screen_height = self.root.winfo_screenheight()
+        x = (screen_width // 2) - (self.width // 2)
+        y = (screen_height // 2) - (self.height // 2)
+        self.root.geometry(f"{self.width}x{self.height}+{x}+{y}")
+
+        # Configure the grid to be expandable
+        self.root.columnconfigure(0, weight=1)
+        self.root.columnconfigure(1, weight=1)
+        self.root.rowconfigure(0, weight=1)
+        self.root.rowconfigure(1, weight=1)
+
+        self.frame1 = tk.Frame(self.root, borderwidth=1, relief="solid", width=self.width, height=self.height)
+        self.frame1.grid(row=0, column=0, columnspan=2, rowspan=2, sticky="nsew")
+        self.button1 = tk.Button(self.frame1, text="Start", command=self.startButton)
+        self.button1.place(relx=0.5, rely=0.8, anchor='center')
+
+    def startButton(self):
+        self.root.destroy()  # Close the welcome window
+        new_root = tk.Tk()
+        app = trainingInterface(new_root)
+        new_root.mainloop()
+
+class trainingInterface:
+    def __init__(self, root):
+        self.root = root
+        self.root.title("preparation Interface")
+        self.width = 1000
+        self.height = 600
+        self.width = 1000
+        self.height = 600
+        screen_width = self.root.winfo_screenwidth()
+        screen_height = self.root.winfo_screenheight()
+        x = (screen_width // 2) - (self.width // 2)
+        y = (screen_height // 2) - (self.height // 2)
+        self.root.geometry(f"{self.width}x{self.height}+{x}+{y}")
+        self.ports = [port.device for port in serial.tools.list_ports.comports()]
+
+        # Configure the grid to be expandable
+        self.root.columnconfigure(0, weight=1)
+        self.root.columnconfigure(1, weight=1)
+        self.root.rowconfigure(0, weight=1)
+        self.root.rowconfigure(1, weight=1)
+
+
+        # Create a frame
+        self.frame1 = tk.Frame(self.root, borderwidth=1, relief="solid", width=self.width, height=(self.height *2/ 3))
+        self.frame1.grid(row=0, column=0, padx=10, pady=10, sticky="nsew")
+
+
+        self.frame2 = tk.Frame(self.root, borderwidth=1, relief="solid", width=self.width, height=self.height *1/ 3)
+        self.frame2.grid(row=1, column=0, padx=10, pady=10, sticky="nsew")
+
+        self.initialEMGTraining()
+        if 'COM5' in self.ports:
+
+            self.emg_data_1 = [-1] * 41
+            self.emg_data_2 = [-1] * 41
+            self.savingData=[]
+            self.openHandButton=tk.Button(self.frame2,text="Hand Open",command=self.EMG_connect_HandOpen,width=15, height=2,font=("Helvetica", 12))
+            self.openHandButton.place(relx=0.3, rely=0.3, anchor='center')
+            self.handCloseButton=tk.Button(self.frame2,text="Hand Close",command=self.handCloseButton,width=15, height=2,font=("Helvetica", 12))
+            self.handCloseButton.place(relx=0.7, rely=0.3, anchor='center')
+            self.gameStartButton = tk.Button(self.frame2, text="Start", command=self.startButton, width=15,
+                                         height=2,font=("Helvetica", 12))
+            self.gameStartButton.place(relx=0.5, rely=0.5, anchor='center')
+        if 'COM5' not in self.ports:
+            self.label=tk.Label(self.frame2, text="No ports found, Please check the hardware connection",font=("Helvetica", 15))
+            self.label.place(relx=0.5, rely=0.3, anchor='center')
+
+    def startButton(self):
+        self.root.destroy()  # Close the welcome window
+        new_root = tk.Tk()
+        app = Window(new_root)
+        new_root.mainloop()
+
+    def EMG_connect_HandOpen(self):
+        self.arduino_EMG = serial.Serial('COM5', 9600, timeout=1)
+        gesture = "handOpen"
+        self.start_countdown(2)
+        self.displayAndsaveDate()
+
+
+    def handCloseButton(self):
+        self.arduino_EMG = serial.Serial('COM5', 9600, timeout=1)
+        gesture = "handOpen"
+        self.start_countdown_close(2)
+        self.displayAndsaveDate()
+
+
+    def EMG_disconnect(self):
+        if self.arduino_EMG is not None:
+            self.arduino_EMG.close()
+            self.arduino_EMG = None
+
+    def start_countdown(self, count):
+        if count > 0:
+            self.startSave=True
+            if count<11:
+             self.openHandButton.config(text=str(count))
+            self.frame2.after(1000, self.start_countdown, count - 1)
+        else:
+            self.openHandButton.config(text="Hand Open")
+            self.startSave = False
+            self.savedDataOpen = []
+            for i in self.savingData:
+                self.savedDataOpen.append(i)
+            print(self.savedDataOpen)
+            self.savingData.clear()
+            self.EMG_disconnect()
+
+    def start_countdown_close(self, count):
+        if count > 0:
+            self.startSave=True
+            if count<11:
+             self.handCloseButton.config(text=str(count))
+            self.frame2.after(1000, self.start_countdown_close, count - 1)
+        else:
+            self.handCloseButton.config(text="Hand Close")
+            self.startSave = False
+            self.savedDataClose=[]
+            print(self.savingData)
+            for i in self.savingData:
+             self.savedDataClose.append(i)
+            self.savingData.clear()
+            print(self.savedDataClose)
+            self.EMG_disconnect()
+            self.trainData()
+
+    def displayAndsaveDate(self):
+      if self.startSave:
+        try:
+            while (self.arduino_EMG.inWaiting() > 0) :
+                data = self.arduino_EMG.readline()
+                emg_data = self._decode(data)
+                if emg_data is not None:
+                    print(f"EMG 1: {emg_data[0]} , EMG 2: {emg_data[1]}")
+                    # Append the new data to the lists
+
+                    self.emg_data_1.append(emg_data[0])
+                    self.emg_data_1.pop(0)
+                    self.emg_data_2.append(emg_data[1])
+                    self.emg_data_2.pop(0)
+                    if self.startSave==True:
+                     self.savingData.append([emg_data[0],emg_data[1]])
+                     print(len(self.savingData))
+
+
+                    # Update the line data to shift the line from right to left
+                    self.line1.set_data(range(len(self.emg_data_1)), self.emg_data_1)
+                    self.line2.set_data(range(len(self.emg_data_2)), self.emg_data_2)
+
+                    # Redraw the canvas
+                    self.canvas1.draw()  # Redraw the canvas
+
+        except Exception as e:
+            print(f"An error occurred: {e}")
+
+        self.EMG_display_id = self.root.after(50, self.displayAndsaveDate)
+
+
+
+
+    def initialEMGTraining(self):
+        self.EMG_transmitting = False
+        fig = Figure(figsize=(self.frame1.winfo_width() / 100, self.frame1.winfo_height() / 100))
+        self.ax1 = fig.add_subplot(111)
+
+        self.ax1.set_title("Electromyography Envelope", fontsize=14, pad=0)
+        self.ax1.set_xlim(0, 5)
+        self.ax1.set_ylim(0, 5)
+        self.ax1.set_xlabel("Sample (20 samples per second)", fontsize=8, labelpad=-2)
+        self.ax1.set_ylabel("Magnitude", labelpad=0)
+        self.ax1.set_xticks(np.arange(0, 41, 8))
+        self.ax1.set_yticks(np.arange(0, 1001, 200))
+
+        for x_tick in self.ax1.get_xticks():
+            self.ax1.axvline(x_tick, color='gray', linestyle='--', linewidth=0.5)
+        for y_tick in self.ax1.get_yticks():
+            self.ax1.axhline(y_tick, color='gray', linestyle='--', linewidth=0.5)
+
+        self.line1, = self.ax1.plot([], [], color='red', label='Outer Wrist Muscle (Extensor Carpi Ulnaris)')
+        self.line2, = self.ax1.plot([], [], color='blue', label='Inner Wrist Muscle (Flexor Carpi Radialis)')
+        self.ax1.legend(fontsize=9, loc='upper right')
+
+        # Embed the plot in the tkinter frame
+        self.canvas1 = FigureCanvasTkAgg(fig, master=self.frame1)
+        self.canvas1.draw()
+        self.canvas1.get_tk_widget().pack(fill=tk.BOTH, expand=True)
+
+        # Bind the resizing event to the figure update
+        self.frame1.bind("<Configure>", self.on_frame_resize)
+
+    def on_frame_resize(self, event):
+        width = self.frame1.winfo_width()
+        height = self.frame1.winfo_height()
+        self.canvas1.get_tk_widget().config(width=width, height=height)
+        self.canvas1.draw()
+
+    '''
+    Train Data
+    '''
+    def trainData(self):
+        if (self.savedDataClose !=[])and (self.savedDataOpen!=[]):
+         train_Data=Train_Data()
+         train_Data.run(self.savedDataOpen, self.savedDataClose)
 
     def _decode(self, serial_data):
         serial_string = serial_data.decode(errors="ignore")
@@ -472,7 +718,108 @@ class Window:
 
                 return self.adc_values
 
+
+class Train_Data:
+    def __init__(self, num_inputs=2, num_outputs=2, num_hiddens=10, lr=0.001, num_epochs=300, batch_size=2010):
+        self.num_inputs = num_inputs
+        self.num_outputs = num_outputs
+        self.num_hiddens = num_hiddens
+        self.lr = lr
+        self.num_epochs = num_epochs
+        self.batch_size = batch_size
+
+        self.w1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True))
+        self.b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
+        self.w2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True))
+        self.b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))
+        self.params = [self.w1, self.b1, self.w2, self.b2]
+
+        self.loss = nn.CrossEntropyLoss()
+        self.updater = torch.optim.Adam(self.params, lr=self.lr)
+
+    def relu(self, x):
+        return torch.max(x, torch.zeros_like(x))
+
+    def net(self, x):
+        H = self.relu(x @ self.w1 + self.b1)
+        return H @ self.w2 + self.b2
+
+    def dataManage(self, savedDataOpen, savedDataClose):
+        # Combine the data and create labels
+        data = savedDataOpen + savedDataClose
+        labels = [0] * len(savedDataOpen) + [1] * len(savedDataClose)  # 0 for HandOpen, 1 for HandClose
+
+        # Convert to tensor
+        data_tensor = torch.tensor(data, dtype=torch.float32)
+        labels_tensor = torch.tensor(labels, dtype=torch.long)
+
+        # Create TensorDataset
+        dataset = TensorDataset(data_tensor, labels_tensor)
+        total_size = len(dataset)
+        train_size = int(0.7 * total_size)
+        test_size = total_size - train_size
+
+        # Split the data into training and testing sets
+        train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
+
+        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
+        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
+
+        return train_loader, test_loader
+
+
+    def train(self,savedDataOpen,savedDataClose):
+        train_loader, test_loader = self.dataManage(savedDataOpen,savedDataClose)
+
+        for epoch in range(self.num_epochs):
+            for X, y in train_loader:
+                y_hat = self.net(X)
+                l = self.loss(y_hat, y)
+                self.updater.zero_grad()
+                l.backward()
+                self.updater.step()
+
+            with torch.no_grad():
+                train_loss = sum(self.loss(self.net(X), y).item() for X, y in train_loader) / len(train_loader)
+            print(f'epoch {epoch + 1}, train loss {train_loss:.3f}')
+
+            self.evaluate_accuracy(test_loader)
+
+    def evaluate_accuracy(self, test_loader):
+        correct = 0
+        total = 0
+        with torch.no_grad():
+            for X, y in test_loader:
+                y_hat = self.net(X)
+                _, predicted = torch.max(y_hat, 1)
+                total += y.size(0)
+                correct += (predicted == y).sum().item()
+
+        accuracy = correct / total
+        print(f'Test Accuracy: {accuracy:.4f}')
+
+    def save_model(self, file_path='trained_mlp.pkl'):
+        joblib.dump(self.net, file_path)
+
+    def load_model(self, file_path='trained_mlp.pkl'):
+        self.loaded_model = joblib.load(file_path)
+        return self.loaded_model
+
+    def run(self,savedDataOpen,savedDataClose):
+        self.train(savedDataOpen,savedDataClose)
+        self.save_model()
+
+    def predict(self, X):
+        self.load_model()  # Ensure the model parameters are loaded
+        with torch.no_grad():
+            y_hat = self.net(X)
+            _, predicted = torch.max(y_hat, 1)
+        return predicted
+
+
+
+
 if __name__ == "__main__":
-    root = tk.Tk()
-    app = Window(root)
-    root.mainloop()
\ No newline at end of file
+    root1 = tk.Tk()
+    appWelcome = WelcomeWindow(root1)
+    root1.mainloop()
diff --git a/integration/Window2.py b/integration/Window2.py
index b88a861c80fbe136a7859423dd764e43fe501564..084839745f9fa5b4812136d3585d41257ebf1440 100644
--- a/integration/Window2.py
+++ b/integration/Window2.py
@@ -80,11 +80,11 @@ class Window:
                                    wraplength=self.width / 2)
             self.label1.place(relx=0.5, rely=0.9, anchor='center')
             # Add a button to start data transmission
-            self.start_button = tk.Button(self.frame2, text="connect", command=self.start_data_transmission)
-            self.start_button.place(relx=0.5, rely=0.8, anchor='center')
+            self.start_buttonConnect = tk.Button(self.frame2, text="connect", command=self.start_data_transmission)
+            self.start_buttonConnect.place(relx=0.5, rely=0.8, anchor='center')
 
-            self.start_button = tk.Button(self.frame2, text="Disconnect", command=self.disconnect)
-            self.start_button.place(relx=0.7, rely=0.8, anchor='center')
+            self.start_buttonDisConnect = tk.Button(self.frame2, text="Disconnect", command=self.disconnect)
+            self.start_buttonDisConnect.place(relx=0.7, rely=0.8, anchor='center')
 
         else:
             print("IMU is not connected")
@@ -328,7 +328,8 @@ class Window:
 
                     if len(row) == 6:
                         splitPacket = cleandata.split(',')
-
+                        emg1 = float(splitPacket[0])  # emg sensor data
+                        emg2 = float(splitPacket[1])
                         q0 = float(splitPacket[2])  # qw
                         q1 = float(splitPacket[3])  # qx
                         q2 = float(splitPacket[4])  # qy
@@ -342,6 +343,8 @@ class Window:
                         self.roll_label.config( text="roll is : "+str(roll))
                         self.pitch_label.config(text="pitch is : "+str(pitch))
                         self.yaw_label.config(text="yaw is : "+str(yaw))
+                        print(f"EMG1 is: {emg1}")
+                        print(f"EMG2 is: {emg2}")
 
                         # Rotation matrices
                         Rz = np.array([