Skip to content
Snippets Groups Projects
Commit 9802fe93 authored by bav1g20's avatar bav1g20
Browse files

updated data saving as spec was not clear

parent 46060854
Branches JulioWithHashTable
No related tags found
No related merge requests found
......@@ -27,6 +27,39 @@ def handleDataset(ds):
handleTrain(ds)
elif i =='a':
handlePredict(ds)
elif i=='c':
internal(ds)
def internal(ds):
print(f"Predicting {ds} with model")
root = Path(__file__).parent.parent
path = root / "models" / datasets[ds]
files = path.glob('*')
print("Available predictors: ")
for i,val in enumerate(files):
print(f"[{i:2<}] {val}")
print("[q] Quit")
i = input()
if i == 'q':
exit()
root = Path(__file__).parent.parent
testing = root / "resources" / "TestingData{}.csv".format("Multi" if ds =='b' else "Binary")
training = root / "resources" / "TrainingData{}.csv".format("Multi" if ds =='b' else "Binary")
testingData = read_csv(testing, header=None)
trainingData = read_csv(training, header=None)
x = list(path.glob('*'))
f = x[int(i)]
predictor = TabularPredictor.load(path=f.absolute().as_posix())
n = predictor.leaderboard()
print(n["model"].values.tolist())
def handlePredict(ds):
print(f"Predicting {ds} with model")
......@@ -47,7 +80,6 @@ def handlePredict(ds):
try:
i = int(i)
predict(ds, list(path.glob('*'))[i])
except Exception:
print()
......@@ -66,9 +98,7 @@ def predict(ds, model):
savePath = savePath.absolute().as_posix()
print(f"Saving to {savePath}")
with open(savePath, "w") as f:
for i in a:
f.write(f"{i}\n")
testingData.to_csv(savePath, header=False, index=False)
def handleTrain(ds):
print("Starting training")
......@@ -85,7 +115,7 @@ def handleTrain(ds):
eval_metric="accuracy",
path=file.absolute().as_posix()
)
predictor.fit(train_data)
predictor.fit(train_data,verbosity=4)
print("Finished predictor!")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment