Skip to content
Snippets Groups Projects
Commit d7825dbb authored by ym13n22's avatar ym13n22
Browse files

modify train_file

parent 882eb335
No related branches found
No related tags found
No related merge requests found
......@@ -9,7 +9,7 @@ import joblib
class MLPClassifier:
def __init__(self, num_inputs=2, num_outputs=4, num_hiddens=10, lr=0.01, num_epochs=400, batch_size=2010,
def __init__(self, num_inputs=2, num_outputs=2, num_hiddens=10, lr=0.001, num_epochs=300, batch_size=2010,
data_dir='Data/Data'):
self.num_inputs = num_inputs
self.num_outputs = num_outputs
......@@ -36,7 +36,7 @@ class MLPClassifier:
return H @ self.w2 + self.b2
def load_data(self):
classes = ['HandClosed', 'HandOpen', 'WristExtension', 'WristFlexation']
classes = ['HandClosed', 'HandOpen']
data = []
labels = []
for label, class_name in enumerate(classes):
......@@ -51,6 +51,7 @@ class MLPClassifier:
data = torch.tensor(np.vstack(data), dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)
print(labels)
return data, labels
def load_data_loaders(self):
......@@ -119,18 +120,30 @@ class MLPClassifier:
if __name__ == '__main__':
mlp_classifier = MLPClassifier()
#mlp_classifier.run()
file_path='trained_mlp.pkl'
if not os.path.exists(file_path):
mlp_classifier.run()
test_data = torch.tensor([[97,58], [25,34], [42,7],[298,9],[10,39]], dtype=torch.float32)
'''
np.random.seed(42)
num_random_points = 1000
random_data = np.random.uniform(low=0, high=100, size=(num_random_points, 2))
test_data = torch.tensor(random_data, dtype=torch.float32)
predictions = mlp_classifier.predict(test_data)
print("Predictions:", predictions.numpy())
'''
test_data = torch.tensor([[32,5], [20,5],[30,6],[29,23],[24,6]], dtype=torch.float32)
# 进行预测
predictions = mlp_classifier.predict(test_data)
# 输出预测结果
print("Predictions:", predictions.numpy())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment