diff --git a/IMU/code/connect.py b/IMU/code/connect.py index 69c7d087617e7db084da10facd9e347de6222d7c..df84b83f300e38d6787387d35f29a81f9eba86c7 100644 --- a/IMU/code/connect.py +++ b/IMU/code/connect.py @@ -4,7 +4,6 @@ import numpy as np from time import sleep, time from math import atan2, asin, cos, sin from vpython import vector, arrow, color, box, compound, scene, cross -# from sourceCode.IMU import * # 假设 IMU 是一个模块,如果没有可以移除这行 def send_command_to_unity(command): host = '127.0.0.1' # Unity服务器的IP地址 diff --git a/integration/Window.py b/integration/Window.py index fd9b1a27ef676159e91e1c41fb86961f31049a09..61f70dbc1ce55ac7fc39e66ca4aca88340aeeca6 100644 --- a/integration/Window.py +++ b/integration/Window.py @@ -1,5 +1,7 @@ import threading +import time import tkinter as tk +from tkinter import font as tkFont from matplotlib.figure import Figure from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg from mpl_toolkits.mplot3d.art3d import Poly3DCollection @@ -8,6 +10,14 @@ import serial import serial.tools.list_ports from time import sleep import math +import datetime +import os +import pandas as pd +import torch +from torch import nn +from torch.utils.data import DataLoader, TensorDataset, random_split +import pickle +import joblib class Window: def __init__(self, root): @@ -75,11 +85,11 @@ class Window: wraplength=self.width / 2) self.label1.place(relx=0.5, rely=0.9, anchor='center') # Add a button to start data transmission - self.start_button = tk.Button(self.frame2, text="connect", command=self.start_data_transmission) - self.start_button.place(relx=0.5, rely=0.8, anchor='center') + self.start_buttonConnect = tk.Button(self.frame2, text="connect", command=self.start_data_transmission) + self.start_buttonConnect.place(relx=0.5, rely=0.8, anchor='center') - self.start_button = tk.Button(self.frame2, text="Disconnect", command=self.disconnect) - self.start_button.place(relx=0.7, rely=0.8, anchor='center') + self.start_buttonDisConnect = tk.Button(self.frame2, text="Disconnect", command=self.disconnect) + self.start_buttonDisConnect.place(relx=0.7, rely=0.8, anchor='center') else: print("IMU is not connected") @@ -447,7 +457,243 @@ class Window: # Call update_display() again after 50 milliseconds - self.EMG_display_id=self.root.after(50, self.EMG_Display) + self.EMG_display_id=self.root.after(1, self.EMG_Display) + + def _decode(self, serial_data): + serial_string = serial_data.decode(errors="ignore") + adc_string_1 = "" + adc_string_2 = "" + self.adc_values = [0, 0] + if '\n' in serial_string: + # remove new line character + serial_string = serial_string.replace("\n", "") + if serial_string != '': + # Convert number to binary, placing 0s in empty spots + serial_string = format(int(serial_string, 10), "024b") + + # Separate the input number from the data + for i0 in range(0, 12): + adc_string_1 += serial_string[i0] + for i0 in range(12, 24): + adc_string_2 += serial_string[i0] + + self.adc_values[0] = int(adc_string_1, base=2) + self.adc_values[1] = int(adc_string_2, base=2) + + return self.adc_values + +class WelcomeWindow: + def __init__(self, root): + self.root = root + self.root.title("Welcome") + self.width = 1000 + self.height = 600 + screen_width = self.root.winfo_screenwidth() + screen_height = self.root.winfo_screenheight() + x = (screen_width // 2) - (self.width // 2) + y = (screen_height // 2) - (self.height // 2) + self.root.geometry(f"{self.width}x{self.height}+{x}+{y}") + + # Configure the grid to be expandable + self.root.columnconfigure(0, weight=1) + self.root.columnconfigure(1, weight=1) + self.root.rowconfigure(0, weight=1) + self.root.rowconfigure(1, weight=1) + + self.frame1 = tk.Frame(self.root, borderwidth=1, relief="solid", width=self.width, height=self.height) + self.frame1.grid(row=0, column=0, columnspan=2, rowspan=2, sticky="nsew") + self.button1 = tk.Button(self.frame1, text="Start", command=self.startButton) + self.button1.place(relx=0.5, rely=0.8, anchor='center') + + def startButton(self): + self.root.destroy() # Close the welcome window + new_root = tk.Tk() + app = trainingInterface(new_root) + new_root.mainloop() + +class trainingInterface: + def __init__(self, root): + self.root = root + self.root.title("preparation Interface") + self.width = 1000 + self.height = 600 + self.width = 1000 + self.height = 600 + screen_width = self.root.winfo_screenwidth() + screen_height = self.root.winfo_screenheight() + x = (screen_width // 2) - (self.width // 2) + y = (screen_height // 2) - (self.height // 2) + self.root.geometry(f"{self.width}x{self.height}+{x}+{y}") + self.ports = [port.device for port in serial.tools.list_ports.comports()] + + # Configure the grid to be expandable + self.root.columnconfigure(0, weight=1) + self.root.columnconfigure(1, weight=1) + self.root.rowconfigure(0, weight=1) + self.root.rowconfigure(1, weight=1) + + + # Create a frame + self.frame1 = tk.Frame(self.root, borderwidth=1, relief="solid", width=self.width, height=(self.height *2/ 3)) + self.frame1.grid(row=0, column=0, padx=10, pady=10, sticky="nsew") + + + self.frame2 = tk.Frame(self.root, borderwidth=1, relief="solid", width=self.width, height=self.height *1/ 3) + self.frame2.grid(row=1, column=0, padx=10, pady=10, sticky="nsew") + + self.initialEMGTraining() + if 'COM5' in self.ports: + + self.emg_data_1 = [-1] * 41 + self.emg_data_2 = [-1] * 41 + self.savingData=[] + self.openHandButton=tk.Button(self.frame2,text="Hand Open",command=self.EMG_connect_HandOpen,width=15, height=2,font=("Helvetica", 12)) + self.openHandButton.place(relx=0.3, rely=0.3, anchor='center') + self.handCloseButton=tk.Button(self.frame2,text="Hand Close",command=self.handCloseButton,width=15, height=2,font=("Helvetica", 12)) + self.handCloseButton.place(relx=0.7, rely=0.3, anchor='center') + self.gameStartButton = tk.Button(self.frame2, text="Start", command=self.startButton, width=15, + height=2,font=("Helvetica", 12)) + self.gameStartButton.place(relx=0.5, rely=0.5, anchor='center') + if 'COM5' not in self.ports: + self.label=tk.Label(self.frame2, text="No ports found, Please check the hardware connection",font=("Helvetica", 15)) + self.label.place(relx=0.5, rely=0.3, anchor='center') + + def startButton(self): + self.root.destroy() # Close the welcome window + new_root = tk.Tk() + app = Window(new_root) + new_root.mainloop() + + def EMG_connect_HandOpen(self): + self.arduino_EMG = serial.Serial('COM5', 9600, timeout=1) + gesture = "handOpen" + self.start_countdown(2) + self.displayAndsaveDate() + + + def handCloseButton(self): + self.arduino_EMG = serial.Serial('COM5', 9600, timeout=1) + gesture = "handOpen" + self.start_countdown_close(2) + self.displayAndsaveDate() + + + def EMG_disconnect(self): + if self.arduino_EMG is not None: + self.arduino_EMG.close() + self.arduino_EMG = None + + def start_countdown(self, count): + if count > 0: + self.startSave=True + if count<11: + self.openHandButton.config(text=str(count)) + self.frame2.after(1000, self.start_countdown, count - 1) + else: + self.openHandButton.config(text="Hand Open") + self.startSave = False + self.savedDataOpen = [] + for i in self.savingData: + self.savedDataOpen.append(i) + print(self.savedDataOpen) + self.savingData.clear() + self.EMG_disconnect() + + def start_countdown_close(self, count): + if count > 0: + self.startSave=True + if count<11: + self.handCloseButton.config(text=str(count)) + self.frame2.after(1000, self.start_countdown_close, count - 1) + else: + self.handCloseButton.config(text="Hand Close") + self.startSave = False + self.savedDataClose=[] + print(self.savingData) + for i in self.savingData: + self.savedDataClose.append(i) + self.savingData.clear() + print(self.savedDataClose) + self.EMG_disconnect() + self.trainData() + + def displayAndsaveDate(self): + if self.startSave: + try: + while (self.arduino_EMG.inWaiting() > 0) : + data = self.arduino_EMG.readline() + emg_data = self._decode(data) + if emg_data is not None: + print(f"EMG 1: {emg_data[0]} , EMG 2: {emg_data[1]}") + # Append the new data to the lists + + self.emg_data_1.append(emg_data[0]) + self.emg_data_1.pop(0) + self.emg_data_2.append(emg_data[1]) + self.emg_data_2.pop(0) + if self.startSave==True: + self.savingData.append([emg_data[0],emg_data[1]]) + print(len(self.savingData)) + + + # Update the line data to shift the line from right to left + self.line1.set_data(range(len(self.emg_data_1)), self.emg_data_1) + self.line2.set_data(range(len(self.emg_data_2)), self.emg_data_2) + + # Redraw the canvas + self.canvas1.draw() # Redraw the canvas + + except Exception as e: + print(f"An error occurred: {e}") + + self.EMG_display_id = self.root.after(50, self.displayAndsaveDate) + + + + + def initialEMGTraining(self): + self.EMG_transmitting = False + fig = Figure(figsize=(self.frame1.winfo_width() / 100, self.frame1.winfo_height() / 100)) + self.ax1 = fig.add_subplot(111) + + self.ax1.set_title("Electromyography Envelope", fontsize=14, pad=0) + self.ax1.set_xlim(0, 5) + self.ax1.set_ylim(0, 5) + self.ax1.set_xlabel("Sample (20 samples per second)", fontsize=8, labelpad=-2) + self.ax1.set_ylabel("Magnitude", labelpad=0) + self.ax1.set_xticks(np.arange(0, 41, 8)) + self.ax1.set_yticks(np.arange(0, 1001, 200)) + + for x_tick in self.ax1.get_xticks(): + self.ax1.axvline(x_tick, color='gray', linestyle='--', linewidth=0.5) + for y_tick in self.ax1.get_yticks(): + self.ax1.axhline(y_tick, color='gray', linestyle='--', linewidth=0.5) + + self.line1, = self.ax1.plot([], [], color='red', label='Outer Wrist Muscle (Extensor Carpi Ulnaris)') + self.line2, = self.ax1.plot([], [], color='blue', label='Inner Wrist Muscle (Flexor Carpi Radialis)') + self.ax1.legend(fontsize=9, loc='upper right') + + # Embed the plot in the tkinter frame + self.canvas1 = FigureCanvasTkAgg(fig, master=self.frame1) + self.canvas1.draw() + self.canvas1.get_tk_widget().pack(fill=tk.BOTH, expand=True) + + # Bind the resizing event to the figure update + self.frame1.bind("<Configure>", self.on_frame_resize) + + def on_frame_resize(self, event): + width = self.frame1.winfo_width() + height = self.frame1.winfo_height() + self.canvas1.get_tk_widget().config(width=width, height=height) + self.canvas1.draw() + + ''' + Train Data + ''' + def trainData(self): + if (self.savedDataClose !=[])and (self.savedDataOpen!=[]): + train_Data=Train_Data() + train_Data.run(self.savedDataOpen, self.savedDataClose) def _decode(self, serial_data): serial_string = serial_data.decode(errors="ignore") @@ -472,7 +718,108 @@ class Window: return self.adc_values + +class Train_Data: + def __init__(self, num_inputs=2, num_outputs=2, num_hiddens=10, lr=0.001, num_epochs=300, batch_size=2010): + self.num_inputs = num_inputs + self.num_outputs = num_outputs + self.num_hiddens = num_hiddens + self.lr = lr + self.num_epochs = num_epochs + self.batch_size = batch_size + + self.w1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True)) + self.b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True)) + self.w2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True)) + self.b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True)) + self.params = [self.w1, self.b1, self.w2, self.b2] + + self.loss = nn.CrossEntropyLoss() + self.updater = torch.optim.Adam(self.params, lr=self.lr) + + def relu(self, x): + return torch.max(x, torch.zeros_like(x)) + + def net(self, x): + H = self.relu(x @ self.w1 + self.b1) + return H @ self.w2 + self.b2 + + def dataManage(self, savedDataOpen, savedDataClose): + # Combine the data and create labels + data = savedDataOpen + savedDataClose + labels = [0] * len(savedDataOpen) + [1] * len(savedDataClose) # 0 for HandOpen, 1 for HandClose + + # Convert to tensor + data_tensor = torch.tensor(data, dtype=torch.float32) + labels_tensor = torch.tensor(labels, dtype=torch.long) + + # Create TensorDataset + dataset = TensorDataset(data_tensor, labels_tensor) + total_size = len(dataset) + train_size = int(0.7 * total_size) + test_size = total_size - train_size + + # Split the data into training and testing sets + train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) + + train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False) + + return train_loader, test_loader + + + def train(self,savedDataOpen,savedDataClose): + train_loader, test_loader = self.dataManage(savedDataOpen,savedDataClose) + + for epoch in range(self.num_epochs): + for X, y in train_loader: + y_hat = self.net(X) + l = self.loss(y_hat, y) + self.updater.zero_grad() + l.backward() + self.updater.step() + + with torch.no_grad(): + train_loss = sum(self.loss(self.net(X), y).item() for X, y in train_loader) / len(train_loader) + print(f'epoch {epoch + 1}, train loss {train_loss:.3f}') + + self.evaluate_accuracy(test_loader) + + def evaluate_accuracy(self, test_loader): + correct = 0 + total = 0 + with torch.no_grad(): + for X, y in test_loader: + y_hat = self.net(X) + _, predicted = torch.max(y_hat, 1) + total += y.size(0) + correct += (predicted == y).sum().item() + + accuracy = correct / total + print(f'Test Accuracy: {accuracy:.4f}') + + def save_model(self, file_path='trained_mlp.pkl'): + joblib.dump(self.net, file_path) + + def load_model(self, file_path='trained_mlp.pkl'): + self.loaded_model = joblib.load(file_path) + return self.loaded_model + + def run(self,savedDataOpen,savedDataClose): + self.train(savedDataOpen,savedDataClose) + self.save_model() + + def predict(self, X): + self.load_model() # Ensure the model parameters are loaded + with torch.no_grad(): + y_hat = self.net(X) + _, predicted = torch.max(y_hat, 1) + return predicted + + + + if __name__ == "__main__": - root = tk.Tk() - app = Window(root) - root.mainloop() \ No newline at end of file + root1 = tk.Tk() + appWelcome = WelcomeWindow(root1) + root1.mainloop() diff --git a/integration/Window2.py b/integration/Window2.py index b88a861c80fbe136a7859423dd764e43fe501564..084839745f9fa5b4812136d3585d41257ebf1440 100644 --- a/integration/Window2.py +++ b/integration/Window2.py @@ -80,11 +80,11 @@ class Window: wraplength=self.width / 2) self.label1.place(relx=0.5, rely=0.9, anchor='center') # Add a button to start data transmission - self.start_button = tk.Button(self.frame2, text="connect", command=self.start_data_transmission) - self.start_button.place(relx=0.5, rely=0.8, anchor='center') + self.start_buttonConnect = tk.Button(self.frame2, text="connect", command=self.start_data_transmission) + self.start_buttonConnect.place(relx=0.5, rely=0.8, anchor='center') - self.start_button = tk.Button(self.frame2, text="Disconnect", command=self.disconnect) - self.start_button.place(relx=0.7, rely=0.8, anchor='center') + self.start_buttonDisConnect = tk.Button(self.frame2, text="Disconnect", command=self.disconnect) + self.start_buttonDisConnect.place(relx=0.7, rely=0.8, anchor='center') else: print("IMU is not connected") @@ -328,7 +328,8 @@ class Window: if len(row) == 6: splitPacket = cleandata.split(',') - + emg1 = float(splitPacket[0]) # emg sensor data + emg2 = float(splitPacket[1]) q0 = float(splitPacket[2]) # qw q1 = float(splitPacket[3]) # qx q2 = float(splitPacket[4]) # qy @@ -342,6 +343,8 @@ class Window: self.roll_label.config( text="roll is : "+str(roll)) self.pitch_label.config(text="pitch is : "+str(pitch)) self.yaw_label.config(text="yaw is : "+str(yaw)) + print(f"EMG1 is: {emg1}") + print(f"EMG2 is: {emg2}") # Rotation matrices Rz = np.array([