diff --git a/embeddings/__init__.py b/embeddings/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/embeddings/__pycache__/ModuleEmbeddings.cpython-39.pyc b/embeddings/__pycache__/ModuleEmbeddings.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f517b2832d279eceff26690cf50123b46795f4a
Binary files /dev/null and b/embeddings/__pycache__/ModuleEmbeddings.cpython-39.pyc differ
diff --git a/embeddings/__pycache__/NextTagEmbedding.cpython-39.pyc b/embeddings/__pycache__/NextTagEmbedding.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bc5968612d8fbb03504a3b43e0a545dab411d0d
Binary files /dev/null and b/embeddings/__pycache__/NextTagEmbedding.cpython-39.pyc differ
diff --git a/embeddings/__pycache__/Visualize.cpython-39.pyc b/embeddings/__pycache__/Visualize.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74ac8f2a2d5b9996aef84b0c816ad0946d3c1e4d
Binary files /dev/null and b/embeddings/__pycache__/Visualize.cpython-39.pyc differ
diff --git a/embeddings/__pycache__/custom_logger.cpython-38.pyc b/embeddings/__pycache__/custom_logger.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a06a90bb35dfdd823d5608f747aeabf6a251fefe
Binary files /dev/null and b/embeddings/__pycache__/custom_logger.cpython-38.pyc differ
diff --git a/embeddings/__pycache__/custom_logger.cpython-39.pyc b/embeddings/__pycache__/custom_logger.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97fd1d8280827875e666c1a6c8b9995ee9958120
Binary files /dev/null and b/embeddings/__pycache__/custom_logger.cpython-39.pyc differ
diff --git a/embeddings/__pycache__/dataset.cpython-39.pyc b/embeddings/__pycache__/dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd97a912b458d313f1ee3b398718af41bcc1a946
Binary files /dev/null and b/embeddings/__pycache__/dataset.cpython-39.pyc differ
diff --git a/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc b/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b5680a08ba3da9769d0355f247721d87c640624
Binary files /dev/null and b/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc differ
diff --git a/embeddings/__pycache__/post_embedding_builder.cpython-39.pyc b/embeddings/__pycache__/post_embedding_builder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c28eb16ed26e7813cbf1dd1fd1857bcb9fed14ba
Binary files /dev/null and b/embeddings/__pycache__/post_embedding_builder.cpython-39.pyc differ
diff --git a/embeddings/__pycache__/static_graph_construction.cpython-39.pyc b/embeddings/__pycache__/static_graph_construction.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c148ef5bc25a569593d938dd5ee030df7d33280
Binary files /dev/null and b/embeddings/__pycache__/static_graph_construction.cpython-39.pyc differ
diff --git a/embeddings/__pycache__/unixcoder.cpython-39.pyc b/embeddings/__pycache__/unixcoder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..328777fd5508c3dbe7b68f74baa751b377c14a97
Binary files /dev/null and b/embeddings/__pycache__/unixcoder.cpython-39.pyc differ
diff --git a/embeddings/custom_logger.py b/embeddings/custom_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fb2391569ebb7ca2cf7bf372cb11fd708b0bfed
--- /dev/null
+++ b/embeddings/custom_logger.py
@@ -0,0 +1,16 @@
+import logging
+import sys
+
+
+def setup_custom_logger(name, level):
+    formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s',
+                                  datefmt='%Y-%m-%d %H:%M:%S')
+    screen_handler = logging.StreamHandler(stream=sys.stdout)
+    screen_handler.setFormatter(formatter)
+    logger = logging.getLogger(name)
+    logger.propagate = False
+    logger.setLevel(level)
+
+    logger.handlers.clear()
+    logger.addHandler(screen_handler)
+    return logger
diff --git a/embeddings/dataset_in_memory.py b/embeddings/dataset_in_memory.py
index 153055e362ba7a2164820943821e540113ece47b..1ecbf6a7b6ca609b0c1c604fbb5d643d5ea79cee 100644
--- a/embeddings/dataset_in_memory.py
+++ b/embeddings/dataset_in_memory.py
@@ -1,5 +1,7 @@
 import logging
 import os
+import re
+from typing import List
 
 import torch
 from torch_geometric.data import InMemoryDataset
@@ -9,9 +11,12 @@ from custom_logger import setup_custom_logger
 log = setup_custom_logger('in-memory-dataset', logging.INFO)
 
 class UserGraphDatasetInMemory(InMemoryDataset):
-    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
+    def __init__(self, root, file_name_out: str, question_ids:List[int]=None, transform=None, pre_transform=None, pre_filter=None):
+        self._file_name_out = file_name_out
+        self._question_ids = question_ids
         super().__init__(root, transform, pre_transform, pre_filter)
         self.data, self.slices = torch.load(self.processed_paths[0])
+        self.data = self.data.apply(lambda x: x.detach())
 
     @property
     def processed_dir(self):
@@ -27,7 +32,7 @@ class UserGraphDatasetInMemory(InMemoryDataset):
 
     @property
     def processed_file_names(self):
-        return ['in-memory-dataset.pt']
+        return [self._file_name_out]
 
     def download(self):
         pass
@@ -37,15 +42,34 @@ class UserGraphDatasetInMemory(InMemoryDataset):
         data_list = []
 
         for f in self.raw_file_names:
+            question_id_search = re.search(r"id_(\d+)", f)
+            if question_id_search:
+                if int(question_id_search.group(1)) not in self._question_ids:
+                    continue
+
             data = torch.load(os.path.join(self.raw_dir, f))
             data_list.append(data)
 
         data, slices = self.collate(data_list)
-        torch.save((data, slices), os.path.join(self.processed_dir, self.processed_file_names[0]))
+        self.processed_paths[0] = f"{len(data_list)}-{self.processed_file_names[0]}"
+        torch.save((data, slices), os.path.join(self.processed_paths[0]))
 
 
 
 if __name__ == '__main__':
-    dataset = UserGraphDatasetInMemory('../data/')
-
-    print(dataset.get(3))
\ No newline at end of file
+    question_ids = set()
+    # Split by question ids
+    for f in os.listdir("../data/processed"):
+        question_id_search = re.search(r"id_(\d+)", f)
+        if question_id_search:
+            question_ids.add(int(question_id_search.group(1)))
+
+    #question_ids = list(question_ids)[:len(question_ids)* 0.6]
+    train_ids = list(question_ids)[:int(len(question_ids) * 0.7)]
+    test_ids = [x for x in question_ids if x not in train_ids]
+
+    log.info(f"Training question count {len(train_ids)}")
+    log.info(f"Testing question count {len(test_ids)}")
+
+    train_dataset = UserGraphDatasetInMemory('../data/', train_ids, f'train-{len(train_ids)}-qs.pt')
+    test_dataset = UserGraphDatasetInMemory('../data/', test_ids, f'test-{len(test_ids)}-qs.pt')
diff --git a/embeddings/hetero_GAT.py b/embeddings/hetero_GAT.py
index 95a6b9983a5b8e902c12648de10c2e51c1e145dc..0943795d340e8939345c4d2893db569eea6f1cb4 100644
--- a/embeddings/hetero_GAT.py
+++ b/embeddings/hetero_GAT.py
@@ -19,6 +19,10 @@ from dataset_in_memory import UserGraphDatasetInMemory
 from Visualize import GraphVisualization
 
 log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
+torch.multiprocessing.set_sharing_strategy('file_system')
+import resource
+rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
 
 
 class HeteroGNN(torch.nn.Module):
