diff --git a/__pycache__/custom_logger.cpython-38.pyc b/__pycache__/custom_logger.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd1b7e74e21f2765fb90fb38781bc9be52b35809
Binary files /dev/null and b/__pycache__/custom_logger.cpython-38.pyc differ
diff --git a/embeddings/GAT.py b/archived/GAT.py
similarity index 100%
rename from embeddings/GAT.py
rename to archived/GAT.py
diff --git a/embeddings/db_handler.py b/archived/db_handler.py
similarity index 100%
rename from embeddings/db_handler.py
rename to archived/db_handler.py
diff --git a/embeddings/embeddings.ipynb b/archived/embeddings.ipynb
similarity index 100%
rename from embeddings/embeddings.ipynb
rename to archived/embeddings.ipynb
diff --git a/embeddings/pyg_construction_demo.ipynb b/archived/pyg_construction_demo.ipynb
similarity index 100%
rename from embeddings/pyg_construction_demo.ipynb
rename to archived/pyg_construction_demo.ipynb
diff --git a/embeddings/tag_embedding_trainer.py b/archived/tag_embedding_trainer.py
similarity index 100%
rename from embeddings/tag_embedding_trainer.py
rename to archived/tag_embedding_trainer.py
diff --git a/custom_logger.py b/custom_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..428f5a43902b960d46ee84d8064d47173258d58c
--- /dev/null
+++ b/custom_logger.py
@@ -0,0 +1,16 @@
+import logging
+import sys
+
+
+def setup_custom_logger(name, level):
+    formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s',
+                                  datefmt='%Y-%m-%d %H:%M:%S')
+    screen_handler = logging.StreamHandler(stream=sys.stdout)
+    screen_handler.setFormatter(formatter)
+    logger = logging.getLogger(name)
+    logger.propagate = False
+    logger.setLevel(level)
+
+    logger.handlers.clear()
+    logger.addHandler(screen_handler)
+    return logger
diff --git a/embeddings/__pycache__/dataset.cpython-38.pyc b/embeddings/__pycache__/dataset.cpython-38.pyc
index 05d3eee3ff36256b85b90bfb9541b5349d0b6872..8f00b97e8ca224ab5e5ca43a353fac98b07d13c1 100644
Binary files a/embeddings/__pycache__/dataset.cpython-38.pyc and b/embeddings/__pycache__/dataset.cpython-38.pyc differ
diff --git a/embeddings/__pycache__/dataset_in_memory.cpython-38.pyc b/embeddings/__pycache__/dataset_in_memory.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f733627988e7838d5f142171ab9486d0d03313e1
Binary files /dev/null and b/embeddings/__pycache__/dataset_in_memory.cpython-38.pyc differ
diff --git a/embeddings/__pycache__/post_embedding_builder.cpython-38.pyc b/embeddings/__pycache__/post_embedding_builder.cpython-38.pyc
index a5e9dda385534af7818ba061bb26aabf70615bf9..9d69ec7710c3ec43bf3d6926c232566a68e672bc 100644
Binary files a/embeddings/__pycache__/post_embedding_builder.cpython-38.pyc and b/embeddings/__pycache__/post_embedding_builder.cpython-38.pyc differ
diff --git a/embeddings/__pycache__/static_graph_construction.cpython-38.pyc b/embeddings/__pycache__/static_graph_construction.cpython-38.pyc
index 25ddf94535803d6606156bd35f630605d2bf1fbe..5eae8edcf42817130d409e7c8ba9f90b9149efdb 100644
Binary files a/embeddings/__pycache__/static_graph_construction.cpython-38.pyc and b/embeddings/__pycache__/static_graph_construction.cpython-38.pyc differ
diff --git a/embeddings/dataset.py b/embeddings/dataset.py
index 319af3ea3f12a3cae30ee59ca39c13bcb09c910b..165ccbfd85e18e9168b0459904b992ac91a597f7 100644
--- a/embeddings/dataset.py
+++ b/embeddings/dataset.py
@@ -13,15 +13,15 @@ from torch_geometric.data import Dataset, download_url, Data, HeteroData
 from torch_geometric.data.hetero_data import NodeOrEdgeStorage
 from tqdm import tqdm
 import warnings
+
+from custom_logger import setup_custom_logger
+
 warnings.filterwarnings('ignore', category=MarkupResemblesLocatorWarning)
 
 from post_embedding_builder import PostEmbedding
 from static_graph_construction import StaticGraphConstruction
 
-logging.basicConfig()
-logging.getLogger().setLevel(logging.INFO)
-log = logging.getLogger("dataset")
-
+log = setup_custom_logger('dataset', logging.INFO)
 
 class UserGraphDataset(Dataset):
     def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, db_address:str=None, skip_processing=False):
diff --git a/embeddings/dataset_in_memory.py b/embeddings/dataset_in_memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..153055e362ba7a2164820943821e540113ece47b
--- /dev/null
+++ b/embeddings/dataset_in_memory.py
@@ -0,0 +1,51 @@
+import logging
+import os
+
+import torch
+from torch_geometric.data import InMemoryDataset
+
+from custom_logger import setup_custom_logger
+
+log = setup_custom_logger('in-memory-dataset', logging.INFO)
+
+class UserGraphDatasetInMemory(InMemoryDataset):
+    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
+        super().__init__(root, transform, pre_transform, pre_filter)
+        self.data, self.slices = torch.load(self.processed_paths[0])
+
+    @property
+    def processed_dir(self):
+        return os.path.join(self.root, 'processed_in_memory')
+
+    @property
+    def raw_dir(self):
+        return os.path.join(self.root, 'processed')
+
+    @property
+    def raw_file_names(self):
+        return [x for x in os.listdir("../data/processed") if x not in ['pre_filter.pt', 'pre_transform.pt']]
+
+    @property
+    def processed_file_names(self):
+        return ['in-memory-dataset.pt']
+
+    def download(self):
+        pass
+
+    def process(self):
+        # Read data into huge `Data` list.
+        data_list = []
+
+        for f in self.raw_file_names:
+            data = torch.load(os.path.join(self.raw_dir, f))
+            data_list.append(data)
+
+        data, slices = self.collate(data_list)
+        torch.save((data, slices), os.path.join(self.processed_dir, self.processed_file_names[0]))
+
+
+
+if __name__ == '__main__':
+    dataset = UserGraphDatasetInMemory('../data/')
+
+    print(dataset.get(3))
\ No newline at end of file
diff --git a/embeddings/hetero_GAT.py b/embeddings/hetero_GAT.py
index 98ca211618753b93a0ea461364f516558adc3251..f4c9d33bf4a99e9a85668c2152832dd52414b4ab 100644
--- a/embeddings/hetero_GAT.py
+++ b/embeddings/hetero_GAT.py
@@ -12,12 +12,13 @@ from torch_geometric.nn import HeteroConv, GATConv, Linear, global_mean_pool
 import wandb
 from torch_geometric.utils import to_networkx
 
+from custom_logger import setup_custom_logger
 from dataset import UserGraphDataset
+from dataset_in_memory import UserGraphDatasetInMemory
 from Visualize import GraphVisualization
 
-logging.basicConfig()
-logging.getLogger().setLevel(logging.INFO)
-log = logging.getLogger('heterogeneous-GAT-model')
+log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
+
 
 class HeteroGNN(torch.nn.Module):
     def __init__(self, hidden_channels, out_channels, num_layers):
@@ -54,7 +55,6 @@ class HeteroGNN(torch.nn.Module):
             if len(x):
                 outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)))
             else:
-                #print("EMPTY")
                 outs.append(torch.zeros(1, x.size(-1)))
 
         #print([x.shape for x in outs])
@@ -75,7 +75,6 @@ def train(model, train_loader):
 
     model.train()
     for i, data in enumerate(train_loader):  # Iterate in batches over the training dataset.
-        #print(data)
         data = data.to(device)
 
         optimizer.zero_grad()  # Clear gradients.
@@ -121,7 +120,6 @@ def test(loader):
             for pred, label in zip(pred, torch.squeeze(data.label, -1)):
                 table.add_data(graph_html, label, pred)
 
-    #print("PRED", predictions, true_labels)
     return accuracy_score(true_labels, predictions), f1_score(true_labels, predictions), loss_ / len(loader), table
 
 def create_graph_vis(graph):
@@ -146,9 +144,9 @@ def init_wandb(project_name: str, dataset):
         n_edges = graph.num_edges
         label = graph.label.item()
 
-        graph_vis = plotly.io.to_html(fig, full_html=False)
+        #graph_vis = plotly.io.to_html(fig, full_html=False)
 
-        table.add_data(graph_vis, n_nodes, n_edges, label)
+        table.add_data(wandb.Plotly(fig), n_nodes, n_edges, label)
     wandb.log({"data": table})
 
     # Log the dataset to W&B as an artifact.
@@ -159,17 +157,30 @@ def init_wandb(project_name: str, dataset):
     # End the W&B run
     wandb.finish()
 
-def start_wandb_for_training(wandb_project_name: str):
-    wandb.init(project=wandb_project_name)
-    wandb.use_artifact("static-graphs:latest")
+def start_wandb_for_training(wandb_project_name: str, wandb_run_name: str):
+    wandb.init(project=wandb_project_name, name=wandb_run_name)
+    #wandb.use_artifact("static-graphs:latest")
 
 if __name__ == '__main__':
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     log.info(f"Proceeding with {device} . .")
 
+    in_memory_dataset = False
     # Datasets
-    dataset = UserGraphDataset(root="../data", skip_processing=True)
-    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.6, 0.1, 0.3])
+    if in_memory_dataset:
+        dataset = UserGraphDatasetInMemory(root="../data")
+    else:
+        dataset = UserGraphDataset(root="../data", skip_processing=True)
+
+
+    train_size = int(0.7 * len(dataset))
+    val_size = int(0.1 * len(dataset))
+    test_size = len(dataset) - (train_size + val_size)
+
+    log.info(f"Train Dataset Size: {train_size}")
+    log.info(f"Validation Dataset Size: {val_size}")
+    log.info(f"Test Dataset Size: {test_size}")
+    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
 
     # Weights&Biases dashboard
     data_details = {
@@ -182,30 +193,35 @@ if __name__ == '__main__':
     wandb_project_name = "heterogeneous-GAT-model"
     if setup_wandb:
         init_wandb(wandb_project_name, dataset)
-    use_wandb = True
+    use_wandb = False
     if use_wandb:
         wandb_run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}"
         start_wandb_for_training(wandb_project_name, wandb_run_name)
 
-    # Class weights
-    train_labels = [x.label for x in train_dataset]
-    counts = [train_labels.count(x) for x in [0,1]]
-    class_weights = [1 - (x / sum(counts)) for x in counts]
-    sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
+
+    calculate_class_weights = False
+    #Class weights
+    sampler = None
+    if calculate_class_weights:
+        log.info(f"Calculating class weights")
+        train_labels = [x.label for x in train_dataset]
+        counts = [train_labels.count(x) for x in [0,1]]
+        class_weights = [1 - (x / sum(counts)) for x in counts]
+        sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
 
     # Dataloaders
-    train_loader = DataLoader(train_dataset, sampler=None, batch_size=16)
-    val_loader = DataLoader(val_dataset, batch_size=1)
-    test_loader = DataLoader(test_dataset, batch_size=1)
+    train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=512)
+    val_loader = DataLoader(val_dataset, batch_size=16)
+    test_loader = DataLoader(test_dataset, batch_size=16)
 
     # Model
-    model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3)
-    model.to(device)
+    model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3).to(device)
 
     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
     criterion = torch.nn.CrossEntropyLoss()
 
     for epoch in range(1, 10):
+        log.info(f"Epoch: {epoch:03d} > > >")
         train(model, train_loader)
         train_acc, train_f1, train_loss, train_table = test(train_loader)
         val_acc, val_f1, val_loss, val_table = test(val_loader)
diff --git a/embeddings/post_embedding_builder.py b/embeddings/post_embedding_builder.py
index 2a3674ca5c28ed0f5c8e5aea8e32c670639a84ba..29fa4ff950dc933323b5d8a2e3ecd6877da7e3cc 100644
--- a/embeddings/post_embedding_builder.py
+++ b/embeddings/post_embedding_builder.py
@@ -9,17 +9,18 @@ from typing import List
 
 from NextTagEmbedding import NextTagEmbedding, NextTagEmbeddingTrainer
 
-logging.basicConfig()
-logging.getLogger().setLevel(logging.DEBUG)
-log = logging.getLogger("PostEmbedding")
-
 from bs4 import BeautifulSoup
 import spacy
 import torch
 import torch.nn as nn
 from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel, AutoTokenizer, AutoModel, AutoConfig
+
+from custom_logger import setup_custom_logger
 from unixcoder import UniXcoder
 
+log = setup_custom_logger('post_embedding_builder', logging.INFO)
+
+
 Import = namedtuple("Import", ["module", "name", "alias"])
 Function = namedtuple("Function", ["function_name", "parameter_names"])
 
diff --git a/embeddings/static_graph_construction.py b/embeddings/static_graph_construction.py
index e8e9f3baf6e6b3ef13b6154a7956e6b9fcbfd958..eefa12f9eee0de73b1c811c3e709be5ce2d02a30 100644
--- a/embeddings/static_graph_construction.py
+++ b/embeddings/static_graph_construction.py
@@ -30,12 +30,11 @@ class BatchedHeteroData(HeteroData):
 
 class StaticGraphConstruction:
 
-    # PostEmbedding is costly to instantiate in each StaticGraphConstruction instance.
-    post_embedding_builder = PostEmbedding()
-    tag_embedding_model = NextTagEmbeddingTrainer.load_model("../models/tag-emb-7_5mil-50d-63653-3.pt", embedding_dim=50, vocab_size=63654, context_length=3)
-    module_embedding_model = ModuleEmbeddingTrainer.load_model("../models/module-emb-1milx5-30d-49911.pt", embedding_dim=30, vocab_size=49911)
-
     def __init__(self):
+        # PostEmbedding is costly to instantiate in each StaticGraphConstruction instance.
+        post_embedding_builder = PostEmbedding()
+        tag_embedding_model = NextTagEmbeddingTrainer.load_model("../models/tag-emb-7_5mil-50d-63653-3.pt", embedding_dim=50, vocab_size=63654, context_length=3)
+        module_embedding_model = ModuleEmbeddingTrainer.load_model("../models/module-emb-1milx5-30d-49911.pt", embedding_dim=30, vocab_size=49911)
         self._known_tags = {}  # tag_name -> index
         self._known_modules = {}  # module_name -> index
         self._data = BatchedHeteroData()
diff --git a/embeddings/tag_embedding_models/10mil_500d_embd.pt b/embeddings/tag_embedding_models/10mil_500d_embd.pt
deleted file mode 100644
index 91881da44787cc40eec3e91f2139414a0b7258ba..0000000000000000000000000000000000000000
Binary files a/embeddings/tag_embedding_models/10mil_500d_embd.pt and /dev/null differ
diff --git a/embeddings/tag_embedding_models/10mil_embd.pt b/embeddings/tag_embedding_models/10mil_embd.pt
deleted file mode 100644
index 46c0789a42dc92fcdf73ae2286c1bb88172de0c0..0000000000000000000000000000000000000000
Binary files a/embeddings/tag_embedding_models/10mil_embd.pt and /dev/null differ
diff --git a/embeddings/tag_embedding_models/1mil_750d_embd.pt b/embeddings/tag_embedding_models/1mil_750d_embd.pt
deleted file mode 100644
index b266910b1bd67b214f235a5df5aba06707098871..0000000000000000000000000000000000000000
Binary files a/embeddings/tag_embedding_models/1mil_750d_embd.pt and /dev/null differ
diff --git a/embeddings/user-graph.svg b/figures/user-graph.svg
similarity index 100%
rename from embeddings/user-graph.svg
rename to figures/user-graph.svg