Skip to content
Snippets Groups Projects
Commit 99cfcfb4 authored by ym13n22's avatar ym13n22
Browse files

train data

parent 15234b08
No related branches found
No related tags found
No related merge requests found
...@@ -108,11 +108,27 @@ class MLPClassifier: ...@@ -108,11 +108,27 @@ class MLPClassifier:
self.train() self.train()
self.save_model() self.save_model()
def predict(self, X):
self.load_model() # Ensure the model parameters are loaded
with torch.no_grad():
y_hat = self.net(X)
_, predicted = torch.max(y_hat, 1)
return predicted
if __name__ == '__main__': if __name__ == '__main__':
mlp_classifier = MLPClassifier() mlp_classifier = MLPClassifier()
file_path='trained_mlp.pkl' file_path='trained_mlp.pkl'
mlp_classifier.run() if not os.path.exists(file_path):
mlp_classifier.run()
test_data = torch.tensor([[97,58], [25,34], [42,7],[298,9],[10,39]], dtype=torch.float32)
# 进行预测
predictions = mlp_classifier.predict(test_data)
# 输出预测结果
print("Predictions:", predictions.numpy())
......
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment