Skip to content
Snippets Groups Projects
Commit 8900fcaf authored by Andrei Vasile's avatar Andrei Vasile
Browse files

Write main program script for training and predicting

Done via CLI
parent a30c0758
Branches
No related tags found
No related merge requests found
from pathlib import Path
from pandas import read_csv
from autogluon.tabular import TabularDataset, TabularPredictor
from datetime import datetime
datasets = {
"a": "DataBinary",
"b": "DataMulti"
}
def main(): def main():
print("Training and predictor") print("Training and predictor")
i = input("Which set would you like to work on?\n[A] DataBinary\n[B] DataMulti\n")
i = i.lower()
if i!="a" and i!="b":
print("I don't know that dataset!")
exit()
handleDataset(i)
def handleDataset(ds):
i = input(f"What would you like to do to {ds}?\n[A] Predict\n[B] Train\n").lower()
if i=='b':
handleTrain(ds)
elif i =='a':
handlePredict(ds)
def handlePredict(ds):
print(f"Predicting {ds} with model")
root = Path(__file__).parent.parent
path = root / "models" / datasets[ds]
files = path.glob('*')
print("Available predictors: ")
for i,val in enumerate(files):
print(f"[{i:2<}] {val}")
print("[q] Quit")
i = input()
if i == 'q':
exit()
try:
i = int(i)
predict(ds, list(path.glob('*'))[i])
except Exception:
print()
def predict(ds, model):
root = Path(__file__).parent.parent
testing = root / "resources" / f"Testing{datasets[ds]}.csv"
testingData = read_csv(testing, header=None)
savePath = root / "output" / f"Testing{datasets[ds]}.csv"
predictor = TabularPredictor.load(path=model)
a = predictor.predict(data=testingData)
testingData[128] = a
print(a)
a = savePath.absolute().as_posix()
print(f"Saving to {a}")
testingData.to_csv(path_or_buf=a, header=False)
def handleTrain(ds):
print("Starting training")
root = Path(__file__).parent.parent
path = root / "models" / datasets[ds]
name = datetime.now().isoformat()
file = path / name
print(f"Saving model to {file.absolute().as_posix()}")
train_data = read_csv(f'/home/andrei/cypersec-coursework2/resources/Training{datasets[ds]}.csv', header=None)
predictor = TabularPredictor(label=128,
problem_type="binary" if ds =="a" else "multiclass",
eval_metric="accuracy",
path=file.absolute().as_posix()
)
predictor.fit(train_data)
print("Finished predictor!")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment