From 343b3525110860f0ce586334a15a695081931064 Mon Sep 17 00:00:00 2001
From: ym13n22 <ym13n22@soton.ac.uk>
Date: Thu, 25 Jul 2024 12:26:42 +0100
Subject: [PATCH] add predict version

---
 Windows/Window.py | 128 ++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 128 insertions(+)

diff --git a/Windows/Window.py b/Windows/Window.py
index 775ae41..239272a 100644
--- a/Windows/Window.py
+++ b/Windows/Window.py
@@ -11,6 +11,14 @@ from time import sleep
 import math
 import mlp_model
 import torch
+import os
+import pandas as pd
+import torch
+from torch import nn
+from torch.utils.data import DataLoader, TensorDataset, random_split
+import numpy as np
+import pickle
+import joblib
 
 class Window:
     def __init__(self, root):
@@ -266,6 +274,9 @@ class Window:
         self.gesture_label = tk.Label(self.frame3, text=f"Gesture is :")
         self.gesture_label.config(font=("Arial", 12))
         self.gesture_label.place(relx=0.1, rely=0.6, anchor='w')
+        self.gesture_predict_label = tk.Label(self.frame3, text="")
+        self.gesture_predict_label.config(font=("Arial", 12))
+        self.gesture_predict_label.place(relx=0.2, rely=0.7, anchor='w')
 
 
 
@@ -432,9 +443,20 @@ class Window:
                     test_data = torch.tensor([[emg_data[0], emg_data[0]]], dtype=torch.float32)
                     predictions=self.mlp_classifier.predict(test_data)
                     print("Predictions:", predictions.numpy())
+                    predict=""
+                    if (predictions.numpy()[0]==0):
+                      predict="HandClosed"
+                    if (predictions.numpy()[0]==1):
+                      predict="HandOpen"
+                    if (predictions.numpy()[0]==2):
+                      predict="WristExtension"
+                    if (predictions.numpy()[0]==3):
+                      predict="WristFlexation"
+
 
                     self.outer_EMG_Number.config(text=f"{emg_data[0]}")
                     self.inner_EMG_Number.config(text=f"{emg_data[1]}")
+                    self.gesture_predict_label.config(text=f"{predict}")
 
 
 
@@ -483,6 +505,112 @@ class Window:
                 self.adc_values[1] = int(adc_string_2, base=2)
 
                 return self.adc_values
+class MLPClassifier:
+    def __init__(self, num_inputs=2, num_outputs=4, num_hiddens=10, lr=0.01, num_epochs=400, batch_size=2010,
+                 data_dir='Data/Data'):
+        self.num_inputs = num_inputs
+        self.num_outputs = num_outputs
+        self.num_hiddens = num_hiddens
+        self.lr = lr
+        self.num_epochs = num_epochs
+        self.batch_size = batch_size
+        self.data_dir = data_dir
+
+        self.w1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True))
+        self.b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
+        self.w2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True))
+        self.b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))
+        self.params = [self.w1, self.b1, self.w2, self.b2]
+
+        self.loss = nn.CrossEntropyLoss()
+        self.updater = torch.optim.Adam(self.params, lr=self.lr)
+
+    def relu(self, x):
+        return torch.max(x, torch.zeros_like(x))
+
+    def net(self, x):
+        H = self.relu(x @ self.w1 + self.b1)
+        return H @ self.w2 + self.b2
+
+    def load_data(self):
+        classes = ['HandClosed', 'HandOpen', 'WristExtension', 'WristFlexation']
+        data = []
+        labels = []
+        for label, class_name in enumerate(classes):
+            class_dir = os.path.join(self.data_dir, class_name)
+            for file_name in os.listdir(class_dir):
+                if file_name.endswith('.txt'):
+                    file_path = os.path.join(class_dir, file_name)
+                    df = pd.read_csv(file_path, delimiter='\t')
+                    emg_data = df[['EMG1', 'EMG2']].values
+                    data.append(emg_data)
+                    labels.extend([label] * len(emg_data))
+
+        data = torch.tensor(np.vstack(data), dtype=torch.float32)
+        labels = torch.tensor(labels, dtype=torch.long)
+        return data, labels
+
+    def load_data_loaders(self):
+        data, labels = self.load_data()
+
+        dataset = TensorDataset(data, labels)
+        total_size = len(dataset)
+        train_size = int(0.7 * total_size)
+        test_size = total_size - train_size
+        train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
+
+        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
+        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
+
+        return train_loader, test_loader
+
+    def train(self):
+        train_loader, test_loader = self.load_data_loaders()
+
+        for epoch in range(self.num_epochs):
+            for X, y in train_loader:
+                y_hat = self.net(X)
+                l = self.loss(y_hat, y)
+                self.updater.zero_grad()
+                l.backward()
+                self.updater.step()
+
+            with torch.no_grad():
+                train_loss = sum(self.loss(self.net(X), y).item() for X, y in train_loader) / len(train_loader)
+            print(f'epoch {epoch + 1}, train loss {train_loss:.3f}')
+
+            self.evaluate_accuracy(test_loader)
+
+    def evaluate_accuracy(self, test_loader):
+        correct = 0
+        total = 0
+        with torch.no_grad():
+            for X, y in test_loader:
+                y_hat = self.net(X)
+                _, predicted = torch.max(y_hat, 1)
+                total += y.size(0)
+                correct += (predicted == y).sum().item()
+
+        accuracy = correct / total
+        print(f'Test Accuracy: {accuracy:.4f}')
+
+    def save_model(self, file_path='trained_mlp.pkl'):
+        joblib.dump(self.net, file_path)
+
+    def load_model(self, file_path='trained_mlp.pkl'):
+        self.loaded_model = joblib.load(file_path)
+        return self.loaded_model
+
+    def run(self):
+        self.train()
+        self.save_model()
+
+    def predict(self, X):
+        self.load_model()  # Ensure the model parameters are loaded
+        with torch.no_grad():
+            y_hat = self.net(X)
+            _, predicted = torch.max(y_hat, 1)
+        return predicted
 
 if __name__ == "__main__":
     root = tk.Tk()
-- 
GitLab