diff --git a/embeddings/NextTagEmbedding.py b/embeddings/NextTagEmbedding.py
index 0d0ad727298fc98185310d3cccbfdf14a1f54a1f..6d5a8576940c8d378ca4388ccc5f21a01b1e792f 100644
--- a/embeddings/NextTagEmbedding.py
+++ b/embeddings/NextTagEmbedding.py
@@ -162,16 +162,18 @@ class NextTagEmbedding(nn.Module):
 
 if __name__ == '__main__':
     tet = NextTagEmbeddingTrainer(context_length=2, emb_size=30, excluded_tags=['python'], database_path="../stackoverflow.db")
-    tet.from_db()
-    print(len(tet.post_tags))
-    print(len(tet.tag_vocab))
+    #tet.from_db()
+    #print(len(tet.post_tags))
+    #print(len(tet.tag_vocab))
 
 
     #tet = NextTagEmbeddingTrainer(context_length=3, emb_size=50)
 
-    #tet.from_files("../data/raw/all_tags.csv", "../data/raw/tag_vocab.csv")
+    tet.from_files("../all_tags.csv", "../tag_vocab.csv")
     # assert len(tet.post_tags) == 84187510, "Incorrect number of post tags!"
     # assert len(tet.tag_vocab) == 63653, "Incorrect vocab size!"
+    
+    print(len(tet.post_tags))
 
     tet.train(1000, 1)
     # tet.to_tensorboard(f"run@{time.strftime('%Y%m%d-%H%M%S')}")
diff --git a/embeddings/__pycache__/dataset.cpython-39.pyc b/embeddings/__pycache__/dataset.cpython-39.pyc
index fd97a912b458d313f1ee3b398718af41bcc1a946..5a89c09867736d07f5956e8241d511ed95476100 100644
Binary files a/embeddings/__pycache__/dataset.cpython-39.pyc and b/embeddings/__pycache__/dataset.cpython-39.pyc differ
diff --git a/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc b/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc
index b144eff4b8fc104d87e4e3c6da288641d1083a7f..b0f15c4568700226cde4638c04bf1bfcf832dee9 100644
Binary files a/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc and b/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc differ
diff --git a/embeddings/__pycache__/helper_functions.cpython-39.pyc b/embeddings/__pycache__/helper_functions.cpython-39.pyc
index c703478f4863918e0f7707354d90e4d738152c74..0235bc1e59c4ce3b5175b5dfeb5eb5f57eba5468 100644
Binary files a/embeddings/__pycache__/helper_functions.cpython-39.pyc and b/embeddings/__pycache__/helper_functions.cpython-39.pyc differ
diff --git a/embeddings/__pycache__/hetero_GAT_constants.cpython-39.pyc b/embeddings/__pycache__/hetero_GAT_constants.cpython-39.pyc
index dc4dbeba04cf288cfd9bf78b062a95b945be411b..4fb58f303d20a0568ce0f968a6241a656623153b 100644
Binary files a/embeddings/__pycache__/hetero_GAT_constants.cpython-39.pyc and b/embeddings/__pycache__/hetero_GAT_constants.cpython-39.pyc differ
diff --git a/embeddings/__pycache__/static_graph_construction.cpython-39.pyc b/embeddings/__pycache__/static_graph_construction.cpython-39.pyc
index 2d145e62ef1474e70b94717c8d6774b17b1e19f1..1ef4ec5d29e091fa7b326d0c1910abf0eddefc42 100644
Binary files a/embeddings/__pycache__/static_graph_construction.cpython-39.pyc and b/embeddings/__pycache__/static_graph_construction.cpython-39.pyc differ
diff --git a/embeddings/gnn_sweep.py b/embeddings/gnn_sweep.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d5d830af74d2f80ab7fec622e4d9617b88a42e1
--- /dev/null
+++ b/embeddings/gnn_sweep.py
@@ -0,0 +1,321 @@
+import json
+import logging
+import os
+import string
+import time
+
+import networkx as nx
+import pandas as pd
+import plotly
+import torch
+from sklearn.metrics import f1_score, accuracy_score
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import HeteroConv, GATv2Conv, GATConv, Linear, global_mean_pool, GCNConv, SAGEConv
+from helper_functions import calculate_class_weights, split_test_train_pytorch
+import wandb
+from torch_geometric.utils import to_networkx
+import torch.nn.functional as F
+from sklearn.model_selection import KFold
+from torch.optim.lr_scheduler import ExponentialLR
+import pickle
+
+from custom_logger import setup_custom_logger
+from dataset import UserGraphDataset
+from dataset_in_memory import UserGraphDatasetInMemory
+from Visualize import GraphVisualization
+import helper_functions
+from hetero_GAT_constants import OS_NAME, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME, CROSS_VALIDATE, FOLD_FILES, USE_CLASS_WEIGHTS_SAMPLER, USE_CLASS_WEIGHTS_LOSS, DROPOUT, GAMMA, START_LR, PICKLE_PATH_KF, ROOT, TRAIN_DATA_PATH, TEST_DATA_PATH, WARM_START_FILE, MODEL, REL_SUBSET
+
+log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
+
+if OS_NAME == "linux":
+    torch.multiprocessing.set_sharing_strategy('file_system')
+    import resource
+    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+    resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
+
+
+"""
+G
+A
+T
+"""
+
+class HeteroGAT(torch.nn.Module):
+    """
+    Heterogeneous Graph Attentional Network (GAT) model.
+    """
+    def __init__(self, hidden_channels, out_channels, num_layers):
+        super().__init__()
+        log.info("MODEL: GAT")
+
+        self.convs = torch.nn.ModuleList()
+
+        # Create Graph Attentional layers
+        for _ in range(num_layers):
+            conv = HeteroConv({
+                ('tag', 'describes', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('tag', 'describes', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('tag', 'describes', 'comment'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('module', 'imported_in', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('module', 'imported_in', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('question', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('answer', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('comment', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('question', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('answer', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+            }, aggr='sum')
+            self.convs.append(conv)
+
+        self.lin1 = Linear(-1, hidden_channels)
+        self.lin2 = Linear(hidden_channels, out_channels)
+        self.softmax = torch.nn.Softmax(dim=-1)
+
+    def forward(self, x_dict, edge_index_dict, batch_dict, post_emb):
+        x_dict = {key: x_dict[key] for key in x_dict.keys() if key in ["question", "answer", "comment", "tag"]}
+
+        
+        for conv in self.convs:
+            break
+            x_dict = conv(x_dict, edge_index_dict)
+            x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
+            x_dict = {key: F.dropout(x, p=DROPOUT, training=self.training) for key, x in x_dict.items()}
+
+        outs = []
+        
+        for x, batch in zip(x_dict.values(), batch_dict.values()):
+            if len(x):
+                outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)).to(device))
+            else:
+                outs.append(torch.zeros(1, x.size(-1)).to(device))
+
+
+        out = torch.cat(outs, dim=1).to(device)
+
+        out = torch.cat([out, post_emb], dim=1).to(device)
+        
+        out = F.dropout(out, p=DROPOUT, training=self.training)
+
+
+        out = self.lin1(out)
+        out = F.leaky_relu(out)
+        
+        out = self.lin2(out)
+        out = F.leaky_relu(out)
+        
+        out = self.softmax(out)
+        return out
+
+
+"""
+T
+R
+A
+I
+N
+"""        
+def train_epoch(train_loader):
+    running_loss = 0.0
+    model.train()
+
+    for i, data in enumerate(train_loader):  # Iterate in batches over the training dataset.
+        data.to(device)
+
+        optimizer.zero_grad()  # Clear gradients.
+
+        if INCLUDE_ANSWER:
+            # Concatenate question and answer embeddings to form post embeddings
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            # Use only question embeddings as post embedding
+            post_emb = data.question_emb.to(device)
+        post_emb.requires_grad = True
+
+        out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
+
+        #y = torch.tensor([1 if x > 0 else 0 for x in data.score]).to(device)
+        loss = criterion(out, torch.squeeze(data.label, -1))  # Compute the loss.
+        loss.backward()  # Derive gradients.
+        optimizer.step()  # Update parameters based on gradients.
+
+        running_loss += loss.item()
+        if i % 5 == 0:
+            log.info(f"[{i + 1}] Loss: {running_loss / 5}")
+            running_loss = 0.0
+
+"""
+T
+E
+S
+T
+"""
+def test(loader):
+    table = wandb.Table(columns=["ground_truth", "prediction"]) if USE_WANDB else None
+    model.eval()
+
+    predictions = []
+    true_labels = []
+
+    cumulative_loss = 0
+
+    for data in loader:  # Iterate in batches over the training/test dataset.
+        data.to(device)
+        
+        if INCLUDE_ANSWER:
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            post_emb = data.question_emb.to(device)
+
+        out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
+
+        #y = torch.tensor([1 if x > 0 else 0 for x in data.score]).to(device)
+        loss = criterion(out, torch.squeeze(data.label, -1))  # Compute the loss.
+        cumulative_loss += loss.item()
+        
+        # Use the class with highest probability.
+        pred = out.argmax(dim=1)  
+        
+        # Cache the predictions for calculating metrics
+        predictions += list([x.item() for x in pred])
+        true_labels += list([x.item() for x in data.label])
+        
+        # Log table of predictions to WandB
+        if USE_WANDB:
+            #graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data)))
+            
+            for pred, label in zip(pred, torch.squeeze(data.label, -1)):
+                table.add_data(label, pred)
+    
+    # Collate results into a single dictionary
+    test_results = {
+        "accuracy": accuracy_score(true_labels, predictions),
+        "f1-score-weighted": f1_score(true_labels, predictions, average='weighted'),
+        "f1-score-macro": f1_score(true_labels, predictions, average='macro'),
+        "loss": cumulative_loss / len(loader),
+        "table": table,
+        "preds": predictions, 
+        "trues": true_labels 
+    }
+    return test_results
+
+
+
+
+"""
+SWEEP
+"""
+
+def build_dataset(train_batch_size):
+    train_dataset = UserGraphDatasetInMemory(root=ROOT, file_name_out=TRAIN_DATA_PATH, question_ids=[])
+    test_dataset = UserGraphDatasetInMemory(root=ROOT, file_name_out=TEST_DATA_PATH, question_ids=[])
+
+    class_weights = calculate_class_weights(train_dataset).to(device)
+    train_labels = [x.label for x in train_dataset]
+    sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
+    
+    
+    # Dataloaders
+    log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}")
+    train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=train_batch_size, num_workers=NUM_WORKERS)
+    test_loader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, num_workers=NUM_WORKERS)
+    return train_loader, test_loader
+
+    
+def build_network(channels, layers):
+    model = HeteroGAT(hidden_channels=channels, out_channels=2, num_layers=layers)
+    return model.to(device)
+    
+
+    
+
+def train(config=None):
+    # Initialize a new wandb run
+    with wandb.init(config=config):
+        # If called by wandb.agent, as below,
+        # this config will be set by Sweep Controller
+        config = wandb.config
+
+        train_loader, test_loader = build_dataset(config.batch_size)
+        
+        DROPOUT = config.dropout
+        global model
+        model = build_network(config.hidden_channels, config.num_layers)
+        
+        
+        # Optimizers & Loss function
+        global optimizer
+        optimizer = torch.optim.Adam(model.parameters(), lr=config.initial_lr)
+        global scheduler
+        scheduler = ExponentialLR(optimizer, gamma=GAMMA, verbose=True)
+    
+        # Cross Entropy Loss (with optional class weights)
+        global criterion
+        criterion = torch.nn.CrossEntropyLoss()
+        
+        for epoch in range(config.epochs):
+            train_epoch(train_loader)
+            f1 = test(test_loader)
+            wandb.log({'validation/weighted-f1': f1, "epoch": epoch})    
+
+def test(loader):
+    model.eval()
+
+    predictions = []
+    true_labels = []
+
+    for data in loader:  # Iterate in batches over the training/test dataset.
+        data.to(device)
+        
+        if INCLUDE_ANSWER:
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            post_emb = data.question_emb.to(device)
+
+        out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
+
+        loss = criterion(out, torch.squeeze(data.label, -1))  # Compute the loss.
+        
+        # Use the class with highest probability.
+        pred = out.argmax(dim=1)  
+        
+        # Cache the predictions for calculating metrics
+        predictions += list([x.item() for x in pred])
+        true_labels += list([x.item() for x in data.label])
+        
+    
+    return f1_score(true_labels, predictions, average='weighted')
+
+"""
+M
+A
+I
+N
+"""
+if __name__ == '__main__':
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    log.info(f"Proceeding with {device} . .")
+
+    wandb.login()
+    
+    sweep_configuration = sweep_configuration = {
+        'method': 'bayes',
+        'name': 'sweep',
+        'metric': {
+            'goal': 'maximize', 
+            'name': 'validation/weighted-f1'
+        },
+        'parameters': {
+            'batch_size': {'values': [32, 64, 128, 256]},
+            'epochs': {'max': 100, 'min': 5},
+            'initial_lr': {'max': 0.015, 'min': 0.0001},
+            'num_layers': {'values': [1,2,3]},
+            'hidden_channels': {'values': [32, 64, 128, 256]},
+            'dropout': {'max': 0.9, 'min': 0.2}
+        }
+    }
+    
+    sweep_id = wandb.sweep(sweep=sweep_configuration, project=WANDB_PROJECT_NAME)
+    wandb.agent(sweep_id, function=train, count=100)
+        
+
+
diff --git a/embeddings/helper_functions.py b/embeddings/helper_functions.py
index d5ec672b4dd906f5fc243551e565c0a815ce21c5..9699079ce566f037b516b3762e7f9ff831896fa5 100644
--- a/embeddings/helper_functions.py
+++ b/embeddings/helper_functions.py
@@ -65,7 +65,8 @@ def log_results_to_wandb(results_map, results_name: str):
     wandb.log({
                 f"{results_name}/loss": results_map["loss"],
                 f"{results_name}/accuracy": results_map["accuracy"],
-                f"{results_name}/f1": results_map["f1-score"],
+                f"{results_name}/f1-macro": results_map["f1-score-macro"],
+                f"{results_name}/f1-weighted": results_map["f1-score-weighted"],
                 f"{results_name}/table": results_map["table"]
             })
             
diff --git a/embeddings/hetero_GAT.py b/embeddings/hetero_GAT.py
index f9e4f171e5e8a05c0545db3571a85e43c1886929..2275741e002bb5721c601e844b5d96a84c9f9785 100644
--- a/embeddings/hetero_GAT.py
+++ b/embeddings/hetero_GAT.py
@@ -10,7 +10,7 @@ import plotly
 import torch
 from sklearn.metrics import f1_score, accuracy_score
 from torch_geometric.loader import DataLoader
-from torch_geometric.nn import HeteroConv, GATv2Conv, GATConv, Linear, global_mean_pool, GCNConv, SAGEConv
+from torch_geometric.nn import HeteroConv, GATv2Conv, GATConv, Linear, global_mean_pool, GCNConv, SAGEConv, GraphConv
 from helper_functions import calculate_class_weights, split_test_train_pytorch
 import wandb
 from torch_geometric.utils import to_networkx
@@ -24,7 +24,7 @@ from dataset import UserGraphDataset
 from dataset_in_memory import UserGraphDatasetInMemory
 from Visualize import GraphVisualization
 import helper_functions
-from hetero_GAT_constants import OS_NAME, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME, CROSS_VALIDATE, FOLD_FILES, USE_CLASS_WEIGHTS_SAMPLER, USE_CLASS_WEIGHTS_LOSS, DROPOUT, GAMMA, START_LR, PICKLE_PATH_KF, ROOT, TRAIN_DATA_PATH, TEST_DATA_PATH, WARM_START_FILE, MODEL
+from hetero_GAT_constants import OS_NAME, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME, CROSS_VALIDATE, FOLD_FILES, USE_CLASS_WEIGHTS_SAMPLER, USE_CLASS_WEIGHTS_LOSS, DROPOUT, GAMMA, START_LR, PICKLE_PATH_KF, ROOT, TRAIN_DATA_PATH, TEST_DATA_PATH, WARM_START_FILE, MODEL, REL_SUBSET
 
 log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
 
@@ -47,6 +47,7 @@ class HeteroGAT(torch.nn.Module):
     """
     def __init__(self, hidden_channels, out_channels, num_layers):
         super().__init__()
+        log.info("MODEL: GAT")
 
         self.convs = torch.nn.ModuleList()
 
@@ -71,12 +72,17 @@ class HeteroGAT(torch.nn.Module):
         self.softmax = torch.nn.Softmax(dim=-1)
 
     def forward(self, x_dict, edge_index_dict, batch_dict, post_emb):
+        x_dict = {key: x_dict[key] for key in x_dict.keys() if key in ["question", "answer", "comment", "tag", "module"]}
+
+        
         for conv in self.convs:
+            break
             x_dict = conv(x_dict, edge_index_dict)
             x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
             x_dict = {key: F.dropout(x, p=DROPOUT, training=self.training) for key, x in x_dict.items()}
 
         outs = []
+        
         for x, batch in zip(x_dict.values(), batch_dict.values()):
             if len(x):
                 outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)).to(device))
@@ -120,7 +126,7 @@ class HeteroGraphSAGE(torch.nn.Module):
     """
     def __init__(self, hidden_channels, out_channels, num_layers):
         super().__init__()
-
+        log.info("MODEL: GraphSAGE")
         self.convs = torch.nn.ModuleList()
 
         # Create Graph Attentional layers
@@ -173,6 +179,87 @@ class HeteroGraphSAGE(torch.nn.Module):
         out = self.softmax(out)
         return out
         
+"""
+G
+R
+A
+P
+H
+C
+O
+N
+V
+"""
+"""
+G
+A
+T
+"""
+
+class HeteroGraphConv(torch.nn.Module):
+    """
+    Heterogeneous GraphConv model.
+    """
+    def __init__(self, hidden_channels, out_channels, num_layers):
+        super().__init__()
+        log.info("MODEL: GraphConv")
+
+        self.convs = torch.nn.ModuleList()
+
+        # Create Graph Attentional layers
+        for _ in range(num_layers):
+            conv = HeteroConv({
+                ('tag', 'describes', 'question'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('tag', 'describes', 'answer'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('tag', 'describes', 'comment'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('module', 'imported_in', 'question'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('module', 'imported_in', 'answer'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('question', 'rev_describes', 'tag'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('answer', 'rev_describes', 'tag'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('comment', 'rev_describes', 'tag'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('question', 'rev_imported_in', 'module'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('answer', 'rev_imported_in', 'module'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+            }, aggr='sum')
+            self.convs.append(conv)
+
+        self.lin1 = Linear(-1, hidden_channels)
+        self.lin2 = Linear(hidden_channels, out_channels)
+        self.softmax = torch.nn.Softmax(dim=-1)
+
+    def forward(self, x_dict, edge_index_dict, batch_dict, post_emb):
+        x_dict = {key: x_dict[key] for key in x_dict.keys() if key in ["question", "answer", "comment", "tag", "module"]}
+
+        
+        for conv in self.convs:
+            break
+            x_dict = conv(x_dict, edge_index_dict)
+            x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
+            x_dict = {key: F.dropout(x, p=DROPOUT, training=self.training) for key, x in x_dict.items()}
+
+        outs = []
+        
+        for x, batch in zip(x_dict.values(), batch_dict.values()):
+            if len(x):
+                outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)).to(device))
+            else:
+                outs.append(torch.zeros(1, x.size(-1)).to(device))
+
+
+        out = torch.cat(outs, dim=1).to(device)
+
+        out = torch.cat([out, post_emb], dim=1).to(device)
+        
+        out = F.dropout(out, p=DROPOUT, training=self.training)
+
+
+        out = self.lin1(out)
+        out = F.leaky_relu(out)
+        
+        out = self.lin2(out)
+        out = F.leaky_relu(out)
+        
+        out = self.softmax(out)
+        return out
         
 """
 T
@@ -291,7 +378,7 @@ if __name__ == '__main__':
         config.initial_lr = START_LR
         config.gamma = GAMMA
         config.batch_size = TRAIN_BATCH_SIZE
-    
+        
 
     # Datasets
     if IN_MEMORY_DATASET:
@@ -300,7 +387,7 @@ if __name__ == '__main__':
     else:
         dataset = UserGraphDataset(root=ROOT, skip_processing=True)
         train_dataset, test_dataset = split_test_train_pytorch(dataset, 0.7)
-
+        
     if CROSS_VALIDATE:
         print(FOLD_FILES)
         folds = [UserGraphDatasetInMemory(root="../data", file_name_out=fold_path, question_ids=[]) for fold_path in FOLD_FILES]
@@ -385,6 +472,13 @@ if __name__ == '__main__':
     if USE_WANDB:
         wandb.log(data_details)
 
+    # Take subset for EXP3        
+    if REL_SUBSET is not None:
+        indices = list(range(int(len(train_dataset)*REL_SUBSET)))
+        train_dataset = torch.utils.data.Subset(train_dataset, indices)
+        log.info(f"Subset contains {len(train_dataset)}")
+
+
     sampler = None
     class_weights = calculate_class_weights(train_dataset).to(device)
     
@@ -392,7 +486,8 @@ if __name__ == '__main__':
     if USE_CLASS_WEIGHTS_SAMPLER:
         train_labels = [x.label for x in train_dataset]
         sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
-
+    
+    
     # Dataloaders
     log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}")
     train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=TRAIN_BATCH_SIZE, num_workers=NUM_WORKERS)
@@ -403,11 +498,12 @@ if __name__ == '__main__':
         model = HeteroGAT(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS)
     elif MODEL == "SAGE":
         model = HeteroGraphSAGE(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS)
+    elif MODEL == "GC":
+        model = HeteroGraphConv(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS)
     else:
         log.error(f"Model does not exist! ({MODEL})")
         exit(1)
         
-    model = HeteroGraphSAGE(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS)
     model.to(device)  # To GPU if available
     
     if WARM_START_FILE is not None:
diff --git a/embeddings/hetero_GAT_constants.py b/embeddings/hetero_GAT_constants.py
index 4b296bb7c769afe9c25e9ef22d35771b31310d5b..6c4ddedd34dc79dc654f32fb8dd36c73b307d98a 100644
--- a/embeddings/hetero_GAT_constants.py
+++ b/embeddings/hetero_GAT_constants.py
@@ -11,7 +11,7 @@ USE_CLASS_WEIGHTS_LOSS = False
 # W&B dashboard logging
 USE_WANDB = False
 WANDB_PROJECT_NAME = "heterogeneous-GAT-model"
-WANDB_RUN_NAME = "EXP1-run"  # None for timestamp
+WANDB_RUN_NAME = None  # None for timestamp
 
 # OS
 OS_NAME = "linux"  # "windows" or "linux"
@@ -21,10 +21,11 @@ NUM_WORKERS = 14
 ROOT = "../../../data/lhb1g20"
 TRAIN_DATA_PATH = "../../../../../data/lhb1g20/train-4175-qs.pt"
 TEST_DATA_PATH = "../../../../../data/lhb1g20/test-1790-qs.pt"
-EPOCHS = 10
+REL_SUBSET = None
+EPOCHS = 20  
 START_LR = 0.001
 GAMMA = 0.95
-WARM_START_FILE = "../models/gat_qa_20e_64h_3l.pt"
+WARM_START_FILE = None #"../models/gat_qa_20e_64h_3l.pt"
 
 # (Optional) k-fold cross validation
 CROSS_VALIDATE = False
@@ -32,9 +33,9 @@ FOLD_FILES = ['fold-1-6001-qs.pt', 'fold-2-6001-qs.pt', 'fold-3-6001-qs.pt', 'fo
 PICKLE_PATH_KF = 'q_kf_results.pkl'
 
 # Model architecture 
-MODEL = "GAT"
+MODEL = "GC"
 NUM_LAYERS = 3
 HIDDEN_CHANNELS = 64
-FINAL_MODEL_OUT_PATH = "gat_qa_10e_64h_3l.pt"
-SAVE_CHECKPOINTS = False
-DROPOUT=0.0
+FINAL_MODEL_OUT_PATH = "SAGE_3l_60e_64h.pt"
+SAVE_CHECKPOINTS = True
+DROPOUT=0.3