@@ -28,25 +32,24 @@ class HeteroGNN(torch.nn.Module):
         self.convs = torch.nn.ModuleList()
         for _ in range(num_layers):
             conv = HeteroConv({
-                ('tag', 'describes', 'question') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('tag', 'describes', 'answer') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('tag', 'describes', 'comment') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('module', 'imported_in', 'question') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('module', 'imported_in', 'answer') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('question', 'rev_describes', 'tag') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('answer', 'rev_describes', 'tag') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('comment', 'rev_describes', 'tag') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('question', 'rev_imported_in', 'module') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('answer', 'rev_imported_in', 'module') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
+                ('tag', 'describes', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('tag', 'describes', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('tag', 'describes', 'comment'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('module', 'imported_in', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('module', 'imported_in', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('question', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('answer', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('comment', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('question', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('answer', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
             }, aggr='sum')
             self.convs.append(conv)
 
         self.lin = Linear(-1, out_channels)
         self.softmax = torch.nn.Softmax(dim=-1)
 
-
     def forward(self, x_dict, edge_index_dict, batch_dict, post_emb):
-        #print("IN", post_emb.shape)
+        # print("IN", post_emb.shape)
         for conv in self.convs:
             x_dict = conv(x_dict, edge_index_dict)
             x_dict = {key: x.relu() for key, x in x_dict.items()}
@@ -58,29 +61,37 @@ class HeteroGNN(torch.nn.Module):
             else:
                 outs.append(torch.zeros(1, x.size(-1)))
 
-        #print([x.shape for x in outs])
+        # print([x.shape for x in outs])
         out = torch.cat(outs, dim=1)
 
         out = torch.cat([out, post_emb], dim=1)
 
-        #print("B4 LINEAR", out.shape)
+        # print("B4 LINEAR", out.shape)
         out = self.lin(out)
         out = out.relu()
         out = self.softmax(out)
         return out
 
 
+'''
+
+'''
+
 
 def train(model, train_loader):
     running_loss = 0.0
 
     model.train()
     for i, data in enumerate(train_loader):  # Iterate in batches over the training dataset.
-        data = data.to(device)
+        data.to(device)
 
         optimizer.zero_grad()  # Clear gradients.
-        #print("DATA IN", data.question_emb.shape, data.answer_emb.shape)
-        post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1)
+        
+        if INCLUDE_ANSWER:
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            post_emb = data.question_emb.to(device)
+
         out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
 
         loss = criterion(out, torch.squeeze(data.label, -1))  # Compute the loss.
@@ -89,13 +100,12 @@ def train(model, train_loader):
 
         running_loss += loss.item()
         if i % 5 == 0:
-            log.info(f"[{i+1}] Loss: {running_loss / 2000}")
+            log.info(f"[{i + 1}] Loss: {running_loss / 5}")
             running_loss = 0.0
 
 
-
 def test(loader):
-    table = wandb.Table(columns=["graph", "ground_truth", "prediction"]) if use_wandb else None
+    table = wandb.Table(columns=["ground_truth", "prediction"]) if use_wandb else None
     model.eval()
 
     predictions = []
@@ -104,9 +114,12 @@ def test(loader):
     loss_ = 0
 
     for data in loader:  # Iterate in batches over the training/test dataset.
-        data = data.to(device)
-
-        post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        data.to(device)
+        
+        if INCLUDE_ANSWER:
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            post_emb = data.question_emb.to(device)
 
         out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
 
@@ -115,13 +128,26 @@ def test(loader):
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         predictions += list([x.item() for x in pred])
         true_labels += list([x.item() for x in data.label])
-
+        # log.info([(x, y) for x,y in zip([x.item() for x in pred], [x.item() for x in data.label])])
         if use_wandb:
-            graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data)))
+            #graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data)))
+            
             for pred, label in zip(pred, torch.squeeze(data.label, -1)):
-                table.add_data(graph_html, label, pred)
+                table.add_data(label, pred)
+
+            
+
+    #print([(x, y) for x, y in zip(predictions, true_labels)])
+    test_results = {
+        "accuracy": accuracy_score(true_labels, predictions),
+        "f1-score": f1_score(true_labels, predictions),
+        "loss": loss_ / len(loader),
+        "table": table,
+        "preds": predictions, 
+        "trues": true_labels 
+    }
+    return test_results
 
-    return accuracy_score(true_labels, predictions), f1_score(true_labels, predictions), loss_ / len(loader), table
 
 def create_graph_vis(graph):
     g = to_networkx(graph.to_homogeneous())
@@ -132,6 +158,7 @@ def create_graph_vis(graph):
     fig = vis.create_figure()
     return fig
 
+
 def init_wandb(project_name: str, dataset):
     wandb.init(project=project_name, name="setup")
     # Log all the details about the data to W&B.
@@ -145,7 +172,7 @@ def init_wandb(project_name: str, dataset):
         n_edges = graph.num_edges
         label = graph.label.item()
 
-        #graph_vis = plotly.io.to_html(fig, full_html=False)
+        # graph_vis = plotly.io.to_html(fig, full_html=False)
 
         table.add_data(wandb.Plotly(fig), n_nodes, n_edges, label)
     wandb.log({"data": table})
@@ -158,106 +185,116 @@ def init_wandb(project_name: str, dataset):
     # End the W&B run
     wandb.finish()
 
+
 def start_wandb_for_training(wandb_project_name: str, wandb_run_name: str):
     wandb.init(project=wandb_project_name, name=wandb_run_name)
-    #wandb.use_artifact("static-graphs:latest")
+    # wandb.use_artifact("static-graphs:latest")
+
 
 def save_model(model, model_name: str):
     torch.save(model.state_dict(), os.path.join("..", "models", model_name))
 
+
 if __name__ == '__main__':
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     log.info(f"Proceeding with {device} . .")
 
-    in_memory_dataset = False
+    in_memory_dataset = True
     # Datasets
     if in_memory_dataset:
-        dataset = UserGraphDatasetInMemory(root="../data")
+        train_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='train-4175-qs.pt')
+        test_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='test-1790-qs.pt')
     else:
         dataset = UserGraphDataset(root="../data", skip_processing=True)
+        train_size = int(0.7 * len(dataset))
+        val_size = int(0.1 * len(dataset))
+        test_size = len(dataset) - (train_size + val_size)
 
 
-    train_size = int(0.7 * len(dataset))
-    val_size = int(0.1 * len(dataset))
-    test_size = len(dataset) - (train_size + val_size)
-
-    log.info(f"Train Dataset Size: {train_size}")
-    log.info(f"Validation Dataset Size: {val_size}")
-    log.info(f"Test Dataset Size: {test_size}")
-    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
+        train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
 
+    log.info(f"Train Dataset Size: {len(train_dataset)}")
+    log.info(f"Test Dataset Size: {len(test_dataset)}")
+    
     # Weights&Biases dashboard
     data_details = {
-        "num_node_features": dataset.num_node_features,
+        "num_node_features": train_dataset.num_node_features,
         "num_classes": 2
     }
     log.info(f"Data Details:\n{data_details}")
-
+    
+    log.info(train_dataset[0])
+    
     setup_wandb = False
     wandb_project_name = "heterogeneous-GAT-model"
     if setup_wandb:
         init_wandb(wandb_project_name, dataset)
-    use_wandb = False
+    use_wandb = True
     if use_wandb:
         wandb_run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}"
         start_wandb_for_training(wandb_project_name, wandb_run_name)
 
-
-    calculate_class_weights = False
-    #Class weights
+    calculate_class_weights = True
+    # Class weights
     sampler = None
     if calculate_class_weights:
         log.info(f"Calculating class weights")
         train_labels = [x.label for x in train_dataset]
-        counts = [train_labels.count(x) for x in [0,1]]
+        counts = [train_labels.count(x) for x in [0, 1]]
+        print(counts)
         class_weights = [1 - (x / sum(counts)) for x in counts]
+        print(class_weights)
         sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
 
+    TRAIN_BATCH_SIZE = 512
+    log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}")
+
     # Dataloaders
-    train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=64)
-    val_loader = DataLoader(val_dataset, batch_size=16)
-    test_loader = DataLoader(test_dataset, batch_size=16)
+    train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=TRAIN_BATCH_SIZE, num_workers=14)
+    
+    test_loader = DataLoader(test_dataset, batch_size=512, num_workers=14)
 
     # Model
-    model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3).to(device)
-
-    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+    model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3)
+    model.to(device)
+    
+    # Experiment config
+    INCLUDE_ANSWER = False
+    
+    optimizer = torch.optim.Adam(model.parameters())
     criterion = torch.nn.CrossEntropyLoss()
 
-    for epoch in range(1, 5):
+    for epoch in range(1, 40):
         log.info(f"Epoch: {epoch:03d} > > >")
         train(model, train_loader)
-        train_acc, train_f1, train_loss, train_table = test(train_loader)
-        val_acc, val_f1, val_loss, val_table = test(val_loader)
-        test_acc, test_f1, test_loss, test_table = test(test_loader)
+        train_info = test(train_loader)
+        test_info = test(test_loader)
 
-        print(f'Epoch: {epoch:03d}, Train F1: {train_f1:.4f}, Validation F1: {val_f1:.4f} Test F1: {test_f1:.4f}')
+        print(f'Epoch: {epoch:03d}, Train F1: {train_info["f1-score"]:.4f}, Test F1: {test_info["f1-score"]:.4f}')
         checkpoint_file_name = f"../models/model-{epoch}.pt"
         torch.save(model.state_dict(), checkpoint_file_name)
         if use_wandb:
             wandb.log({
-                "train/loss": train_loss,
-                "train/accuracy": train_acc,
-                "train/f1": train_f1,
-                "train/table": train_table,
-                "val/loss": val_loss,
-                "val/accuracy": val_acc,
-                "val/f1": val_f1,
-                "val/table": val_table,
-                "test/loss": test_loss,
-                "test/accuracy": test_acc,
-                "test/f1": test_f1,
-                "test/table": test_table,
+                "train/loss": train_info["loss"],
+                "train/accuracy": train_info["accuracy"],
+                "train/f1": train_info["f1-score"],
+                "train/table": train_info["table"],
+                "test/loss": test_info["loss"],
+                "test/accuracy": test_info["accuracy"],
+                "test/f1": test_info["f1-score"],
+                "test/table": test_info["table"]
             })
             # Log model checkpoint as an artifact to W&B
             # artifact = wandb.Artifact(name="heterogenous-GAT-static-graphs", type="model")
-            # checkpoint_file_name = f"../models/model-{epoch}.pt"
+            # checkpoint_file_name = f  "../models/model-{epoch}.pt"
             # torch.save(model.state_dict(), checkpoint_file_name)
             # artifact.add_file(checkpoint_file_name)
             # wandb.log_artifact(artifact)
 
-    print(f'Test F1: {test_f1:.4f}')
+    print(f'Test F1: {train_info["f1-score"]:.4f}')
 
     save_model(model, "model.pt")
     if use_wandb:
+        wandb.log({"test/cm": wandb.plot.confusion_matrix(probs=None, y_true=test_info["trues"], preds=test_info["preds"], class_names=["neutral", "upvoted"])})
         wandb.finish()
+
diff --git a/embeddings/hetero_GAT.py.save b/embeddings/hetero_GAT.py.save
new file mode 100644
index 0000000000000000000000000000000000000000..865939d3c47c49ffbf0511dc847d8d031db35564
--- /dev/null
+++ b/embeddings/hetero_GAT.py.save
@@ -0,0 +1,297 @@
+import json
+import logging
+import os
+import string
+import time
+
+import networkx as nx
+import plotly
+import torch
+from sklearn.metrics import f1_score, accuracy_score
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import HeteroConv, GATConv, Linear, global_mean_pool
+import wandb
+from torch_geometric.utils import to_networkx
+
+from custom_logger import setup_custom_logger
+from dataset import UserGraphDataset
+from dataset_in_memory import UserGraphDatasetInMemory
+from Visualize import GraphVisualization
+
+log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
+torch.multiprocessing.set_sharing_strategy('file_system')
+mport resource
+rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+# resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
+
+
+class HeteroGNN(torch.nn.Module):
+    def __init__(self, hidden_channels, out_channels, num_layers):
+        super().__init__()
+
+        self.convs = torch.nn.ModuleList()
+        for _ in range(num_layers):
+            conv = HeteroConv({
+                ('tag', 'describes', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('tag', 'describes', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('tag', 'describes', 'comment'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('module', 'imported_in', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('module', 'imported_in', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('question', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('answer', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('comment', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('question', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('answer', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+            }, aggr='sum')
+            self.convs.append(conv)
+
+        self.lin = Linear(-1, out_channels)
+        self.softmax = torch.nn.Softmax(dim=-1)
+
+    def forward(self, x_dict, edge_index_dict, batch_dict, post_emb):
+        # print("IN", post_emb.shape)
+        for conv in self.convs:
+            x_dict = conv(x_dict, edge_index_dict)
+            x_dict = {key: x.relu() for key, x in x_dict.items()}
+
+        outs = []
+        for x, batch in zip(x_dict.values(), batch_dict.values()):
+            if len(x):
+                outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)))
+            else:
+                outs.append(torch.zeros(1, x.size(-1)))
+
+        # print([x.shape for x in outs])
+        out = torch.cat(outs, dim=1)
+
+        out = torch.cat([out, post_emb], dim=1)
+
+        # print("B4 LINEAR", out.shape)
+        out = self.lin(out)
+        out = out.relu()
+        out = self.softmax(out)
+        return out
+
+
+'''
+
+'''
+
+
+def train(model, train_loader):
+    running_loss = 0.0
+
+    model.train()
+    for i, data in enumerate(train_loader):  # Iterate in batches over the training dataset.
+        data.to(device)
+
+        optimizer.zero_grad()  # Clear gradients.
+        
+        if INCLUDE_ANSWER:
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            post_emb = data.question_emb.to(device)
+
+        out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
+
+        loss = criterion(out, torch.squeeze(data.label, -1))  # Compute the loss.
+        loss.backward()  # Derive gradients.
+        optimizer.step()  # Update parameters based on gradients.
+
+        running_loss += loss.item()
+        if i % 5 == 0:
+            log.info(f"[{i + 1}] Loss: {running_loss / 5}")
+            running_loss = 0.0
+
+
+def test(loader):
+    table = wandb.Table(columns=["ground_truth", "prediction"]) if use_wandb else None
+    model.eval()
+
+    predictions = []
+    true_labels = []
+
+    loss_ = 0
+
+    for data in loader:  # Iterate in batches over the training/test dataset.
+        data.to(device)
+        
+        if INCLUDE_ANSWER:
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            post_emb = data.question_emb.to(device)
+
+        out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
+
+        loss = criterion(out, torch.squeeze(data.label, -1))  # Compute the loss.
+        loss_ += loss.item()
+        pred = out.argmax(dim=1)  # Use the class with highest probability.
+        predictions += list([x.item() for x in pred])
+        true_labels += list([x.item() for x in data.label])
+        # log.info([(x, y) for x,y in zip([x.item() for x in pred], [x.item() for x in data.label])])
+        if use_wandb:
+            #graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data)))
+            
+            for pred, label in zip(pred, torch.squeeze(data.label, -1)):
+                table.add_data(label, pred)
+
+            
+
+    #print([(x, y) for x, y in zip(predictions, true_labels)])
+    test_results = {
+        "accuracy": accuracy_score(true_labels, predictions),
+        "f1-score": f1_score(true_labels, predictions),
+        "loss": loss_ / len(loader),
+        "table": table,
+        "preds": predictions, 
+        "trues": true_labels 
+    }
+    return test_results
+
+
+def create_graph_vis(graph):
+    g = to_networkx(graph.to_homogeneous())
+    pos = nx.spring_layout(g)
+    vis = GraphVisualization(
+        g, pos, node_text_position='top left', node_size=20,
+    )
+    fig = vis.create_figure()
+    return fig
+
+
+def init_wandb(project_name: str, dataset):
+    wandb.init(project=project_name, name="setup")
+    # Log all the details about the data to W&B.
+    wandb.log(data_details)
+
+    # Log exploratory visualizations for each data point to W&B
+    table = wandb.Table(columns=["Graph", "Number of Nodes", "Number of Edges", "Label"])
+    for graph in dataset:
+        fig = create_graph_vis(graph)
+        n_nodes = graph.num_nodes
+        n_edges = graph.num_edges
+        label = graph.label.item()
+
+        # graph_vis = plotly.io.to_html(fig, full_html=False)
+
+        table.add_data(wandb.Plotly(fig), n_nodes, n_edges, label)
+    wandb.log({"data": table})
+
+    # Log the dataset to W&B as an artifact.
+    dataset_artifact = wandb.Artifact(name="static-graphs", type="dataset", metadata=data_details)
+    dataset_artifact.add_dir("../data/")
+    wandb.log_artifact(dataset_artifact)
+
+    # End the W&B run
+    wandb.finish()
+
+
+def start_wandb_for_training(wandb_project_name: str, wandb_run_name: str):
+    wandb.init(project=wandb_project_name, name=wandb_run_name)
+    # wandb.use_artifact("static-graphs:latest")
+
+
+def save_model(model, model_name: str):
+    torch.save(model.state_dict(), os.path.join("..", "models", model_name))
+
+
+if __name__ == '__main__':
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    log.info(f"Proceeding with {device} . .")
+
+    in_memory_dataset = True
+    # Datasets
+    if in_memory_dataset:
+        train_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='train-4175-qs.pt')
+        test_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='test-1790-qs.pt')
+    else:
+        dataset = UserGraphDataset(root="../data", skip_processing=True)
+        train_size = int(0.7 * len(dataset))
+        val_size = int(0.1 * len(dataset))
+        test_size = len(dataset) - (train_size + val_size)
+
+
+        train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
+
+    log.info(f"Train Dataset Size: {len(train_dataset)}")
+    log.info(f"Test Dataset Size: {len(test_dataset)}")
+    
+    # Weights&Biases dashboard
+    data_details = {
+        "num_node_features": train_dataset.num_node_features,
+        "num_classes": 2
+    }
+    log.info(f"Data Details:\n{data_details}")
+
+    setup_wandb = False
+    wandb_project_name = "heterogeneous-GAT-model"
+    if setup_wandb:
+        init_wandb(wandb_project_name, dataset)
+    use_wandb = True
+    if use_wandb:
+        wandb_run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}"
+        start_wandb_for_training(wandb_project_name, wandb_run_name)
+
+    calculate_class_weights = True
+    # Class weights
+    sampler = None
+    if calculate_class_weights:
+        log.info(f"Calculating class weights")
+        train_labels = [x.label for x in train_dataset]
+        counts = [train_labels.count(x) for x in [0, 1]]
+        class_weights = [1 - (x / sum(counts)) for x in counts]
+        print(class_weights)
+        sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
+
+    TRAIN_BATCH_SIZE = 512
+    log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}")
+
+    # Dataloaders
+    train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=TRAIN_BATCH_SIZE, num_workers=14)
+    
+    test_loader = DataLoader(test_dataset, batch_size=512, num_workers=14)
+
+    # Model
+    model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3)
+    model.to(device)
+    
+    # Experiment config
+    INCLUDE_ANSWER = True
+    
+    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+    criterion = torch.nn.CrossEntropyLoss()
+
+    for epoch in range(1, 40):
+        log.info(f"Epoch: {epoch:03d} > > >")
+        train(model, train_loader)
+        train_info = test(train_loader)
+        test_info = test(test_loader)
+
+        print(f'Epoch: {epoch:03d}, Train F1: {train_info["f1-score"]:.4f}, Test F1: {test_info["f1-score"]:.4f}')
+        checkpoint_file_name = f"../models/model-{epoch}.pt"
+        torch.save(model.state_dict(), checkpoint_file_name)
+        if use_wandb:
+            wandb.log({
+                "train/loss": train_info["loss"],
+                "train/accuracy": train_info["accuracy"],
+                "train/f1": train_info["f1-score"],
+                "train/table": train_info["table"],
+                "test/loss": test_info["loss"],
+                "test/accuracy": test_info["accuracy"],
+                "test/f1": test_info["f1-score"],
+                "test/table": test_info["table"]
+            })
+            # Log model checkpoint as an artifact to W&B
+            # artifact = wandb.Artifact(name="heterogenous-GAT-static-graphs", type="model")
+            # checkpoint_file_name = f  "../models/model-{epoch}.pt"
+            # torch.save(model.state_dict(), checkpoint_file_name)
+            # artifact.add_file(checkpoint_file_name)
+            # wandb.log_artifact(artifact)
+
+    print(f'Test F1: {test_f1:.4f}')
+
+    save_model(model, "model.pt")
+    if use_wandb:
+        wandb.log({"test/cm": wandb.plot.confusion_matrix(probs=None, y_true=test_info["trues"], preds=test_info["preds"], class_names=["neutral", "upvoted"])})
+        wandb.finish()
+