diff --git a/embeddings/NextTagEmbedding.py b/embeddings/NextTagEmbedding.py index 0d0ad727298fc98185310d3cccbfdf14a1f54a1f..6d5a8576940c8d378ca4388ccc5f21a01b1e792f 100644 --- a/embeddings/NextTagEmbedding.py +++ b/embeddings/NextTagEmbedding.py @@ -162,16 +162,18 @@ class NextTagEmbedding(nn.Module): if __name__ == '__main__': tet = NextTagEmbeddingTrainer(context_length=2, emb_size=30, excluded_tags=['python'], database_path="../stackoverflow.db") - tet.from_db() - print(len(tet.post_tags)) - print(len(tet.tag_vocab)) + #tet.from_db() + #print(len(tet.post_tags)) + #print(len(tet.tag_vocab)) #tet = NextTagEmbeddingTrainer(context_length=3, emb_size=50) - #tet.from_files("../data/raw/all_tags.csv", "../data/raw/tag_vocab.csv") + tet.from_files("../all_tags.csv", "../tag_vocab.csv") # assert len(tet.post_tags) == 84187510, "Incorrect number of post tags!" # assert len(tet.tag_vocab) == 63653, "Incorrect vocab size!" + + print(len(tet.post_tags)) tet.train(1000, 1) # tet.to_tensorboard(f"run@{time.strftime('%Y%m%d-%H%M%S')}") diff --git a/embeddings/__pycache__/dataset.cpython-39.pyc b/embeddings/__pycache__/dataset.cpython-39.pyc index fd97a912b458d313f1ee3b398718af41bcc1a946..5a89c09867736d07f5956e8241d511ed95476100 100644 Binary files a/embeddings/__pycache__/dataset.cpython-39.pyc and b/embeddings/__pycache__/dataset.cpython-39.pyc differ diff --git a/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc b/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc index b144eff4b8fc104d87e4e3c6da288641d1083a7f..b0f15c4568700226cde4638c04bf1bfcf832dee9 100644 Binary files a/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc and b/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc differ diff --git a/embeddings/__pycache__/helper_functions.cpython-39.pyc b/embeddings/__pycache__/helper_functions.cpython-39.pyc index c703478f4863918e0f7707354d90e4d738152c74..0235bc1e59c4ce3b5175b5dfeb5eb5f57eba5468 100644 Binary files a/embeddings/__pycache__/helper_functions.cpython-39.pyc and b/embeddings/__pycache__/helper_functions.cpython-39.pyc differ diff --git a/embeddings/__pycache__/hetero_GAT_constants.cpython-39.pyc b/embeddings/__pycache__/hetero_GAT_constants.cpython-39.pyc index dc4dbeba04cf288cfd9bf78b062a95b945be411b..4fb58f303d20a0568ce0f968a6241a656623153b 100644 Binary files a/embeddings/__pycache__/hetero_GAT_constants.cpython-39.pyc and b/embeddings/__pycache__/hetero_GAT_constants.cpython-39.pyc differ diff --git a/embeddings/__pycache__/static_graph_construction.cpython-39.pyc b/embeddings/__pycache__/static_graph_construction.cpython-39.pyc index 2d145e62ef1474e70b94717c8d6774b17b1e19f1..1ef4ec5d29e091fa7b326d0c1910abf0eddefc42 100644 Binary files a/embeddings/__pycache__/static_graph_construction.cpython-39.pyc and b/embeddings/__pycache__/static_graph_construction.cpython-39.pyc differ diff --git a/embeddings/gnn_sweep.py b/embeddings/gnn_sweep.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5d830af74d2f80ab7fec622e4d9617b88a42e1 --- /dev/null +++ b/embeddings/gnn_sweep.py @@ -0,0 +1,321 @@ +import json +import logging +import os +import string +import time + +import networkx as nx +import pandas as pd +import plotly +import torch +from sklearn.metrics import f1_score, accuracy_score +from torch_geometric.loader import DataLoader +from torch_geometric.nn import HeteroConv, GATv2Conv, GATConv, Linear, global_mean_pool, GCNConv, SAGEConv +from helper_functions import calculate_class_weights, split_test_train_pytorch +import wandb +from torch_geometric.utils import to_networkx +import torch.nn.functional as F +from sklearn.model_selection import KFold +from torch.optim.lr_scheduler import ExponentialLR +import pickle + +from custom_logger import setup_custom_logger +from dataset import UserGraphDataset +from dataset_in_memory import UserGraphDatasetInMemory +from Visualize import GraphVisualization +import helper_functions +from hetero_GAT_constants import OS_NAME, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME, CROSS_VALIDATE, FOLD_FILES, USE_CLASS_WEIGHTS_SAMPLER, USE_CLASS_WEIGHTS_LOSS, DROPOUT, GAMMA, START_LR, PICKLE_PATH_KF, ROOT, TRAIN_DATA_PATH, TEST_DATA_PATH, WARM_START_FILE, MODEL, REL_SUBSET + +log = setup_custom_logger("heterogenous_GAT_model", logging.INFO) + +if OS_NAME == "linux": + torch.multiprocessing.set_sharing_strategy('file_system') + import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) + + +""" +G +A +T +""" + +class HeteroGAT(torch.nn.Module): + """ + Heterogeneous Graph Attentional Network (GAT) model. + """ + def __init__(self, hidden_channels, out_channels, num_layers): + super().__init__() + log.info("MODEL: GAT") + + self.convs = torch.nn.ModuleList() + + # Create Graph Attentional layers + for _ in range(num_layers): + conv = HeteroConv({ + ('tag', 'describes', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('tag', 'describes', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('tag', 'describes', 'comment'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('module', 'imported_in', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('module', 'imported_in', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('question', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('answer', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('comment', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('question', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('answer', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + }, aggr='sum') + self.convs.append(conv) + + self.lin1 = Linear(-1, hidden_channels) + self.lin2 = Linear(hidden_channels, out_channels) + self.softmax = torch.nn.Softmax(dim=-1) + + def forward(self, x_dict, edge_index_dict, batch_dict, post_emb): + x_dict = {key: x_dict[key] for key in x_dict.keys() if key in ["question", "answer", "comment", "tag"]} + + + for conv in self.convs: + break + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()} + x_dict = {key: F.dropout(x, p=DROPOUT, training=self.training) for key, x in x_dict.items()} + + outs = [] + + for x, batch in zip(x_dict.values(), batch_dict.values()): + if len(x): + outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)).to(device)) + else: + outs.append(torch.zeros(1, x.size(-1)).to(device)) + + + out = torch.cat(outs, dim=1).to(device) + + out = torch.cat([out, post_emb], dim=1).to(device) + + out = F.dropout(out, p=DROPOUT, training=self.training) + + + out = self.lin1(out) + out = F.leaky_relu(out) + + out = self.lin2(out) + out = F.leaky_relu(out) + + out = self.softmax(out) + return out + + +""" +T +R +A +I +N +""" +def train_epoch(train_loader): + running_loss = 0.0 + model.train() + + for i, data in enumerate(train_loader): # Iterate in batches over the training dataset. + data.to(device) + + optimizer.zero_grad() # Clear gradients. + + if INCLUDE_ANSWER: + # Concatenate question and answer embeddings to form post embeddings + post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device) + else: + # Use only question embeddings as post embedding + post_emb = data.question_emb.to(device) + post_emb.requires_grad = True + + out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass. + + #y = torch.tensor([1 if x > 0 else 0 for x in data.score]).to(device) + loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss. + loss.backward() # Derive gradients. + optimizer.step() # Update parameters based on gradients. + + running_loss += loss.item() + if i % 5 == 0: + log.info(f"[{i + 1}] Loss: {running_loss / 5}") + running_loss = 0.0 + +""" +T +E +S +T +""" +def test(loader): + table = wandb.Table(columns=["ground_truth", "prediction"]) if USE_WANDB else None + model.eval() + + predictions = [] + true_labels = [] + + cumulative_loss = 0 + + for data in loader: # Iterate in batches over the training/test dataset. + data.to(device) + + if INCLUDE_ANSWER: + post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device) + else: + post_emb = data.question_emb.to(device) + + out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass. + + #y = torch.tensor([1 if x > 0 else 0 for x in data.score]).to(device) + loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss. + cumulative_loss += loss.item() + + # Use the class with highest probability. + pred = out.argmax(dim=1) + + # Cache the predictions for calculating metrics + predictions += list([x.item() for x in pred]) + true_labels += list([x.item() for x in data.label]) + + # Log table of predictions to WandB + if USE_WANDB: + #graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data))) + + for pred, label in zip(pred, torch.squeeze(data.label, -1)): + table.add_data(label, pred) + + # Collate results into a single dictionary + test_results = { + "accuracy": accuracy_score(true_labels, predictions), + "f1-score-weighted": f1_score(true_labels, predictions, average='weighted'), + "f1-score-macro": f1_score(true_labels, predictions, average='macro'), + "loss": cumulative_loss / len(loader), + "table": table, + "preds": predictions, + "trues": true_labels + } + return test_results + + + + +""" +SWEEP +""" + +def build_dataset(train_batch_size): + train_dataset = UserGraphDatasetInMemory(root=ROOT, file_name_out=TRAIN_DATA_PATH, question_ids=[]) + test_dataset = UserGraphDatasetInMemory(root=ROOT, file_name_out=TEST_DATA_PATH, question_ids=[]) + + class_weights = calculate_class_weights(train_dataset).to(device) + train_labels = [x.label for x in train_dataset] + sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels)) + + + # Dataloaders + log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}") + train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=train_batch_size, num_workers=NUM_WORKERS) + test_loader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, num_workers=NUM_WORKERS) + return train_loader, test_loader + + +def build_network(channels, layers): + model = HeteroGAT(hidden_channels=channels, out_channels=2, num_layers=layers) + return model.to(device) + + + + +def train(config=None): + # Initialize a new wandb run + with wandb.init(config=config): + # If called by wandb.agent, as below, + # this config will be set by Sweep Controller + config = wandb.config + + train_loader, test_loader = build_dataset(config.batch_size) + + DROPOUT = config.dropout + global model + model = build_network(config.hidden_channels, config.num_layers) + + + # Optimizers & Loss function + global optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=config.initial_lr) + global scheduler + scheduler = ExponentialLR(optimizer, gamma=GAMMA, verbose=True) + + # Cross Entropy Loss (with optional class weights) + global criterion + criterion = torch.nn.CrossEntropyLoss() + + for epoch in range(config.epochs): + train_epoch(train_loader) + f1 = test(test_loader) + wandb.log({'validation/weighted-f1': f1, "epoch": epoch}) + +def test(loader): + model.eval() + + predictions = [] + true_labels = [] + + for data in loader: # Iterate in batches over the training/test dataset. + data.to(device) + + if INCLUDE_ANSWER: + post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device) + else: + post_emb = data.question_emb.to(device) + + out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass. + + loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss. + + # Use the class with highest probability. + pred = out.argmax(dim=1) + + # Cache the predictions for calculating metrics + predictions += list([x.item() for x in pred]) + true_labels += list([x.item() for x in data.label]) + + + return f1_score(true_labels, predictions, average='weighted') + +""" +M +A +I +N +""" +if __name__ == '__main__': + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log.info(f"Proceeding with {device} . .") + + wandb.login() + + sweep_configuration = sweep_configuration = { + 'method': 'bayes', + 'name': 'sweep', + 'metric': { + 'goal': 'maximize', + 'name': 'validation/weighted-f1' + }, + 'parameters': { + 'batch_size': {'values': [32, 64, 128, 256]}, + 'epochs': {'max': 100, 'min': 5}, + 'initial_lr': {'max': 0.015, 'min': 0.0001}, + 'num_layers': {'values': [1,2,3]}, + 'hidden_channels': {'values': [32, 64, 128, 256]}, + 'dropout': {'max': 0.9, 'min': 0.2} + } + } + + sweep_id = wandb.sweep(sweep=sweep_configuration, project=WANDB_PROJECT_NAME) + wandb.agent(sweep_id, function=train, count=100) + + + diff --git a/embeddings/helper_functions.py b/embeddings/helper_functions.py index d5ec672b4dd906f5fc243551e565c0a815ce21c5..9699079ce566f037b516b3762e7f9ff831896fa5 100644 --- a/embeddings/helper_functions.py +++ b/embeddings/helper_functions.py @@ -65,7 +65,8 @@ def log_results_to_wandb(results_map, results_name: str): wandb.log({ f"{results_name}/loss": results_map["loss"], f"{results_name}/accuracy": results_map["accuracy"], - f"{results_name}/f1": results_map["f1-score"], + f"{results_name}/f1-macro": results_map["f1-score-macro"], + f"{results_name}/f1-weighted": results_map["f1-score-weighted"], f"{results_name}/table": results_map["table"] }) diff --git a/embeddings/hetero_GAT.py b/embeddings/hetero_GAT.py index f9e4f171e5e8a05c0545db3571a85e43c1886929..2275741e002bb5721c601e844b5d96a84c9f9785 100644 --- a/embeddings/hetero_GAT.py +++ b/embeddings/hetero_GAT.py @@ -10,7 +10,7 @@ import plotly import torch from sklearn.metrics import f1_score, accuracy_score from torch_geometric.loader import DataLoader -from torch_geometric.nn import HeteroConv, GATv2Conv, GATConv, Linear, global_mean_pool, GCNConv, SAGEConv +from torch_geometric.nn import HeteroConv, GATv2Conv, GATConv, Linear, global_mean_pool, GCNConv, SAGEConv, GraphConv from helper_functions import calculate_class_weights, split_test_train_pytorch import wandb from torch_geometric.utils import to_networkx @@ -24,7 +24,7 @@ from dataset import UserGraphDataset from dataset_in_memory import UserGraphDatasetInMemory from Visualize import GraphVisualization import helper_functions -from hetero_GAT_constants import OS_NAME, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME, CROSS_VALIDATE, FOLD_FILES, USE_CLASS_WEIGHTS_SAMPLER, USE_CLASS_WEIGHTS_LOSS, DROPOUT, GAMMA, START_LR, PICKLE_PATH_KF, ROOT, TRAIN_DATA_PATH, TEST_DATA_PATH, WARM_START_FILE, MODEL +from hetero_GAT_constants import OS_NAME, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME, CROSS_VALIDATE, FOLD_FILES, USE_CLASS_WEIGHTS_SAMPLER, USE_CLASS_WEIGHTS_LOSS, DROPOUT, GAMMA, START_LR, PICKLE_PATH_KF, ROOT, TRAIN_DATA_PATH, TEST_DATA_PATH, WARM_START_FILE, MODEL, REL_SUBSET log = setup_custom_logger("heterogenous_GAT_model", logging.INFO) @@ -47,6 +47,7 @@ class HeteroGAT(torch.nn.Module): """ def __init__(self, hidden_channels, out_channels, num_layers): super().__init__() + log.info("MODEL: GAT") self.convs = torch.nn.ModuleList() @@ -71,12 +72,17 @@ class HeteroGAT(torch.nn.Module): self.softmax = torch.nn.Softmax(dim=-1) def forward(self, x_dict, edge_index_dict, batch_dict, post_emb): + x_dict = {key: x_dict[key] for key in x_dict.keys() if key in ["question", "answer", "comment", "tag", "module"]} + + for conv in self.convs: + break x_dict = conv(x_dict, edge_index_dict) x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()} x_dict = {key: F.dropout(x, p=DROPOUT, training=self.training) for key, x in x_dict.items()} outs = [] + for x, batch in zip(x_dict.values(), batch_dict.values()): if len(x): outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)).to(device)) @@ -120,7 +126,7 @@ class HeteroGraphSAGE(torch.nn.Module): """ def __init__(self, hidden_channels, out_channels, num_layers): super().__init__() - + log.info("MODEL: GraphSAGE") self.convs = torch.nn.ModuleList() # Create Graph Attentional layers @@ -173,6 +179,87 @@ class HeteroGraphSAGE(torch.nn.Module): out = self.softmax(out) return out +""" +G +R +A +P +H +C +O +N +V +""" +""" +G +A +T +""" + +class HeteroGraphConv(torch.nn.Module): + """ + Heterogeneous GraphConv model. + """ + def __init__(self, hidden_channels, out_channels, num_layers): + super().__init__() + log.info("MODEL: GraphConv") + + self.convs = torch.nn.ModuleList() + + # Create Graph Attentional layers + for _ in range(num_layers): + conv = HeteroConv({ + ('tag', 'describes', 'question'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('tag', 'describes', 'answer'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('tag', 'describes', 'comment'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('module', 'imported_in', 'question'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('module', 'imported_in', 'answer'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('question', 'rev_describes', 'tag'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('answer', 'rev_describes', 'tag'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('comment', 'rev_describes', 'tag'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('question', 'rev_imported_in', 'module'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('answer', 'rev_imported_in', 'module'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + }, aggr='sum') + self.convs.append(conv) + + self.lin1 = Linear(-1, hidden_channels) + self.lin2 = Linear(hidden_channels, out_channels) + self.softmax = torch.nn.Softmax(dim=-1) + + def forward(self, x_dict, edge_index_dict, batch_dict, post_emb): + x_dict = {key: x_dict[key] for key in x_dict.keys() if key in ["question", "answer", "comment", "tag", "module"]} + + + for conv in self.convs: + break + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()} + x_dict = {key: F.dropout(x, p=DROPOUT, training=self.training) for key, x in x_dict.items()} + + outs = [] + + for x, batch in zip(x_dict.values(), batch_dict.values()): + if len(x): + outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)).to(device)) + else: + outs.append(torch.zeros(1, x.size(-1)).to(device)) + + + out = torch.cat(outs, dim=1).to(device) + + out = torch.cat([out, post_emb], dim=1).to(device) + + out = F.dropout(out, p=DROPOUT, training=self.training) + + + out = self.lin1(out) + out = F.leaky_relu(out) + + out = self.lin2(out) + out = F.leaky_relu(out) + + out = self.softmax(out) + return out """ T @@ -291,7 +378,7 @@ if __name__ == '__main__': config.initial_lr = START_LR config.gamma = GAMMA config.batch_size = TRAIN_BATCH_SIZE - + # Datasets if IN_MEMORY_DATASET: @@ -300,7 +387,7 @@ if __name__ == '__main__': else: dataset = UserGraphDataset(root=ROOT, skip_processing=True) train_dataset, test_dataset = split_test_train_pytorch(dataset, 0.7) - + if CROSS_VALIDATE: print(FOLD_FILES) folds = [UserGraphDatasetInMemory(root="../data", file_name_out=fold_path, question_ids=[]) for fold_path in FOLD_FILES] @@ -385,6 +472,13 @@ if __name__ == '__main__': if USE_WANDB: wandb.log(data_details) + # Take subset for EXP3 + if REL_SUBSET is not None: + indices = list(range(int(len(train_dataset)*REL_SUBSET))) + train_dataset = torch.utils.data.Subset(train_dataset, indices) + log.info(f"Subset contains {len(train_dataset)}") + + sampler = None class_weights = calculate_class_weights(train_dataset).to(device) @@ -392,7 +486,8 @@ if __name__ == '__main__': if USE_CLASS_WEIGHTS_SAMPLER: train_labels = [x.label for x in train_dataset] sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels)) - + + # Dataloaders log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}") train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=TRAIN_BATCH_SIZE, num_workers=NUM_WORKERS) @@ -403,11 +498,12 @@ if __name__ == '__main__': model = HeteroGAT(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS) elif MODEL == "SAGE": model = HeteroGraphSAGE(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS) + elif MODEL == "GC": + model = HeteroGraphConv(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS) else: log.error(f"Model does not exist! ({MODEL})") exit(1) - model = HeteroGraphSAGE(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS) model.to(device) # To GPU if available if WARM_START_FILE is not None: diff --git a/embeddings/hetero_GAT_constants.py b/embeddings/hetero_GAT_constants.py index 4b296bb7c769afe9c25e9ef22d35771b31310d5b..6c4ddedd34dc79dc654f32fb8dd36c73b307d98a 100644 --- a/embeddings/hetero_GAT_constants.py +++ b/embeddings/hetero_GAT_constants.py @@ -11,7 +11,7 @@ USE_CLASS_WEIGHTS_LOSS = False # W&B dashboard logging USE_WANDB = False WANDB_PROJECT_NAME = "heterogeneous-GAT-model" -WANDB_RUN_NAME = "EXP1-run" # None for timestamp +WANDB_RUN_NAME = None # None for timestamp # OS OS_NAME = "linux" # "windows" or "linux" @@ -21,10 +21,11 @@ NUM_WORKERS = 14 ROOT = "../../../data/lhb1g20" TRAIN_DATA_PATH = "../../../../../data/lhb1g20/train-4175-qs.pt" TEST_DATA_PATH = "../../../../../data/lhb1g20/test-1790-qs.pt" -EPOCHS = 10 +REL_SUBSET = None +EPOCHS = 20 START_LR = 0.001 GAMMA = 0.95 -WARM_START_FILE = "../models/gat_qa_20e_64h_3l.pt" +WARM_START_FILE = None #"../models/gat_qa_20e_64h_3l.pt" # (Optional) k-fold cross validation CROSS_VALIDATE = False @@ -32,9 +33,9 @@ FOLD_FILES = ['fold-1-6001-qs.pt', 'fold-2-6001-qs.pt', 'fold-3-6001-qs.pt', 'fo PICKLE_PATH_KF = 'q_kf_results.pkl' # Model architecture -MODEL = "GAT" +MODEL = "GC" NUM_LAYERS = 3 HIDDEN_CHANNELS = 64 -FINAL_MODEL_OUT_PATH = "gat_qa_10e_64h_3l.pt" -SAVE_CHECKPOINTS = False -DROPOUT=0.0 +FINAL_MODEL_OUT_PATH = "SAGE_3l_60e_64h.pt" +SAVE_CHECKPOINTS = True +DROPOUT=0.3