Skip to content
Snippets Groups Projects
Select Git revision
  • 46060854ced46b60ede07ac7d6ec6ce5333646c7
  • main default protected
2 results

main.py

Blame
  • main.py 2.47 KiB
    from pathlib import Path
    from pandas import read_csv, DataFrame
    from autogluon.tabular import TabularDataset, TabularPredictor
    from datetime import datetime
    
    
    datasets = {
            "a": "DataBinary",
            "b": "DataMulti"
        }
    
    def main():
        print("Training and predictor")
    
        i = input("Which set would you like to work on?\n[A] DataBinary\n[B] DataMulti\n")
        i = i.lower()
    
        if i!="a" and i!="b":
            print("I don't know that dataset!")
            exit()
    
        handleDataset(i)
    
    def handleDataset(ds):
        i = input(f"What would you like to do to {ds}?\n[A] Predict\n[B] Train\n").lower()
        if i=='b':
            handleTrain(ds)
        elif i =='a':
            handlePredict(ds)
    
    def handlePredict(ds):
        print(f"Predicting {ds} with model")
    
        root = Path(__file__).parent.parent
        path = root / "models" / datasets[ds]
    
        files = path.glob('*')
    
        print("Available predictors: ")
        for i,val in enumerate(files):
            print(f"[{i:2<}] {val}")
        print("[q] Quit")
    
        i = input()
        if i == 'q':
            exit()
        try:
            i = int(i)
            predict(ds, list(path.glob('*'))[i])
    
        except Exception:
            print()
    
    def predict(ds, model):
        root = Path(__file__).parent.parent
        testing = root / "resources" / "TestingData{}.csv".format("Multi" if ds =='b' else "Binary")
        testingData = read_csv(testing, header=None)
        savePath = root / "output" / "TestingResults{}.csv".format("Multi" if ds =='b' else "Binary")
    
        predictor = TabularPredictor.load(path=model)
    
        a = predictor.predict(data=testingData)
        testingData[128] = a
        print(a)
    
        savePath = savePath.absolute().as_posix()
        print(f"Saving to {savePath}")
    
        with open(savePath, "w") as f:
            for i in a:
                f.write(f"{i}\n")
    
    def handleTrain(ds):
        print("Starting training")
    
        root = Path(__file__).parent.parent
        path = root / "models" / datasets[ds]
        name = datetime.now().isoformat()
    
        file = path / name
        print(f"Saving model to {file.absolute().as_posix()}")
        train_data = read_csv(f'/home/andrei/cypersec-coursework2/resources/Training{datasets[ds]}.csv', header=None)
        predictor = TabularPredictor(label=128,
                                     problem_type="binary" if ds =="a" else "multiclass", 
                                     eval_metric="accuracy",
                                     path=file.absolute().as_posix()
                                     )
        predictor.fit(train_data)
        print("Finished predictor!")
    
    
    if __name__ == "__main__":
        main()