Skip to content
Snippets Groups Projects
Commit 2e4ad739 authored by L.H.Byrne's avatar L.H.Byrne
Browse files

working Hetero-GAT

parent a6d7f72c
Branches
No related tags found
No related merge requests found
Showing
with 451 additions and 77 deletions
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
import logging
import sys
def setup_custom_logger(name, level):
formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
screen_handler = logging.StreamHandler(stream=sys.stdout)
screen_handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.propagate = False
logger.setLevel(level)
logger.handlers.clear()
logger.addHandler(screen_handler)
return logger
import logging import logging
import os import os
import re
from typing import List
import torch import torch
from torch_geometric.data import InMemoryDataset from torch_geometric.data import InMemoryDataset
...@@ -9,9 +11,12 @@ from custom_logger import setup_custom_logger ...@@ -9,9 +11,12 @@ from custom_logger import setup_custom_logger
log = setup_custom_logger('in-memory-dataset', logging.INFO) log = setup_custom_logger('in-memory-dataset', logging.INFO)
class UserGraphDatasetInMemory(InMemoryDataset): class UserGraphDatasetInMemory(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): def __init__(self, root, file_name_out: str, question_ids:List[int]=None, transform=None, pre_transform=None, pre_filter=None):
self._file_name_out = file_name_out
self._question_ids = question_ids
super().__init__(root, transform, pre_transform, pre_filter) super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0]) self.data, self.slices = torch.load(self.processed_paths[0])
self.data = self.data.apply(lambda x: x.detach())
@property @property
def processed_dir(self): def processed_dir(self):
...@@ -27,7 +32,7 @@ class UserGraphDatasetInMemory(InMemoryDataset): ...@@ -27,7 +32,7 @@ class UserGraphDatasetInMemory(InMemoryDataset):
@property @property
def processed_file_names(self): def processed_file_names(self):
return ['in-memory-dataset.pt'] return [self._file_name_out]
def download(self): def download(self):
pass pass
...@@ -37,15 +42,34 @@ class UserGraphDatasetInMemory(InMemoryDataset): ...@@ -37,15 +42,34 @@ class UserGraphDatasetInMemory(InMemoryDataset):
data_list = [] data_list = []
for f in self.raw_file_names: for f in self.raw_file_names:
question_id_search = re.search(r"id_(\d+)", f)
if question_id_search:
if int(question_id_search.group(1)) not in self._question_ids:
continue
data = torch.load(os.path.join(self.raw_dir, f)) data = torch.load(os.path.join(self.raw_dir, f))
data_list.append(data) data_list.append(data)
data, slices = self.collate(data_list) data, slices = self.collate(data_list)
torch.save((data, slices), os.path.join(self.processed_dir, self.processed_file_names[0])) self.processed_paths[0] = f"{len(data_list)}-{self.processed_file_names[0]}"
torch.save((data, slices), os.path.join(self.processed_paths[0]))
if __name__ == '__main__': if __name__ == '__main__':
dataset = UserGraphDatasetInMemory('../data/') question_ids = set()
# Split by question ids
print(dataset.get(3)) for f in os.listdir("../data/processed"):
\ No newline at end of file question_id_search = re.search(r"id_(\d+)", f)
if question_id_search:
question_ids.add(int(question_id_search.group(1)))
#question_ids = list(question_ids)[:len(question_ids)* 0.6]
train_ids = list(question_ids)[:int(len(question_ids) * 0.7)]
test_ids = [x for x in question_ids if x not in train_ids]
log.info(f"Training question count {len(train_ids)}")
log.info(f"Testing question count {len(test_ids)}")
train_dataset = UserGraphDatasetInMemory('../data/', train_ids, f'train-{len(train_ids)}-qs.pt')
test_dataset = UserGraphDatasetInMemory('../data/', test_ids, f'test-{len(test_ids)}-qs.pt')
...@@ -19,6 +19,10 @@ from dataset_in_memory import UserGraphDatasetInMemory ...@@ -19,6 +19,10 @@ from dataset_in_memory import UserGraphDatasetInMemory
from Visualize import GraphVisualization from Visualize import GraphVisualization
log = setup_custom_logger("heterogenous_GAT_model", logging.INFO) log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
torch.multiprocessing.set_sharing_strategy('file_system')
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
class HeteroGNN(torch.nn.Module): class HeteroGNN(torch.nn.Module):
...@@ -44,7 +48,6 @@ class HeteroGNN(torch.nn.Module): ...@@ -44,7 +48,6 @@ class HeteroGNN(torch.nn.Module):
self.lin = Linear(-1, out_channels) self.lin = Linear(-1, out_channels)
self.softmax = torch.nn.Softmax(dim=-1) self.softmax = torch.nn.Softmax(dim=-1)
def forward(self, x_dict, edge_index_dict, batch_dict, post_emb): def forward(self, x_dict, edge_index_dict, batch_dict, post_emb):
# print("IN", post_emb.shape) # print("IN", post_emb.shape)
for conv in self.convs: for conv in self.convs:
...@@ -70,17 +73,25 @@ class HeteroGNN(torch.nn.Module): ...@@ -70,17 +73,25 @@ class HeteroGNN(torch.nn.Module):
return out return out
'''
'''
def train(model, train_loader): def train(model, train_loader):
running_loss = 0.0 running_loss = 0.0
model.train() model.train()
for i, data in enumerate(train_loader): # Iterate in batches over the training dataset. for i, data in enumerate(train_loader): # Iterate in batches over the training dataset.
data = data.to(device) data.to(device)
optimizer.zero_grad() # Clear gradients. optimizer.zero_grad() # Clear gradients.
#print("DATA IN", data.question_emb.shape, data.answer_emb.shape)
post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1) if INCLUDE_ANSWER:
post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
else:
post_emb = data.question_emb.to(device)
out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass. out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass.
loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss. loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss.
...@@ -89,13 +100,12 @@ def train(model, train_loader): ...@@ -89,13 +100,12 @@ def train(model, train_loader):
running_loss += loss.item() running_loss += loss.item()
if i % 5 == 0: if i % 5 == 0:
log.info(f"[{i+1}] Loss: {running_loss / 2000}") log.info(f"[{i + 1}] Loss: {running_loss / 5}")
running_loss = 0.0 running_loss = 0.0
def test(loader): def test(loader):
table = wandb.Table(columns=["graph", "ground_truth", "prediction"]) if use_wandb else None table = wandb.Table(columns=["ground_truth", "prediction"]) if use_wandb else None
model.eval() model.eval()
predictions = [] predictions = []
...@@ -104,9 +114,12 @@ def test(loader): ...@@ -104,9 +114,12 @@ def test(loader):
loss_ = 0 loss_ = 0
for data in loader: # Iterate in batches over the training/test dataset. for data in loader: # Iterate in batches over the training/test dataset.
data = data.to(device) data.to(device)
if INCLUDE_ANSWER:
post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device) post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
else:
post_emb = data.question_emb.to(device)
out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass. out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass.
...@@ -115,13 +128,26 @@ def test(loader): ...@@ -115,13 +128,26 @@ def test(loader):
pred = out.argmax(dim=1) # Use the class with highest probability. pred = out.argmax(dim=1) # Use the class with highest probability.
predictions += list([x.item() for x in pred]) predictions += list([x.item() for x in pred])
true_labels += list([x.item() for x in data.label]) true_labels += list([x.item() for x in data.label])
# log.info([(x, y) for x,y in zip([x.item() for x in pred], [x.item() for x in data.label])])
if use_wandb: if use_wandb:
graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data))) #graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data)))
for pred, label in zip(pred, torch.squeeze(data.label, -1)): for pred, label in zip(pred, torch.squeeze(data.label, -1)):
table.add_data(graph_html, label, pred) table.add_data(label, pred)
#print([(x, y) for x, y in zip(predictions, true_labels)])
test_results = {
"accuracy": accuracy_score(true_labels, predictions),
"f1-score": f1_score(true_labels, predictions),
"loss": loss_ / len(loader),
"table": table,
"preds": predictions,
"trues": true_labels
}
return test_results
return accuracy_score(true_labels, predictions), f1_score(true_labels, predictions), loss_ / len(loader), table
def create_graph_vis(graph): def create_graph_vis(graph):
g = to_networkx(graph.to_homogeneous()) g = to_networkx(graph.to_homogeneous())
...@@ -132,6 +158,7 @@ def create_graph_vis(graph): ...@@ -132,6 +158,7 @@ def create_graph_vis(graph):
fig = vis.create_figure() fig = vis.create_figure()
return fig return fig
def init_wandb(project_name: str, dataset): def init_wandb(project_name: str, dataset):
wandb.init(project=project_name, name="setup") wandb.init(project=project_name, name="setup")
# Log all the details about the data to W&B. # Log all the details about the data to W&B.
...@@ -158,96 +185,104 @@ def init_wandb(project_name: str, dataset): ...@@ -158,96 +185,104 @@ def init_wandb(project_name: str, dataset):
# End the W&B run # End the W&B run
wandb.finish() wandb.finish()
def start_wandb_for_training(wandb_project_name: str, wandb_run_name: str): def start_wandb_for_training(wandb_project_name: str, wandb_run_name: str):
wandb.init(project=wandb_project_name, name=wandb_run_name) wandb.init(project=wandb_project_name, name=wandb_run_name)
# wandb.use_artifact("static-graphs:latest") # wandb.use_artifact("static-graphs:latest")
def save_model(model, model_name: str): def save_model(model, model_name: str):
torch.save(model.state_dict(), os.path.join("..", "models", model_name)) torch.save(model.state_dict(), os.path.join("..", "models", model_name))
if __name__ == '__main__': if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"Proceeding with {device} . .") log.info(f"Proceeding with {device} . .")
in_memory_dataset = False in_memory_dataset = True
# Datasets # Datasets
if in_memory_dataset: if in_memory_dataset:
dataset = UserGraphDatasetInMemory(root="../data") train_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='train-4175-qs.pt')
test_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='test-1790-qs.pt')
else: else:
dataset = UserGraphDataset(root="../data", skip_processing=True) dataset = UserGraphDataset(root="../data", skip_processing=True)
train_size = int(0.7 * len(dataset)) train_size = int(0.7 * len(dataset))
val_size = int(0.1 * len(dataset)) val_size = int(0.1 * len(dataset))
test_size = len(dataset) - (train_size + val_size) test_size = len(dataset) - (train_size + val_size)
log.info(f"Train Dataset Size: {train_size}")
log.info(f"Validation Dataset Size: {val_size}") train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
log.info(f"Test Dataset Size: {test_size}")
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size]) log.info(f"Train Dataset Size: {len(train_dataset)}")
log.info(f"Test Dataset Size: {len(test_dataset)}")
# Weights&Biases dashboard # Weights&Biases dashboard
data_details = { data_details = {
"num_node_features": dataset.num_node_features, "num_node_features": train_dataset.num_node_features,
"num_classes": 2 "num_classes": 2
} }
log.info(f"Data Details:\n{data_details}") log.info(f"Data Details:\n{data_details}")
log.info(train_dataset[0])
setup_wandb = False setup_wandb = False
wandb_project_name = "heterogeneous-GAT-model" wandb_project_name = "heterogeneous-GAT-model"
if setup_wandb: if setup_wandb:
init_wandb(wandb_project_name, dataset) init_wandb(wandb_project_name, dataset)
use_wandb = False use_wandb = True
if use_wandb: if use_wandb:
wandb_run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}" wandb_run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}"
start_wandb_for_training(wandb_project_name, wandb_run_name) start_wandb_for_training(wandb_project_name, wandb_run_name)
calculate_class_weights = True
calculate_class_weights = False
# Class weights # Class weights
sampler = None sampler = None
if calculate_class_weights: if calculate_class_weights:
log.info(f"Calculating class weights") log.info(f"Calculating class weights")
train_labels = [x.label for x in train_dataset] train_labels = [x.label for x in train_dataset]
counts = [train_labels.count(x) for x in [0, 1]] counts = [train_labels.count(x) for x in [0, 1]]
print(counts)
class_weights = [1 - (x / sum(counts)) for x in counts] class_weights = [1 - (x / sum(counts)) for x in counts]
print(class_weights)
sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels)) sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
TRAIN_BATCH_SIZE = 512
log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}")
# Dataloaders # Dataloaders
train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=64) train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=TRAIN_BATCH_SIZE, num_workers=14)
val_loader = DataLoader(val_dataset, batch_size=16)
test_loader = DataLoader(test_dataset, batch_size=16) test_loader = DataLoader(test_dataset, batch_size=512, num_workers=14)
# Model # Model
model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3).to(device) model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3)
model.to(device)
# Experiment config
INCLUDE_ANSWER = False
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
for epoch in range(1, 5): for epoch in range(1, 40):
log.info(f"Epoch: {epoch:03d} > > >") log.info(f"Epoch: {epoch:03d} > > >")
train(model, train_loader) train(model, train_loader)
train_acc, train_f1, train_loss, train_table = test(train_loader) train_info = test(train_loader)
val_acc, val_f1, val_loss, val_table = test(val_loader) test_info = test(test_loader)
test_acc, test_f1, test_loss, test_table = test(test_loader)
print(f'Epoch: {epoch:03d}, Train F1: {train_f1:.4f}, Validation F1: {val_f1:.4f} Test F1: {test_f1:.4f}') print(f'Epoch: {epoch:03d}, Train F1: {train_info["f1-score"]:.4f}, Test F1: {test_info["f1-score"]:.4f}')
checkpoint_file_name = f"../models/model-{epoch}.pt" checkpoint_file_name = f"../models/model-{epoch}.pt"
torch.save(model.state_dict(), checkpoint_file_name) torch.save(model.state_dict(), checkpoint_file_name)
if use_wandb: if use_wandb:
wandb.log({ wandb.log({
"train/loss": train_loss, "train/loss": train_info["loss"],
"train/accuracy": train_acc, "train/accuracy": train_info["accuracy"],
"train/f1": train_f1, "train/f1": train_info["f1-score"],
"train/table": train_table, "train/table": train_info["table"],
"val/loss": val_loss, "test/loss": test_info["loss"],
"val/accuracy": val_acc, "test/accuracy": test_info["accuracy"],
"val/f1": val_f1, "test/f1": test_info["f1-score"],
"val/table": val_table, "test/table": test_info["table"]
"test/loss": test_loss,
"test/accuracy": test_acc,
"test/f1": test_f1,
"test/table": test_table,
}) })
# Log model checkpoint as an artifact to W&B # Log model checkpoint as an artifact to W&B
# artifact = wandb.Artifact(name="heterogenous-GAT-static-graphs", type="model") # artifact = wandb.Artifact(name="heterogenous-GAT-static-graphs", type="model")
...@@ -256,8 +291,10 @@ if __name__ == '__main__': ...@@ -256,8 +291,10 @@ if __name__ == '__main__':
# artifact.add_file(checkpoint_file_name) # artifact.add_file(checkpoint_file_name)
# wandb.log_artifact(artifact) # wandb.log_artifact(artifact)
print(f'Test F1: {test_f1:.4f}') print(f'Test F1: {train_info["f1-score"]:.4f}')
save_model(model, "model.pt") save_model(model, "model.pt")
if use_wandb: if use_wandb:
wandb.log({"test/cm": wandb.plot.confusion_matrix(probs=None, y_true=test_info["trues"], preds=test_info["preds"], class_names=["neutral", "upvoted"])})
wandb.finish() wandb.finish()
import json
import logging
import os
import string
import time
import networkx as nx
import plotly
import torch
from sklearn.metrics import f1_score, accuracy_score
from torch_geometric.loader import DataLoader
from torch_geometric.nn import HeteroConv, GATConv, Linear, global_mean_pool
import wandb
from torch_geometric.utils import to_networkx
from custom_logger import setup_custom_logger
from dataset import UserGraphDataset
from dataset_in_memory import UserGraphDatasetInMemory
from Visualize import GraphVisualization
log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
torch.multiprocessing.set_sharing_strategy('file_system')
mport resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
# resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
class HeteroGNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_layers):
super().__init__()
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv({
('tag', 'describes', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
('tag', 'describes', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
('tag', 'describes', 'comment'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
('module', 'imported_in', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
('module', 'imported_in', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
('question', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
('answer', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
('comment', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
('question', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
('answer', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
}, aggr='sum')
self.convs.append(conv)
self.lin = Linear(-1, out_channels)
self.softmax = torch.nn.Softmax(dim=-1)
def forward(self, x_dict, edge_index_dict, batch_dict, post_emb):
# print("IN", post_emb.shape)
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: x.relu() for key, x in x_dict.items()}
outs = []
for x, batch in zip(x_dict.values(), batch_dict.values()):
if len(x):
outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)))
else:
outs.append(torch.zeros(1, x.size(-1)))
# print([x.shape for x in outs])
out = torch.cat(outs, dim=1)
out = torch.cat([out, post_emb], dim=1)
# print("B4 LINEAR", out.shape)
out = self.lin(out)
out = out.relu()
out = self.softmax(out)
return out
'''
'''
def train(model, train_loader):
running_loss = 0.0
model.train()
for i, data in enumerate(train_loader): # Iterate in batches over the training dataset.
data.to(device)
optimizer.zero_grad() # Clear gradients.
if INCLUDE_ANSWER:
post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
else:
post_emb = data.question_emb.to(device)
out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass.
loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
running_loss += loss.item()
if i % 5 == 0:
log.info(f"[{i + 1}] Loss: {running_loss / 5}")
running_loss = 0.0
def test(loader):
table = wandb.Table(columns=["ground_truth", "prediction"]) if use_wandb else None
model.eval()
predictions = []
true_labels = []
loss_ = 0
for data in loader: # Iterate in batches over the training/test dataset.
data.to(device)
if INCLUDE_ANSWER:
post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
else:
post_emb = data.question_emb.to(device)
out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass.
loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss.
loss_ += loss.item()
pred = out.argmax(dim=1) # Use the class with highest probability.
predictions += list([x.item() for x in pred])
true_labels += list([x.item() for x in data.label])
# log.info([(x, y) for x,y in zip([x.item() for x in pred], [x.item() for x in data.label])])
if use_wandb:
#graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data)))
for pred, label in zip(pred, torch.squeeze(data.label, -1)):
table.add_data(label, pred)
#print([(x, y) for x, y in zip(predictions, true_labels)])
test_results = {
"accuracy": accuracy_score(true_labels, predictions),
"f1-score": f1_score(true_labels, predictions),
"loss": loss_ / len(loader),
"table": table,
"preds": predictions,
"trues": true_labels
}
return test_results
def create_graph_vis(graph):
g = to_networkx(graph.to_homogeneous())
pos = nx.spring_layout(g)
vis = GraphVisualization(
g, pos, node_text_position='top left', node_size=20,
)
fig = vis.create_figure()
return fig
def init_wandb(project_name: str, dataset):
wandb.init(project=project_name, name="setup")
# Log all the details about the data to W&B.
wandb.log(data_details)
# Log exploratory visualizations for each data point to W&B
table = wandb.Table(columns=["Graph", "Number of Nodes", "Number of Edges", "Label"])
for graph in dataset:
fig = create_graph_vis(graph)
n_nodes = graph.num_nodes
n_edges = graph.num_edges
label = graph.label.item()
# graph_vis = plotly.io.to_html(fig, full_html=False)
table.add_data(wandb.Plotly(fig), n_nodes, n_edges, label)
wandb.log({"data": table})
# Log the dataset to W&B as an artifact.
dataset_artifact = wandb.Artifact(name="static-graphs", type="dataset", metadata=data_details)
dataset_artifact.add_dir("../data/")
wandb.log_artifact(dataset_artifact)
# End the W&B run
wandb.finish()
def start_wandb_for_training(wandb_project_name: str, wandb_run_name: str):
wandb.init(project=wandb_project_name, name=wandb_run_name)
# wandb.use_artifact("static-graphs:latest")
def save_model(model, model_name: str):
torch.save(model.state_dict(), os.path.join("..", "models", model_name))
if __name__ == '__main__':
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"Proceeding with {device} . .")
in_memory_dataset = True
# Datasets
if in_memory_dataset:
train_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='train-4175-qs.pt')
test_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='test-1790-qs.pt')
else:
dataset = UserGraphDataset(root="../data", skip_processing=True)
train_size = int(0.7 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - (train_size + val_size)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
log.info(f"Train Dataset Size: {len(train_dataset)}")
log.info(f"Test Dataset Size: {len(test_dataset)}")
# Weights&Biases dashboard
data_details = {
"num_node_features": train_dataset.num_node_features,
"num_classes": 2
}
log.info(f"Data Details:\n{data_details}")
setup_wandb = False
wandb_project_name = "heterogeneous-GAT-model"
if setup_wandb:
init_wandb(wandb_project_name, dataset)
use_wandb = True
if use_wandb:
wandb_run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}"
start_wandb_for_training(wandb_project_name, wandb_run_name)
calculate_class_weights = True
# Class weights
sampler = None
if calculate_class_weights:
log.info(f"Calculating class weights")
train_labels = [x.label for x in train_dataset]
counts = [train_labels.count(x) for x in [0, 1]]
class_weights = [1 - (x / sum(counts)) for x in counts]
print(class_weights)
sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
TRAIN_BATCH_SIZE = 512
log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}")
# Dataloaders
train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=TRAIN_BATCH_SIZE, num_workers=14)
test_loader = DataLoader(test_dataset, batch_size=512, num_workers=14)
# Model
model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3)
model.to(device)
# Experiment config
INCLUDE_ANSWER = True
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(1, 40):
log.info(f"Epoch: {epoch:03d} > > >")
train(model, train_loader)
train_info = test(train_loader)
test_info = test(test_loader)
print(f'Epoch: {epoch:03d}, Train F1: {train_info["f1-score"]:.4f}, Test F1: {test_info["f1-score"]:.4f}')
checkpoint_file_name = f"../models/model-{epoch}.pt"
torch.save(model.state_dict(), checkpoint_file_name)
if use_wandb:
wandb.log({
"train/loss": train_info["loss"],
"train/accuracy": train_info["accuracy"],
"train/f1": train_info["f1-score"],
"train/table": train_info["table"],
"test/loss": test_info["loss"],
"test/accuracy": test_info["accuracy"],
"test/f1": test_info["f1-score"],
"test/table": test_info["table"]
})
# Log model checkpoint as an artifact to W&B
# artifact = wandb.Artifact(name="heterogenous-GAT-static-graphs", type="model")
# checkpoint_file_name = f "../models/model-{epoch}.pt"
# torch.save(model.state_dict(), checkpoint_file_name)
# artifact.add_file(checkpoint_file_name)
# wandb.log_artifact(artifact)
print(f'Test F1: {test_f1:.4f}')
save_model(model, "model.pt")
if use_wandb:
wandb.log({"test/cm": wandb.plot.confusion_matrix(probs=None, y_true=test_info["trues"], preds=test_info["preds"], class_names=["neutral", "upvoted"])})
wandb.finish()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment