Skip to content
Snippets Groups Projects
Commit 882eb335 authored by ym13n22's avatar ym13n22
Browse files

add mlp_model to integration

parent 343b3525
No related branches found
No related tags found
No related merge requests found
import os
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split
import numpy as np
import pickle
import joblib
class MLPClassifier:
def __init__(self, num_inputs=2, num_outputs=2, num_hiddens=10, lr=0.001, num_epochs=300, batch_size=2010,
data_dir='Data/Data'):
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.num_hiddens = num_hiddens
self.lr = lr
self.num_epochs = num_epochs
self.batch_size = batch_size
self.data_dir = data_dir
self.w1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True))
self.b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
self.w2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True))
self.b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))
self.params = [self.w1, self.b1, self.w2, self.b2]
self.loss = nn.CrossEntropyLoss()
self.updater = torch.optim.Adam(self.params, lr=self.lr)
def relu(self, x):
return torch.max(x, torch.zeros_like(x))
def net(self, x):
H = self.relu(x @ self.w1 + self.b1)
return H @ self.w2 + self.b2
def load_data(self):
classes = ['HandClosed', 'HandOpen']
data = []
labels = []
for label, class_name in enumerate(classes):
class_dir = os.path.join(self.data_dir, class_name)
for file_name in os.listdir(class_dir):
if file_name.endswith('.txt'):
file_path = os.path.join(class_dir, file_name)
df = pd.read_csv(file_path, delimiter='\t')
emg_data = df[['EMG1', 'EMG2']].values
data.append(emg_data)
labels.extend([label] * len(emg_data))
data = torch.tensor(np.vstack(data), dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)
print(labels)
return data, labels
def load_data_loaders(self):
data, labels = self.load_data()
dataset = TensorDataset(data, labels)
total_size = len(dataset)
train_size = int(0.7 * total_size)
test_size = total_size - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
return train_loader, test_loader
def train(self):
train_loader, test_loader = self.load_data_loaders()
for epoch in range(self.num_epochs):
for X, y in train_loader:
y_hat = self.net(X)
l = self.loss(y_hat, y)
self.updater.zero_grad()
l.backward()
self.updater.step()
with torch.no_grad():
train_loss = sum(self.loss(self.net(X), y).item() for X, y in train_loader) / len(train_loader)
print(f'epoch {epoch + 1}, train loss {train_loss:.3f}')
self.evaluate_accuracy(test_loader)
def evaluate_accuracy(self, test_loader):
correct = 0
total = 0
with torch.no_grad():
for X, y in test_loader:
y_hat = self.net(X)
_, predicted = torch.max(y_hat, 1)
total += y.size(0)
correct += (predicted == y).sum().item()
accuracy = correct / total
print(f'Test Accuracy: {accuracy:.4f}')
def save_model(self, file_path='trained_mlp.pkl'):
joblib.dump(self.net, file_path)
def load_model(self, file_path='trained_mlp.pkl'):
self.loaded_model = joblib.load(file_path)
return self.loaded_model
def run(self):
self.train()
self.save_model()
def predict(self, X):
self.load_model() # Ensure the model parameters are loaded
with torch.no_grad():
y_hat = self.net(X)
_, predicted = torch.max(y_hat, 1)
return predicted
if __name__ == '__main__':
mlp_classifier = MLPClassifier()
#mlp_classifier.run()
file_path='trained_mlp.pkl'
if not os.path.exists(file_path):
mlp_classifier.run()
'''
np.random.seed(42)
num_random_points = 1000
random_data = np.random.uniform(low=0, high=100, size=(num_random_points, 2))
test_data = torch.tensor(random_data, dtype=torch.float32)
predictions = mlp_classifier.predict(test_data)
print("Predictions:", predictions.numpy())
'''
test_data = torch.tensor([[32,5], [20,5],[30,6],[29,23],[24,6]], dtype=torch.float32)
predictions = mlp_classifier.predict(test_data)
print("Predictions:", predictions.numpy())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment