diff --git a/Windows/mlp_model.py b/Windows/mlp_model.py index d93feec22595c75fd888981fe1fe13dbce1466e3..9401ceaf7ee10043735517151c3fe736895d905d 100644 --- a/Windows/mlp_model.py +++ b/Windows/mlp_model.py @@ -108,11 +108,27 @@ class MLPClassifier: self.train() self.save_model() + def predict(self, X): + self.load_model() # Ensure the model parameters are loaded + with torch.no_grad(): + y_hat = self.net(X) + _, predicted = torch.max(y_hat, 1) + return predicted + + if __name__ == '__main__': mlp_classifier = MLPClassifier() file_path='trained_mlp.pkl' - mlp_classifier.run() + if not os.path.exists(file_path): + mlp_classifier.run() + test_data = torch.tensor([[97,58], [25,34], [42,7],[298,9],[10,39]], dtype=torch.float32) + + # 进行预测 + predictions = mlp_classifier.predict(test_data) + + # 输出预测结果 + print("Predictions:", predictions.numpy()) diff --git a/Windows/trained_mlp.pkl b/Windows/trained_mlp.pkl index 7719b6dbe8ef20f94b6c89ec71a8ed500902906a..7c0431e06e4de36b9fe4631f7b23032980e9dca0 100644 Binary files a/Windows/trained_mlp.pkl and b/Windows/trained_mlp.pkl differ