Skip to content
Snippets Groups Projects
Commit ee6821fa authored by Liam Byrne's avatar Liam Byrne
Browse files

addressed the graph visualization slowing down

parent 0a21c2ca
No related branches found
No related tags found
No related merge requests found
...@@ -117,9 +117,9 @@ def test(loader): ...@@ -117,9 +117,9 @@ def test(loader):
true_labels += list([x.item() for x in data.label]) true_labels += list([x.item() for x in data.label])
if use_wandb: if use_wandb:
print("PRED", pred, data.label) graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data)))
for pred, label in zip(pred, torch.squeeze(data.label, -1)): for pred, label in zip(pred, torch.squeeze(data.label, -1)):
table.add_data(wandb.Html(plotly.io.to_html(create_graph_vis(data))), label, pred) table.add_data(graph_html, label, pred)
#print("PRED", predictions, true_labels) #print("PRED", predictions, true_labels)
return accuracy_score(true_labels, predictions), f1_score(true_labels, predictions), loss_ / len(loader), table return accuracy_score(true_labels, predictions), f1_score(true_labels, predictions), loss_ / len(loader), table
...@@ -133,8 +133,8 @@ def create_graph_vis(graph): ...@@ -133,8 +133,8 @@ def create_graph_vis(graph):
fig = vis.create_figure() fig = vis.create_figure()
return fig return fig
def init_wandb(project_name: str, run_name: str, dataset): def init_wandb(project_name: str, dataset):
wandb.init(project=project_name, name=run_name) wandb.init(project=project_name, name="setup")
# Log all the details about the data to W&B. # Log all the details about the data to W&B.
wandb.log(data_details) wandb.log(data_details)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment