Select Git revision
-
Andrei Vasile authoredAndrei Vasile authored
main.py 2.47 KiB
from pathlib import Path
from pandas import read_csv, DataFrame
from autogluon.tabular import TabularDataset, TabularPredictor
from datetime import datetime
datasets = {
"a": "DataBinary",
"b": "DataMulti"
}
def main():
print("Training and predictor")
i = input("Which set would you like to work on?\n[A] DataBinary\n[B] DataMulti\n")
i = i.lower()
if i!="a" and i!="b":
print("I don't know that dataset!")
exit()
handleDataset(i)
def handleDataset(ds):
i = input(f"What would you like to do to {ds}?\n[A] Predict\n[B] Train\n").lower()
if i=='b':
handleTrain(ds)
elif i =='a':
handlePredict(ds)
def handlePredict(ds):
print(f"Predicting {ds} with model")
root = Path(__file__).parent.parent
path = root / "models" / datasets[ds]
files = path.glob('*')
print("Available predictors: ")
for i,val in enumerate(files):
print(f"[{i:2<}] {val}")
print("[q] Quit")
i = input()
if i == 'q':
exit()
try:
i = int(i)
predict(ds, list(path.glob('*'))[i])
except Exception:
print()
def predict(ds, model):
root = Path(__file__).parent.parent
testing = root / "resources" / "TestingData{}.csv".format("Multi" if ds =='b' else "Binary")
testingData = read_csv(testing, header=None)
savePath = root / "output" / "TestingResults{}.csv".format("Multi" if ds =='b' else "Binary")
predictor = TabularPredictor.load(path=model)
a = predictor.predict(data=testingData)
testingData[128] = a
print(a)
savePath = savePath.absolute().as_posix()
print(f"Saving to {savePath}")
with open(savePath, "w") as f:
for i in a:
f.write(f"{i}\n")
def handleTrain(ds):
print("Starting training")
root = Path(__file__).parent.parent
path = root / "models" / datasets[ds]
name = datetime.now().isoformat()
file = path / name
print(f"Saving model to {file.absolute().as_posix()}")
train_data = read_csv(f'/home/andrei/cypersec-coursework2/resources/Training{datasets[ds]}.csv', header=None)
predictor = TabularPredictor(label=128,
problem_type="binary" if ds =="a" else "multiclass",
eval_metric="accuracy",
path=file.absolute().as_posix()
)
predictor.fit(train_data)
print("Finished predictor!")
if __name__ == "__main__":
main()