diff --git a/embeddings/__init__.py b/embeddings/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/embeddings/__pycache__/ModuleEmbeddings.cpython-39.pyc b/embeddings/__pycache__/ModuleEmbeddings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f517b2832d279eceff26690cf50123b46795f4a Binary files /dev/null and b/embeddings/__pycache__/ModuleEmbeddings.cpython-39.pyc differ diff --git a/embeddings/__pycache__/NextTagEmbedding.cpython-39.pyc b/embeddings/__pycache__/NextTagEmbedding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bc5968612d8fbb03504a3b43e0a545dab411d0d Binary files /dev/null and b/embeddings/__pycache__/NextTagEmbedding.cpython-39.pyc differ diff --git a/embeddings/__pycache__/Visualize.cpython-39.pyc b/embeddings/__pycache__/Visualize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74ac8f2a2d5b9996aef84b0c816ad0946d3c1e4d Binary files /dev/null and b/embeddings/__pycache__/Visualize.cpython-39.pyc differ diff --git a/embeddings/__pycache__/custom_logger.cpython-38.pyc b/embeddings/__pycache__/custom_logger.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a06a90bb35dfdd823d5608f747aeabf6a251fefe Binary files /dev/null and b/embeddings/__pycache__/custom_logger.cpython-38.pyc differ diff --git a/embeddings/__pycache__/custom_logger.cpython-39.pyc b/embeddings/__pycache__/custom_logger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97fd1d8280827875e666c1a6c8b9995ee9958120 Binary files /dev/null and b/embeddings/__pycache__/custom_logger.cpython-39.pyc differ diff --git a/embeddings/__pycache__/dataset.cpython-39.pyc b/embeddings/__pycache__/dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd97a912b458d313f1ee3b398718af41bcc1a946 Binary files /dev/null and b/embeddings/__pycache__/dataset.cpython-39.pyc differ diff --git a/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc b/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b5680a08ba3da9769d0355f247721d87c640624 Binary files /dev/null and b/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc differ diff --git a/embeddings/__pycache__/post_embedding_builder.cpython-39.pyc b/embeddings/__pycache__/post_embedding_builder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c28eb16ed26e7813cbf1dd1fd1857bcb9fed14ba Binary files /dev/null and b/embeddings/__pycache__/post_embedding_builder.cpython-39.pyc differ diff --git a/embeddings/__pycache__/static_graph_construction.cpython-39.pyc b/embeddings/__pycache__/static_graph_construction.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c148ef5bc25a569593d938dd5ee030df7d33280 Binary files /dev/null and b/embeddings/__pycache__/static_graph_construction.cpython-39.pyc differ diff --git a/embeddings/__pycache__/unixcoder.cpython-39.pyc b/embeddings/__pycache__/unixcoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..328777fd5508c3dbe7b68f74baa751b377c14a97 Binary files /dev/null and b/embeddings/__pycache__/unixcoder.cpython-39.pyc differ diff --git a/embeddings/custom_logger.py b/embeddings/custom_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb2391569ebb7ca2cf7bf372cb11fd708b0bfed --- /dev/null +++ b/embeddings/custom_logger.py @@ -0,0 +1,16 @@ +import logging +import sys + + +def setup_custom_logger(name, level): + formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + screen_handler = logging.StreamHandler(stream=sys.stdout) + screen_handler.setFormatter(formatter) + logger = logging.getLogger(name) + logger.propagate = False + logger.setLevel(level) + + logger.handlers.clear() + logger.addHandler(screen_handler) + return logger diff --git a/embeddings/dataset_in_memory.py b/embeddings/dataset_in_memory.py index 153055e362ba7a2164820943821e540113ece47b..1ecbf6a7b6ca609b0c1c604fbb5d643d5ea79cee 100644 --- a/embeddings/dataset_in_memory.py +++ b/embeddings/dataset_in_memory.py @@ -1,5 +1,7 @@ import logging import os +import re +from typing import List import torch from torch_geometric.data import InMemoryDataset @@ -9,9 +11,12 @@ from custom_logger import setup_custom_logger log = setup_custom_logger('in-memory-dataset', logging.INFO) class UserGraphDatasetInMemory(InMemoryDataset): - def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): + def __init__(self, root, file_name_out: str, question_ids:List[int]=None, transform=None, pre_transform=None, pre_filter=None): + self._file_name_out = file_name_out + self._question_ids = question_ids super().__init__(root, transform, pre_transform, pre_filter) self.data, self.slices = torch.load(self.processed_paths[0]) + self.data = self.data.apply(lambda x: x.detach()) @property def processed_dir(self): @@ -27,7 +32,7 @@ class UserGraphDatasetInMemory(InMemoryDataset): @property def processed_file_names(self): - return ['in-memory-dataset.pt'] + return [self._file_name_out] def download(self): pass @@ -37,15 +42,34 @@ class UserGraphDatasetInMemory(InMemoryDataset): data_list = [] for f in self.raw_file_names: + question_id_search = re.search(r"id_(\d+)", f) + if question_id_search: + if int(question_id_search.group(1)) not in self._question_ids: + continue + data = torch.load(os.path.join(self.raw_dir, f)) data_list.append(data) data, slices = self.collate(data_list) - torch.save((data, slices), os.path.join(self.processed_dir, self.processed_file_names[0])) + self.processed_paths[0] = f"{len(data_list)}-{self.processed_file_names[0]}" + torch.save((data, slices), os.path.join(self.processed_paths[0])) if __name__ == '__main__': - dataset = UserGraphDatasetInMemory('../data/') - - print(dataset.get(3)) \ No newline at end of file + question_ids = set() + # Split by question ids + for f in os.listdir("../data/processed"): + question_id_search = re.search(r"id_(\d+)", f) + if question_id_search: + question_ids.add(int(question_id_search.group(1))) + + #question_ids = list(question_ids)[:len(question_ids)* 0.6] + train_ids = list(question_ids)[:int(len(question_ids) * 0.7)] + test_ids = [x for x in question_ids if x not in train_ids] + + log.info(f"Training question count {len(train_ids)}") + log.info(f"Testing question count {len(test_ids)}") + + train_dataset = UserGraphDatasetInMemory('../data/', train_ids, f'train-{len(train_ids)}-qs.pt') + test_dataset = UserGraphDatasetInMemory('../data/', test_ids, f'test-{len(test_ids)}-qs.pt') diff --git a/embeddings/hetero_GAT.py b/embeddings/hetero_GAT.py index 95a6b9983a5b8e902c12648de10c2e51c1e145dc..0943795d340e8939345c4d2893db569eea6f1cb4 100644 --- a/embeddings/hetero_GAT.py +++ b/embeddings/hetero_GAT.py @@ -19,6 +19,10 @@ from dataset_in_memory import UserGraphDatasetInMemory from Visualize import GraphVisualization log = setup_custom_logger("heterogenous_GAT_model", logging.INFO) +torch.multiprocessing.set_sharing_strategy('file_system') +import resource +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) class HeteroGNN(torch.nn.Module): @@ -28,25 +32,24 @@ class HeteroGNN(torch.nn.Module): self.convs = torch.nn.ModuleList() for _ in range(num_layers): conv = HeteroConv({ - ('tag', 'describes', 'question') : GATConv((-1,-1), hidden_channels, add_self_loops=False), - ('tag', 'describes', 'answer') : GATConv((-1,-1), hidden_channels, add_self_loops=False), - ('tag', 'describes', 'comment') : GATConv((-1,-1), hidden_channels, add_self_loops=False), - ('module', 'imported_in', 'question') : GATConv((-1,-1), hidden_channels, add_self_loops=False), - ('module', 'imported_in', 'answer') : GATConv((-1,-1), hidden_channels, add_self_loops=False), - ('question', 'rev_describes', 'tag') : GATConv((-1,-1), hidden_channels, add_self_loops=False), - ('answer', 'rev_describes', 'tag') : GATConv((-1,-1), hidden_channels, add_self_loops=False), - ('comment', 'rev_describes', 'tag') : GATConv((-1,-1), hidden_channels, add_self_loops=False), - ('question', 'rev_imported_in', 'module') : GATConv((-1,-1), hidden_channels, add_self_loops=False), - ('answer', 'rev_imported_in', 'module') : GATConv((-1,-1), hidden_channels, add_self_loops=False), + ('tag', 'describes', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('tag', 'describes', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('tag', 'describes', 'comment'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('module', 'imported_in', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('module', 'imported_in', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('question', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('answer', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('comment', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('question', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('answer', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False), }, aggr='sum') self.convs.append(conv) self.lin = Linear(-1, out_channels) self.softmax = torch.nn.Softmax(dim=-1) - def forward(self, x_dict, edge_index_dict, batch_dict, post_emb): - #print("IN", post_emb.shape) + # print("IN", post_emb.shape) for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) x_dict = {key: x.relu() for key, x in x_dict.items()} @@ -58,29 +61,37 @@ class HeteroGNN(torch.nn.Module): else: outs.append(torch.zeros(1, x.size(-1))) - #print([x.shape for x in outs]) + # print([x.shape for x in outs]) out = torch.cat(outs, dim=1) out = torch.cat([out, post_emb], dim=1) - #print("B4 LINEAR", out.shape) + # print("B4 LINEAR", out.shape) out = self.lin(out) out = out.relu() out = self.softmax(out) return out +''' + +''' + def train(model, train_loader): running_loss = 0.0 model.train() for i, data in enumerate(train_loader): # Iterate in batches over the training dataset. - data = data.to(device) + data.to(device) optimizer.zero_grad() # Clear gradients. - #print("DATA IN", data.question_emb.shape, data.answer_emb.shape) - post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1) + + if INCLUDE_ANSWER: + post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device) + else: + post_emb = data.question_emb.to(device) + out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass. loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss. @@ -89,13 +100,12 @@ def train(model, train_loader): running_loss += loss.item() if i % 5 == 0: - log.info(f"[{i+1}] Loss: {running_loss / 2000}") + log.info(f"[{i + 1}] Loss: {running_loss / 5}") running_loss = 0.0 - def test(loader): - table = wandb.Table(columns=["graph", "ground_truth", "prediction"]) if use_wandb else None + table = wandb.Table(columns=["ground_truth", "prediction"]) if use_wandb else None model.eval() predictions = [] @@ -104,9 +114,12 @@ def test(loader): loss_ = 0 for data in loader: # Iterate in batches over the training/test dataset. - data = data.to(device) - - post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device) + data.to(device) + + if INCLUDE_ANSWER: + post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device) + else: + post_emb = data.question_emb.to(device) out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass. @@ -115,13 +128,26 @@ def test(loader): pred = out.argmax(dim=1) # Use the class with highest probability. predictions += list([x.item() for x in pred]) true_labels += list([x.item() for x in data.label]) - + # log.info([(x, y) for x,y in zip([x.item() for x in pred], [x.item() for x in data.label])]) if use_wandb: - graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data))) + #graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data))) + for pred, label in zip(pred, torch.squeeze(data.label, -1)): - table.add_data(graph_html, label, pred) + table.add_data(label, pred) + + + + #print([(x, y) for x, y in zip(predictions, true_labels)]) + test_results = { + "accuracy": accuracy_score(true_labels, predictions), + "f1-score": f1_score(true_labels, predictions), + "loss": loss_ / len(loader), + "table": table, + "preds": predictions, + "trues": true_labels + } + return test_results - return accuracy_score(true_labels, predictions), f1_score(true_labels, predictions), loss_ / len(loader), table def create_graph_vis(graph): g = to_networkx(graph.to_homogeneous()) @@ -132,6 +158,7 @@ def create_graph_vis(graph): fig = vis.create_figure() return fig + def init_wandb(project_name: str, dataset): wandb.init(project=project_name, name="setup") # Log all the details about the data to W&B. @@ -145,7 +172,7 @@ def init_wandb(project_name: str, dataset): n_edges = graph.num_edges label = graph.label.item() - #graph_vis = plotly.io.to_html(fig, full_html=False) + # graph_vis = plotly.io.to_html(fig, full_html=False) table.add_data(wandb.Plotly(fig), n_nodes, n_edges, label) wandb.log({"data": table}) @@ -158,106 +185,116 @@ def init_wandb(project_name: str, dataset): # End the W&B run wandb.finish() + def start_wandb_for_training(wandb_project_name: str, wandb_run_name: str): wandb.init(project=wandb_project_name, name=wandb_run_name) - #wandb.use_artifact("static-graphs:latest") + # wandb.use_artifact("static-graphs:latest") + def save_model(model, model_name: str): torch.save(model.state_dict(), os.path.join("..", "models", model_name)) + if __name__ == '__main__': device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log.info(f"Proceeding with {device} . .") - in_memory_dataset = False + in_memory_dataset = True # Datasets if in_memory_dataset: - dataset = UserGraphDatasetInMemory(root="../data") + train_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='train-4175-qs.pt') + test_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='test-1790-qs.pt') else: dataset = UserGraphDataset(root="../data", skip_processing=True) + train_size = int(0.7 * len(dataset)) + val_size = int(0.1 * len(dataset)) + test_size = len(dataset) - (train_size + val_size) - train_size = int(0.7 * len(dataset)) - val_size = int(0.1 * len(dataset)) - test_size = len(dataset) - (train_size + val_size) - - log.info(f"Train Dataset Size: {train_size}") - log.info(f"Validation Dataset Size: {val_size}") - log.info(f"Test Dataset Size: {test_size}") - train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size]) + train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) + log.info(f"Train Dataset Size: {len(train_dataset)}") + log.info(f"Test Dataset Size: {len(test_dataset)}") + # Weights&Biases dashboard data_details = { - "num_node_features": dataset.num_node_features, + "num_node_features": train_dataset.num_node_features, "num_classes": 2 } log.info(f"Data Details:\n{data_details}") - + + log.info(train_dataset[0]) + setup_wandb = False wandb_project_name = "heterogeneous-GAT-model" if setup_wandb: init_wandb(wandb_project_name, dataset) - use_wandb = False + use_wandb = True if use_wandb: wandb_run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}" start_wandb_for_training(wandb_project_name, wandb_run_name) - - calculate_class_weights = False - #Class weights + calculate_class_weights = True + # Class weights sampler = None if calculate_class_weights: log.info(f"Calculating class weights") train_labels = [x.label for x in train_dataset] - counts = [train_labels.count(x) for x in [0,1]] + counts = [train_labels.count(x) for x in [0, 1]] + print(counts) class_weights = [1 - (x / sum(counts)) for x in counts] + print(class_weights) sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels)) + TRAIN_BATCH_SIZE = 512 + log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}") + # Dataloaders - train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=64) - val_loader = DataLoader(val_dataset, batch_size=16) - test_loader = DataLoader(test_dataset, batch_size=16) + train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=TRAIN_BATCH_SIZE, num_workers=14) + + test_loader = DataLoader(test_dataset, batch_size=512, num_workers=14) # Model - model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3).to(device) - - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3) + model.to(device) + + # Experiment config + INCLUDE_ANSWER = False + + optimizer = torch.optim.Adam(model.parameters()) criterion = torch.nn.CrossEntropyLoss() - for epoch in range(1, 5): + for epoch in range(1, 40): log.info(f"Epoch: {epoch:03d} > > >") train(model, train_loader) - train_acc, train_f1, train_loss, train_table = test(train_loader) - val_acc, val_f1, val_loss, val_table = test(val_loader) - test_acc, test_f1, test_loss, test_table = test(test_loader) + train_info = test(train_loader) + test_info = test(test_loader) - print(f'Epoch: {epoch:03d}, Train F1: {train_f1:.4f}, Validation F1: {val_f1:.4f} Test F1: {test_f1:.4f}') + print(f'Epoch: {epoch:03d}, Train F1: {train_info["f1-score"]:.4f}, Test F1: {test_info["f1-score"]:.4f}') checkpoint_file_name = f"../models/model-{epoch}.pt" torch.save(model.state_dict(), checkpoint_file_name) if use_wandb: wandb.log({ - "train/loss": train_loss, - "train/accuracy": train_acc, - "train/f1": train_f1, - "train/table": train_table, - "val/loss": val_loss, - "val/accuracy": val_acc, - "val/f1": val_f1, - "val/table": val_table, - "test/loss": test_loss, - "test/accuracy": test_acc, - "test/f1": test_f1, - "test/table": test_table, + "train/loss": train_info["loss"], + "train/accuracy": train_info["accuracy"], + "train/f1": train_info["f1-score"], + "train/table": train_info["table"], + "test/loss": test_info["loss"], + "test/accuracy": test_info["accuracy"], + "test/f1": test_info["f1-score"], + "test/table": test_info["table"] }) # Log model checkpoint as an artifact to W&B # artifact = wandb.Artifact(name="heterogenous-GAT-static-graphs", type="model") - # checkpoint_file_name = f"../models/model-{epoch}.pt" + # checkpoint_file_name = f "../models/model-{epoch}.pt" # torch.save(model.state_dict(), checkpoint_file_name) # artifact.add_file(checkpoint_file_name) # wandb.log_artifact(artifact) - print(f'Test F1: {test_f1:.4f}') + print(f'Test F1: {train_info["f1-score"]:.4f}') save_model(model, "model.pt") if use_wandb: + wandb.log({"test/cm": wandb.plot.confusion_matrix(probs=None, y_true=test_info["trues"], preds=test_info["preds"], class_names=["neutral", "upvoted"])}) wandb.finish() + diff --git a/embeddings/hetero_GAT.py.save b/embeddings/hetero_GAT.py.save new file mode 100644 index 0000000000000000000000000000000000000000..865939d3c47c49ffbf0511dc847d8d031db35564 --- /dev/null +++ b/embeddings/hetero_GAT.py.save @@ -0,0 +1,297 @@ +import json +import logging +import os +import string +import time + +import networkx as nx +import plotly +import torch +from sklearn.metrics import f1_score, accuracy_score +from torch_geometric.loader import DataLoader +from torch_geometric.nn import HeteroConv, GATConv, Linear, global_mean_pool +import wandb +from torch_geometric.utils import to_networkx + +from custom_logger import setup_custom_logger +from dataset import UserGraphDataset +from dataset_in_memory import UserGraphDatasetInMemory +from Visualize import GraphVisualization + +log = setup_custom_logger("heterogenous_GAT_model", logging.INFO) +torch.multiprocessing.set_sharing_strategy('file_system') +mport resource +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +# resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) + + +class HeteroGNN(torch.nn.Module): + def __init__(self, hidden_channels, out_channels, num_layers): + super().__init__() + + self.convs = torch.nn.ModuleList() + for _ in range(num_layers): + conv = HeteroConv({ + ('tag', 'describes', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('tag', 'describes', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('tag', 'describes', 'comment'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('module', 'imported_in', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('module', 'imported_in', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('question', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('answer', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('comment', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('question', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + ('answer', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False), + }, aggr='sum') + self.convs.append(conv) + + self.lin = Linear(-1, out_channels) + self.softmax = torch.nn.Softmax(dim=-1) + + def forward(self, x_dict, edge_index_dict, batch_dict, post_emb): + # print("IN", post_emb.shape) + for conv in self.convs: + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: x.relu() for key, x in x_dict.items()} + + outs = [] + for x, batch in zip(x_dict.values(), batch_dict.values()): + if len(x): + outs.append(global_mean_pool(x, batch=batch, size=len(post_emb))) + else: + outs.append(torch.zeros(1, x.size(-1))) + + # print([x.shape for x in outs]) + out = torch.cat(outs, dim=1) + + out = torch.cat([out, post_emb], dim=1) + + # print("B4 LINEAR", out.shape) + out = self.lin(out) + out = out.relu() + out = self.softmax(out) + return out + + +''' + +''' + + +def train(model, train_loader): + running_loss = 0.0 + + model.train() + for i, data in enumerate(train_loader): # Iterate in batches over the training dataset. + data.to(device) + + optimizer.zero_grad() # Clear gradients. + + if INCLUDE_ANSWER: + post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device) + else: + post_emb = data.question_emb.to(device) + + out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass. + + loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss. + loss.backward() # Derive gradients. + optimizer.step() # Update parameters based on gradients. + + running_loss += loss.item() + if i % 5 == 0: + log.info(f"[{i + 1}] Loss: {running_loss / 5}") + running_loss = 0.0 + + +def test(loader): + table = wandb.Table(columns=["ground_truth", "prediction"]) if use_wandb else None + model.eval() + + predictions = [] + true_labels = [] + + loss_ = 0 + + for data in loader: # Iterate in batches over the training/test dataset. + data.to(device) + + if INCLUDE_ANSWER: + post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device) + else: + post_emb = data.question_emb.to(device) + + out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass. + + loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss. + loss_ += loss.item() + pred = out.argmax(dim=1) # Use the class with highest probability. + predictions += list([x.item() for x in pred]) + true_labels += list([x.item() for x in data.label]) + # log.info([(x, y) for x,y in zip([x.item() for x in pred], [x.item() for x in data.label])]) + if use_wandb: + #graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data))) + + for pred, label in zip(pred, torch.squeeze(data.label, -1)): + table.add_data(label, pred) + + + + #print([(x, y) for x, y in zip(predictions, true_labels)]) + test_results = { + "accuracy": accuracy_score(true_labels, predictions), + "f1-score": f1_score(true_labels, predictions), + "loss": loss_ / len(loader), + "table": table, + "preds": predictions, + "trues": true_labels + } + return test_results + + +def create_graph_vis(graph): + g = to_networkx(graph.to_homogeneous()) + pos = nx.spring_layout(g) + vis = GraphVisualization( + g, pos, node_text_position='top left', node_size=20, + ) + fig = vis.create_figure() + return fig + + +def init_wandb(project_name: str, dataset): + wandb.init(project=project_name, name="setup") + # Log all the details about the data to W&B. + wandb.log(data_details) + + # Log exploratory visualizations for each data point to W&B + table = wandb.Table(columns=["Graph", "Number of Nodes", "Number of Edges", "Label"]) + for graph in dataset: + fig = create_graph_vis(graph) + n_nodes = graph.num_nodes + n_edges = graph.num_edges + label = graph.label.item() + + # graph_vis = plotly.io.to_html(fig, full_html=False) + + table.add_data(wandb.Plotly(fig), n_nodes, n_edges, label) + wandb.log({"data": table}) + + # Log the dataset to W&B as an artifact. + dataset_artifact = wandb.Artifact(name="static-graphs", type="dataset", metadata=data_details) + dataset_artifact.add_dir("../data/") + wandb.log_artifact(dataset_artifact) + + # End the W&B run + wandb.finish() + + +def start_wandb_for_training(wandb_project_name: str, wandb_run_name: str): + wandb.init(project=wandb_project_name, name=wandb_run_name) + # wandb.use_artifact("static-graphs:latest") + + +def save_model(model, model_name: str): + torch.save(model.state_dict(), os.path.join("..", "models", model_name)) + + +if __name__ == '__main__': + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log.info(f"Proceeding with {device} . .") + + in_memory_dataset = True + # Datasets + if in_memory_dataset: + train_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='train-4175-qs.pt') + test_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='test-1790-qs.pt') + else: + dataset = UserGraphDataset(root="../data", skip_processing=True) + train_size = int(0.7 * len(dataset)) + val_size = int(0.1 * len(dataset)) + test_size = len(dataset) - (train_size + val_size) + + + train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) + + log.info(f"Train Dataset Size: {len(train_dataset)}") + log.info(f"Test Dataset Size: {len(test_dataset)}") + + # Weights&Biases dashboard + data_details = { + "num_node_features": train_dataset.num_node_features, + "num_classes": 2 + } + log.info(f"Data Details:\n{data_details}") + + setup_wandb = False + wandb_project_name = "heterogeneous-GAT-model" + if setup_wandb: + init_wandb(wandb_project_name, dataset) + use_wandb = True + if use_wandb: + wandb_run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}" + start_wandb_for_training(wandb_project_name, wandb_run_name) + + calculate_class_weights = True + # Class weights + sampler = None + if calculate_class_weights: + log.info(f"Calculating class weights") + train_labels = [x.label for x in train_dataset] + counts = [train_labels.count(x) for x in [0, 1]] + class_weights = [1 - (x / sum(counts)) for x in counts] + print(class_weights) + sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels)) + + TRAIN_BATCH_SIZE = 512 + log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}") + + # Dataloaders + train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=TRAIN_BATCH_SIZE, num_workers=14) + + test_loader = DataLoader(test_dataset, batch_size=512, num_workers=14) + + # Model + model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3) + model.to(device) + + # Experiment config + INCLUDE_ANSWER = True + + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + criterion = torch.nn.CrossEntropyLoss() + + for epoch in range(1, 40): + log.info(f"Epoch: {epoch:03d} > > >") + train(model, train_loader) + train_info = test(train_loader) + test_info = test(test_loader) + + print(f'Epoch: {epoch:03d}, Train F1: {train_info["f1-score"]:.4f}, Test F1: {test_info["f1-score"]:.4f}') + checkpoint_file_name = f"../models/model-{epoch}.pt" + torch.save(model.state_dict(), checkpoint_file_name) + if use_wandb: + wandb.log({ + "train/loss": train_info["loss"], + "train/accuracy": train_info["accuracy"], + "train/f1": train_info["f1-score"], + "train/table": train_info["table"], + "test/loss": test_info["loss"], + "test/accuracy": test_info["accuracy"], + "test/f1": test_info["f1-score"], + "test/table": test_info["table"] + }) + # Log model checkpoint as an artifact to W&B + # artifact = wandb.Artifact(name="heterogenous-GAT-static-graphs", type="model") + # checkpoint_file_name = f "../models/model-{epoch}.pt" + # torch.save(model.state_dict(), checkpoint_file_name) + # artifact.add_file(checkpoint_file_name) + # wandb.log_artifact(artifact) + + print(f'Test F1: {test_f1:.4f}') + + save_model(model, "model.pt") + if use_wandb: + wandb.log({"test/cm": wandb.plot.confusion_matrix(probs=None, y_true=test_info["trues"], preds=test_info["preds"], class_names=["neutral", "upvoted"])}) + wandb.finish() +