diff --git a/entcl/cl.py b/entcl/cl.py
index 8b137891791fe96927ad78e64b0aad7bded08bdc..04d816ccfd618915725bdefc455a6d19e208993b 100644
--- a/entcl/cl.py
+++ b/entcl/cl.py
@@ -1 +1,200 @@
+from typing import Dict, Tuple
+from entcl.data.util import TransformedTensorDataset
+import pandas as pd
+import torch
+from loguru import logger
+from tqdm import tqdm
 
+
+def cl_session(args, session_dataset: TransformedTensorDataset, model: torch.nn.Module, mapping: Dict[int, int]) -> torch.nn.Module:
+    """
+    Run a continual learning session on the given session dataset using the given model
+    :param args: Arguments object with the attributes `device`, `cl_epochs`, `seed`.
+    :param session_dataset: TransformedTensorDataset with the session data. Should have OOD predtypes and pseudo labels.
+    :param model: torch.nn.Module with the model to use for continual learning. `forward` should return `logits, features`.
+    :return: torch.nn.Module with the updated model.
+    """
+    logger.debug(f"Begin Continual Learning Session {args.current_session}")
+    # make sure the dataset has the correct shape
+    assert len(session_dataset.tensor_dataset.tensors) == 5, "Session Dataset should have 5 tensors, (data, true labels, true types, pred types, pseudo labels). Got: " + str(len(session_dataset.tensor_dataset.tensors))
+    
+    # create the required training stuff and things and junk
+    
+    # we are only training with nove data at the moment, so we need to whittle down the dataset to only the predicted novel samples
+    novel_samples_mask = session_dataset.tensor_dataset.tensors[3] == 1 # novel samples are labelled with 1. predtypes are the 4th tensor in the dataset
+    
+    novel_tensors = [tensor[novel_samples_mask] for tensor in session_dataset.tensor_dataset.tensors]
+    
+    # adjust the psuedo labels (which start from 0 atm) to start from args.dataset.known + ((args.current_session - 1) * args.dataset.novel_inc)
+    adjust_value = args.dataset.known + ((args.current_session - 1) * args.dataset.novel_inc)
+    logger.debug(f"Adjusting Pseudo Labels by {adjust_value}")
+    novel_tensors[4] += adjust_value
+    logger.debug(f"Adjusted Pseudo Labels: {torch.unique(novel_tensors[4], sorted=True)}")
+    
+    session_dataset = TransformedTensorDataset(
+        tensor_dataset=torch.utils.data.TensorDataset(*novel_tensors),
+        transform=session_dataset.transform,
+    )
+    
+    train_loader = torch.utils.data.DataLoader(
+        dataset=session_dataset,
+        batch_size=args.batch_size,
+        shuffle=True,
+        num_workers=args.num_workers,
+        pin_memory=args.pin_memory,
+        drop_last=True,
+    )
+    
+    old_test_loader = torch.utils.data.DataLoader(
+        dataset=args.dataset.get_dataset(session=args.current_session, train=False)["old"],
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.num_workers,
+        pin_memory=args.pin_memory,
+        drop_last=False,
+    )
+    
+    new_test_loader = torch.utils.data.DataLoader(
+        dataset=args.dataset.get_dataset(session=args.current_session, train=False)["new"],
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.num_workers,
+        pin_memory=args.pin_memory,
+        drop_last=False,
+    )
+    
+    all_test_loader = torch.utils.data.DataLoader(
+        dataset=args.dataset.get_dataset(session=args.current_session, train=False)["all"],
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.num_workers,
+        pin_memory=args.pin_memory,
+        drop_last=False,
+    )
+    
+    optimiser = torch.optim.SGD(
+        model.parameters(),
+        lr=args.lr,
+        momentum=args.momentum,
+        weight_decay=args.weight_decay,
+    )
+    
+    criterion = torch.nn.CrossEntropyLoss()
+    
+    results = None
+    
+    # train the model
+    for epoch in range (args.cl_epochs):
+        logger.debug(f"Session {args.current_session} Epoch {epoch} Started")
+        
+        # train model
+        model, train_loss = _train(args, model, train_loader, optimiser, criterion)
+        logger.info(f"Epoch {epoch}: TRAINING : Train Loss: {train_loss}")
+        
+        # validate model on the three test sets
+        model, old_accuracy, old_val_loss = _validate(args, model, old_test_loader, criterion, mapping)
+        logger.info(f"Epoch {epoch}: VALIDATION : OLD Accuracy: {old_accuracy}, OLD Val Loss: {old_val_loss}")
+        
+        model, new_accuracy, new_val_loss = _validate(args, model, new_test_loader, criterion, mapping)
+        logger.info(f"Epoch {epoch}: VALIDATION : NEW Accuracy: {new_accuracy}, NEW Val Loss: {new_val_loss}")
+        
+        model, all_accuracy, all_val_loss = _validate(args, model, all_test_loader, criterion, mapping)
+        logger.info(f"Epoch {epoch}: VALIDATION : ALL Accuracy: {all_accuracy}, ALL Val Loss: {all_val_loss}")
+        
+        # just save the head
+        torch.save(model.head.state_dict(), f"{args.exp_dir}/session_{args.current_session}/head_s{args.current_session}_ep{epoch}.pt")
+        logger.debug(f"Session {args.current_session} Head saved to {args.exp_dir}/session_{args.current_session}/head_s{args.current_session}_ep{epoch}.pt")
+        
+        # create a df to hold the results
+        epoch_results = pd.DataFrame(
+            {
+                "epoch": [epoch],
+                "train_loss": [train_loss],
+                "old_accuracy": [old_accuracy],
+                "old_val_loss": [old_val_loss],
+                "new_accuracy": [new_accuracy],
+                "new_val_loss": [new_val_loss],
+                "all_accuracy": [all_accuracy],
+                "all_val_loss": [all_val_loss],
+            }
+        )
+        
+        # append the results to the results dataframe
+        results = epoch_results if results is None else pd.concat([results, epoch_results], ignore_index=True)
+        
+        # save the results dataframe
+        results.to_csv(f"{args.exp_dir}/results_s{args.current_session}.csv", index=False)
+        logger.debug(f"Session {args.current_session} Results saved to {args.exp_dir}/results_s{args.current_session}.csv")
+        
+    return model
+
+def _train(args, model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, optimiser: torch.optim.Optimizer, criterion: torch.nn.Module) -> Tuple[torch.nn.Module, float]:
+    """
+    Train the model on the given train_loader for one epoch
+    :param args: Arguments object with the attributes `device`.
+    :param model: torch.nn.Module with the model to train.
+    :param train_loader: torch.utils.data.DataLoader with the training data.
+    :param optimiser: torch.optim.Optimiser with the optimiser to use.
+    :param criterion: torch.nn.Module with the loss function to use.
+    :return: torch.nn.Module with the updated model and float with the total loss.
+    """
+    model.train()
+    loss_total = 0
+    for x, _, _, _, y in tqdm(train_loader, desc="Training", unit="batch"): # we only want psuedo labels for training
+        x, y = x.to(args.device), y.to(args.device)
+        logger.debug(f"x shape: {x.shape}, y shape: {y.shape}")
+        logger.debug(f"Psuedo Labels (Unique): {torch.unique(y, sorted=True)}")
+        optimiser.zero_grad()
+        logits, _ = model(x)
+        logger.debug(f"Logits Shape: {logits.shape}")
+            
+        loss = criterion(logits, y)
+        loss.backward()
+        optimiser.step()
+        loss = loss.item()
+        logger.debug(f"Loss: {loss}")
+        loss_total += loss
+    loss_total /= len(train_loader)
+    return model, loss_total
+
+        
+        
+def _validate(args, model: torch.nn.Module, val_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, mapping: Dict[int, int]) -> Tuple[torch.nn.Module, float, float]:
+    """
+    Validate the model on the given val_loader
+    :param args: Arguments object with the attributes `device`.
+    :param model: torch.nn.Module with the model to validate.
+    :param val_loader: torch.utils.data.DataLoader with the validation data.
+    :param criterion: torch.nn.Module with the loss function to use.
+    :param mapping: Dict[int, int] with the mapping from true labels to pseudo labels.
+    :return: torch.nn.Module with the updated model, float with the accuracy and float with the total loss.
+    """
+    model.eval()
+    correct = 0
+    total = 0
+    loss_total = 0
+    with torch.no_grad():
+        for x, y, _ in tqdm(val_loader, desc="Validating", unit="batch"):
+            x, y = x.to(args.device), y.to(args.device)
+            logger.debug(f"x shape: {x.shape}, y shape: {y.shape}")
+            logger.debug(f"True Labels (Unique): {torch.unique(y, sorted=True)}")
+            
+            # if the true labels are in the mapping, replace them with the pseudo labels
+            for true_label, pseudo_label in mapping.items():
+                y[y == true_label] = pseudo_label # for each y in y, if y == true_label, replace it with pseudo_label
+            
+            logits, _ = model(x)
+            loss = criterion(logits, y)
+            loss = loss.item()
+            loss_total += loss
+            
+            _, predicted = torch.max(logits, 1)
+            
+            num_correct = (predicted == y).sum().item()
+            total += y.size(0)
+            correct += num_correct
+            
+            logger.debug(f"This Batch Num Correct: {num_correct}, Total: {y.size(0)}, Accuracy: {num_correct / y.size(0)}, Loss: {loss}")
+            
+    return model, correct / total, loss_total / len(val_loader)
+            
\ No newline at end of file
diff --git a/entcl/data/cifar100.py b/entcl/data/cifar100.py
index bf52d3bb49e8fda2abf72bd0ff222cd9fd5e6277..5fbbf9df6df141b67f376ca197e298f999ca19c9 100644
--- a/entcl/data/cifar100.py
+++ b/entcl/data/cifar100.py
@@ -2,11 +2,12 @@ import os
 from typing import Dict, List, Union
 from entcl.data.util import TransformedTensorDataset
 import torch
+from torchvision import disable_beta_transforms_warning
+disable_beta_transforms_warning()
 from torchvision.datasets import CIFAR100 as _CIFAR100
 import torchvision.transforms.v2 as transforms
 from entcl.config import CIFAR100_DIR
 from loguru import logger
-
 CIFAR100_TRANSFORM = transforms.Compose(
     [
         transforms.ToTensor(),
@@ -35,18 +36,19 @@ class CIFAR100Dataset:
         :param cl_n_novel: Number of samples per novel class for each CL session. Default: 400
         :param cl_n_prevnovel: Number of samples per previously novel class for each CL session. Default: 20
         """
-        if known >= 100:
+        self.num_classes = 100
+        if known >= self.num_classes:
             raise ValueError("Number of known classes cannot be greater than 100")
         self.transform = CIFAR100_TRANSFORM
         self.known = known
         self.sessions = sessions
-
+        self.novel = self.known - self.known
         self.pretrain_n_known = pretrain_n_known
-        self.num_classes = 100
+
         self.cl_n_known = cl_n_known
         self.cl_n_novel = cl_n_novel
         self.cl_n_prevnovel = cl_n_prevnovel
-        self.novel_inc = (100 - self.known) // self.sessions
+        self.novel_inc = (self.num_classes - self.known) // self.sessions
         # Verify the CL settings
         logger.debug(
             "Verifying incremental learning settings\n"
@@ -190,9 +192,16 @@ class CIFAR100Dataset:
                     
             samples = torch.cat(samples)
             labels = torch.cat(labels)
+            
+            
+            # when labelling ood samples, we do not care for prevnovel. there are two classes known/seen or novel/unseen
+            ood_labels = labels.clone()
+            ood_labels[ood_labels < novel_start] = 0  
+            ood_labels[ood_labels >= novel_start] = 1
+            
             logger.debug(f"Creating dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
             logger.debug(f"Classes in this Session {session}'s Train Dataset: {labels.unique(sorted=True)}")
-            datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
+            datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels, ood_labels), transform=self.transform)
 
         return datasets
     
@@ -204,23 +213,77 @@ class CIFAR100Dataset:
         """
         
         datasets = {}
+        
+        
+        logger.debug(f"Splitting test data for session 0")
+        # pretraining session's dataset (session 0) only has known classes, and the dataset is technically an ALL dataset
+        samples, labels = [], []
+        for class_idx in range(self.known):
+            samples.append(masterlist[class_idx])
+            labels.append(torch.full((masterlist[class_idx].size(0),), class_idx, dtype=torch.long))
+        
+        samples = torch.cat(samples)
+        labels = torch.cat(labels)
+        types = torch.full((labels.size(0),), 0, dtype=torch.long) # they are all known classes, so type is 0
+        
+        logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
+        logger.debug(f"Classes in Session 0's Test Dataset: {labels.unique(sorted=True)}")
+        datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels, types), transform=self.transform)
+        
+        del samples, labels, types # free up memory
+        
         logger.debug(f"Splitting test data for {self.sessions} sessions")
-        for session in range(0, self.sessions + 1):
+        for session in range(1, self.sessions + 1):
+            datasets[session] = {}
             logger.debug(f"Splitting test data for session {session}")
-            samples, labels = [], []
             
-            # get data for all seen classes (self.known + session * self.novel_inc)
-            seen_classes_end = self.known + (session * self.novel_inc)
-            logger.debug(f"There are {seen_classes_end} seen classes. Starting at 0 (inc), ending at {seen_classes_end} (exc)")
-            for class_idx in range(seen_classes_end):
-                samples.append(masterlist[class_idx])
-                labels.append(torch.full((masterlist[class_idx].size(0),), class_idx, dtype=torch.long))
+            # for cl sessions, there are 3 datasets old, new and all.
+            old_end_idx = self.known + (session - 1) * self.novel_inc
+            new_end_idx = self.known + session * self.novel_inc
+            logger.debug(f"Old classes end at {old_end_idx} (exc), New classes start at {old_end_idx} (inc) and end at {new_end_idx} (exc)")
             
-            samples = torch.cat(samples)
-            labels = torch.cat(labels)
-            logger.debug(f"Creating test dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
-            logger.debug(f"Classes in Session {session}'s Test Dataset: {labels.unique(sorted=True)}")
-            datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
+            
+            # old dataset -------------------------------------
+            old_samples, old_labels = [], []
+            for class_idx in range(old_end_idx):
+                old_samples.append(masterlist[class_idx])
+                old_labels.append(torch.full((masterlist[class_idx].size(0),), class_idx, dtype=torch.long))
+            
+            old_samples = torch.cat(old_samples)
+            old_labels = torch.cat(old_labels)
+            old_types = torch.full((old_labels.size(0),), 0, dtype=torch.long) # they are all known classes, so type is 0
+            
+            logger.debug(f"Creating OLD dataset for session {session}. There are {len(old_samples)} samples, and {len(old_labels)} labels. There are {old_labels.unique().size(0)} different classes")
+            logger.debug(f"Classes in Session {session}'s OLD Dataset: {old_labels.unique(sorted=True)}")
+            
+            datasets[session] = {"old": TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(old_samples, old_labels, old_types), transform=self.transform)}
+            
+            
+            # new dataset -------------------------------------
+            new_samples, new_labels = [], []
+            for class_idx in range(old_end_idx, new_end_idx):
+                new_samples.append(masterlist[class_idx])
+                new_labels.append(torch.full((masterlist[class_idx].size(0),), class_idx, dtype=torch.long))
+            
+            new_samples = torch.cat(new_samples)
+            new_labels = torch.cat(new_labels)
+            new_types = torch.full((new_labels.size(0),), 1, dtype=torch.long) # they are all novel classes, so type is 1
+            
+            logger.debug(f"Creating NEW dataset for session {session}. There are {len(new_samples)} samples, and {len(new_labels)} labels. There are {new_labels.unique().size(0)} different classes")
+            logger.debug(f"Classes in Session {session}'s NEW Dataset: {new_labels.unique(sorted=True)}")
+            
+            datasets[session]["new"] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(new_samples, new_labels, new_types), transform=self.transform)
+            
+            
+            # all dataset -------------------------------------
+            all_samples = torch.cat([old_samples, new_samples])
+            all_labels = torch.cat([old_labels, new_labels])
+            all_types = torch.cat([old_types, new_types])
+            
+            logger.debug(f"Creating ALL dataset for session {session}. There are {len(all_samples)} samples, and {len(all_labels)} labels. There are {all_labels.unique().size(0)} different classes")
+            logger.debug(f"Classes in Session {session}'s ALL Dataset: {all_labels.unique(sorted=True)}")
+            
+            datasets[session]["all"] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(all_samples, all_labels, all_types), transform=self.transform)
             
         return datasets
                 
diff --git a/entcl/data/util.py b/entcl/data/util.py
index 23dc10818d04cb4bf3859620572be4cad45a732b..00c3aa191c21528a7727637140556bfc4a09731a 100644
--- a/entcl/data/util.py
+++ b/entcl/data/util.py
@@ -8,7 +8,8 @@ class TransformedTensorDataset(torch.utils.data.Dataset):
         return len(self.tensor_dataset)
 
     def __getitem__(self, idx):
-        data, target = self.tensor_dataset[idx]
+        the_tuple = self.tensor_dataset[idx]
         if self.transform:
-            data = self.transform(data)
-        return data, target
+            data = self.transform(the_tuple[0])
+        
+        return data, *the_tuple[1:]
diff --git a/entcl/pretrain.py b/entcl/pretrain.py
index 5bfc2f83ee87d61014c6ff5a4e1719f9f6f1e546..11c2b0713fad58d08f9a562ea7843dc66ee6d7e7 100644
--- a/entcl/pretrain.py
+++ b/entcl/pretrain.py
@@ -7,7 +7,7 @@ from tqdm import tqdm
 
 def pretrain(args, model):
     train_dataset = args.dataset.get_dataset(session=0, train=True)
-    train_dataloader = torch.utils.data.DataLoader(
+    train_loader = torch.utils.data.DataLoader(
         dataset=train_dataset,
         batch_size=args.batch_size,
         shuffle=True,
@@ -17,7 +17,7 @@ def pretrain(args, model):
     )
     
     val_dataset = args.dataset.get_dataset(session=0, train=False)
-    val_dataloader = torch.utils.data.DataLoader(
+    val_loader = torch.utils.data.DataLoader(
         dataset=val_dataset,
         batch_size=args.batch_size,
         shuffle=False,
@@ -41,10 +41,10 @@ def pretrain(args, model):
         logger.debug(f"Loading pretrained model from {args.pretrain_load}")
         model.head.load_state_dict(torch.load(args.pretrain_load, weights_only=True))
         model = model.to(args.device)
-        model, accuracy = _validate(args, model, val_dataloader)
+        model, accuracy, _ = _validate(args, model, val_loader, criterion=criterion)
         logger.info(f"Loaded Pretrained Model Accuracy: {accuracy}")
         return model
-    else:
+    elif args.mode in ["pretrain", "both"]:
         logger.debug("No pretrained model to load, training from scratch")
         model = model.to(args.device)
         for epoch in range(args.pretrain_epochs):
@@ -52,15 +52,16 @@ def pretrain(args, model):
             logger.debug(f"Epoch {epoch} Started")
             
             # train model
-            model, train_loss = _train(args, model, train_dataloader, optimiser, criterion)
+            model, train_loss = _train(args, model, train_loader, optimiser, criterion)
             logger.info(f"Epoch {epoch}: TRAINING : Train Loss: {train_loss}")
             
             # validate model
-            model, accuracy, val_loss = _validate(args, model, val_dataloader, criterion)
+            model, accuracy, val_loss = _validate(args, model, val_loader, criterion)
             logger.info(f"Epoch {epoch}: VALIDATION : Accuracy: {accuracy}, Val Loss: {val_loss}")
             
             # just save the head, the backbone is frozen, and is fucking massive
-            model.head.save(f"{args.exp_root}/{args.name}/head_pretrain_{epoch}.pth")
+            torch.save(model.head.state_dict(), f"{args.exp_dir}/session_0/head_s0_ep{epoch}.pt")
+            logger.debug(f"Session 0 Head saved to {args.exp_dir}/session_0/head_s0_ep{epoch}.pt")
             
             # create a dataframe with the results
             epoch_results = pd.DataFrame(
@@ -69,22 +70,26 @@ def pretrain(args, model):
             )
             
             # append the results to the results dataframe
-            results = epoch_results if results is None else results.append(epoch_results)
+            results = epoch_results if results is None else pd.concat([results, epoch_results])
             
             # save the results dataframe
-            results.to_csv(f"{args.exp_root}/{args.name}/results_pretrain.csv", index=False)
-            logger.debug(f"Epoch {epoch} Finished. Pretrain Results Saved to {args.exp_root}/{args.name}/results_pretrain.csv")
+            results.to_csv(f"{args.exp_dir}/results_s0.csv", index=False)
+            logger.debug(f"Epoch {epoch} Finished. Pretrain Results Saved to {args.exp_dir}/results_s0.csv")
         return model
+    else:
+        raise ValueError(f"No Model to load and mode is not pretrain or both. Mode: {args.mode}, Pretrain Load: {args.pretrain_load}")
 
-def _train(args, model, train_dataloader, optimiser, criterion):
+def _train(args, model, train_loader, optimiser, criterion):
     model.train()
     loss_total = 0
-    for x, y in tqdm(train_dataloader, desc=f"Training", unit = "batch"):
+    for x, y, _ in tqdm(train_loader, desc=f"Training", unit = "batch"):
         x, y = x.to(args.device), y.to(args.device)
-        logger.debug(f"x shape: {x.shape}, y shape: {y.shape}")
+        logger.debug(f"X Shape: {x.shape}, Y Shape: {y.shape}")
+        logger.debug(f"True Labels (Unique): {torch.unique(y, sorted=True)}")
+        
         optimiser.zero_grad()
         logits, _ = model(x)
-        logger.debug(f"logits shape: {logits.shape}")
+        logger.debug(f"Logits Shape: {logits.shape}")
             
         loss = criterion(logits, y)
         loss.backward()
@@ -92,18 +97,19 @@ def _train(args, model, train_dataloader, optimiser, criterion):
         loss = loss.item()
         logger.debug(f"Loss: {loss}")
         loss_total += loss
-    loss_total /= len(train_dataloader)
+    loss_total /= len(train_loader)
     return model, loss_total
 
-def _validate(args, model, val_dataloader, criterion):
+def _validate(args, model, val_loader, criterion):
     model.eval()
     correct = 0
     total = 0
     loss_total = 0
     with torch.no_grad():
-        for x, y in tqdm(val_dataloader, desc=f"Validating", unit = "batch"):
+        for x, y, _ in tqdm(val_loader, desc=f"Validating", unit = "batch"):
             x, y = x.to(args.device), y.to(args.device)
             logger.debug(f"x shape: {x.shape}, y shape: {y.shape}")
+            logger.debug(f"True Labels (Unique): {torch.unique(y, sorted=True)}")
                 
             logits, _ = model(x)
             logger.debug(f"logits shape: {logits.shape}")
@@ -120,4 +126,4 @@ def _validate(args, model, val_dataloader, criterion):
             
             logger.debug(f"This Batch Num Correct: {num_correct}, Total: {y.size(0)}, Accuracy: {num_correct / y.size(0)}, Loss: {loss}")
             
-    return model, correct / total, loss_total / len(val_dataloader)
+    return model, correct / total, loss_total / len(val_loader)
diff --git a/entcl/run.py b/entcl/run.py
index 48604efeab1fd5a1ed4a301d894c6f3949907ece..073a505a840f30381da69a2601700d9853055bcd 100644
--- a/entcl/run.py
+++ b/entcl/run.py
@@ -1,11 +1,14 @@
 import argparse
 import os
 import sys
+from entcl.cl import cl_session
+from entcl.utils.ncd import find_novel_classes_for_session
+from entcl.utils.ood import label_ood_for_session
 from loguru import logger
 from datetime import datetime
 import torch
 
-from entcl.utils.util import seed
+from entcl.utils.util import generate_unique_path, seed
 
 from entcl.models.model import ENTCLModel
 from entcl.pretrain import pretrain
@@ -16,10 +19,39 @@ def main(args: argparse.Namespace):
     model = ENTCLModel(head=args.head, backbone_url=args.backbone_url, backbone=args.backbone, backbone_source=args.backbone_source)
     logger.debug(f"Model: {model}")
     
-    if args.mode == 'pretrain':
-        model = pretrain(args, model)
-    else:
-        raise NotImplementedError(f"Mode {args.mode} not implemented")
+    logger.info("Pretraining Model (Session 0)")
+    model = pretrain(args, model)
+    
+    if args.mode in ['cl', 'both']:
+        logger.info("Starting Continual Learning")
+    
+    for session in range(1, args.sessions + 1):
+        logger.info(f"Starting Continual Learning Session {session}")
+        args.current_session = session
+        session_dataset = args.dataset.get_dataset(session)
+        
+        # OOD detection
+        session_dataset = label_ood_for_session(args, session_dataset, model) # returns a new dataset with the OOD samples labelled 
+        
+        # NCD
+        session_dataset, mapping = find_novel_classes_for_session(args, session_dataset, model) # returns a new dataset with the novel samples labelled
+        
+        # dataset should now have the form (data, true labels, true types, pred types, pseudo labels)
+        
+        # Expand Classification Head & Initialise
+        model.head.expand(args.dataset.novel_inc) # we are cheating here, we know the number of novel classes
+        
+        # freeze the weights for the existing classes. We are only training unknown samples (EG: 50 (known) + (2 (session) - 1) * 10 (novel_inc) = 60 classes have been seen in cl session 2)
+        model.head.freeze(start_idx=0, end_idx=args.dataset.known + ((session -1) * args.dataset.novel_inc))
+        
+        # run continual learning session
+        model = cl_session(args, session_dataset, model, mapping)
+        
+        
+        
+        
+        
+    
 
 
 
@@ -28,7 +60,7 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     # program args
     parser.add_argument('--name', type=str, default="entcl_" + datetime.now().isoformat(timespec='seconds'))
-    parser.add_argument('--mode', type=str, default='pretrain', help='Mode to run the program', choices=['pretrain', 'cl', 'dryrun'])
+    parser.add_argument('--mode', type=str, default='both', help='Mode to run the program', choices=['pretrain', 'cl', 'both'])
     parser.add_argument('--dryrun', action='store_true', default=False, help='Dry Run Mode. Does not save anything')
     
     parser.add_argument('--debug', action='store_true', default=False, help='Debug Mode. Epochs are only done once. Enables Verbose Mode automatically')
@@ -79,6 +111,9 @@ if __name__ == "__main__":
     # ood args
     parser.add_argument('--ood_score', type=str, default='entropy', help='Changes the metric(s) to base OOD detection on', choices=['entropy', 'energy', 'both'])
     parser.add_argument('--ood_eps', type=float, default=1e-8, help='Epsilon value for computing entropy in OOD detection')
+    
+    # ncd args
+    parser.add_argument('--ncd_findk_method', type=str, default='cheat', help='Method to use for finding the number of novel classes', choices=['elbow', 'silhouette', 'gap', 'cheat'])
     args = parser.parse_args()
     
     seed(args.seed) # seed everything
@@ -102,6 +137,9 @@ if __name__ == "__main__":
     # initialise directories
     os.makedirs(args.exp_root, exist_ok=True)
     args.exp_dir = os.path.join(args.exp_root, args.name)
+        
+    if args.name == "debug":
+        args.exp_dir = generate_unique_path(args.exp_dir)
     os.makedirs(args.exp_dir, exist_ok=False)
     
     # initialise logger
@@ -129,6 +167,9 @@ if __name__ == "__main__":
         args.head = MLPHead(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256)
         logger.debug(f"Using MLP Head: {args.head}")
         
+    if args.mode == 'cl' and args.pretrain_load is None:
+        raise ValueError("Continual Learning Mode requires a pretrained model to load")
+        
     argstr = "Arguments: \n"
     for arg in vars(args):
         argstr += f"{arg}: {getattr(args, arg)}\n"
diff --git a/entcl/utils/findk.py b/entcl/utils/findk.py
deleted file mode 100644
index 8241a2236aefcaf53e6ef53379555b1af2f26766..0000000000000000000000000000000000000000
--- a/entcl/utils/findk.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import torch
-import numpy as np
-from sklearn.cluster import KMeans
-from loguru import logger
-from tqdm import tqdm
-
-def elbow(features: torch.Tensor, args) -> int:
-    """
-    Finds the optimal number of clusters using the elbow method and the kneed library.
-    :param features: torch.Tensor with the features to cluster.
-    :param args: Arguments object with the attribute `seed`.
-    :return: int with the optimal number of clusters.
-    """
-    from kneed import KneeLocator
-    inertias = []
-    
-    ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1)
-    logger.debug(f"Finding k using elbow method for k in {ks}")
-    
-    # Calculate the inertia for each k
-    for k in tqdm(ks, desc="Calculating Inertias", leave=True, unit="k"):
-        kmeans = KMeans(n_clusters=k, random_state=args.seed)
-        kmeans.fit(features)
-        inertias.append(kmeans.inertia_)
-        
-    logger.debug(f"Elbow Inertias: {inertias}")
-    logger.debug(f"Running KneeLocator")
-    
-    # Find the knee
-    kneedle = KneeLocator(ks, inertias, curve="convex", direction="decreasing")
-    logger.info(f"Elbow Method Found k: {kneedle.knee}")
-    return kneedle.knee
-    
-def silhouette(features: torch.Tensor, args) -> int:
-    """
-    Finds the optimal number of clusters using the silhouette method.
-    :param features: torch.Tensor with the features to cluster.
-    :param args: Arguments object with the attribute `seed`.
-    :return: int with the optimal number of clusters.
-    """
-    from sklearn.metrics import silhouette_score
-    silhouettes = []
-    
-    ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1)
-    logger.debug(f"Finding k using silhouette method for k in {ks}")
-    
-    # Calculate the silhouette score for each k
-    for k in tqdm(ks, desc="Calculating Silhouettes", leave=True, unit="k"):
-        kmeans = KMeans(n_clusters=k, random_state=args.seed)
-        pseudo_labels = kmeans.fit_predict(features)
-        silhouette = silhouette_score(features, pseudo_labels)
-        silhouettes.append(silhouette)
-        
-    logger.debug(f"Silhouette Scores: {silhouettes}")
-    
-    # Find the best k
-    best_k = ks[silhouettes.index(max(silhouettes))]
-    logger.info(f"Silhouette Method Found k: {best_k}")
-    return best_k
-
-def gap(features: torch.Tensor, args) -> int:
-    """
-    Finds the optimal number of clusters using the gap statistic.
-    :param features: torch.Tensor with the features to cluster.
-    :param args: Arguments object with the attribute `seed`.
-    :return: int with the optimal number of clusters.
-    """
-    logger.debug("Finding k using gap statistic")
-    features_np = features.cpu().numpy()
-    n_samples, n_features = features_np.shape
-    
-    min_vals = features_np.min(axis=0)
-    max_vals = features_np.max(axis=0)
-    
-    gap_stats = {}
-    wks = []
-    wks_refs = []
-    
-    for k in tqdm(range(args.ncd_findk_mink, args.ncd_findk_maxk + 1), desc="Calculating Gap Statistic", leave=True, unit="k"):
-        logger.debug(f"Calculating Gap Statistic for k={k}")
-        kmeans = KMeans(n_clusters=k, random_state=args.seed)
-        kmeans.fit(features_np)
-        wk = kmeans.inertia_
-        wks.append(np.log(wk))
-        
-        wk_refs = []
-        for _ in tqdm(range(args.ncd_gap_nrefs), desc="Calculating Reference Gap Statistic", leave=True, unit="ref"):
-            logger.debug(f"Calculating Reference Gap Statistic for k={k}")
-            ref_data = np.random.uniform(min_vals, max_vals, (n_samples, n_features))
-            kmeans = KMeans(n_clusters=k, random_state=args.seed)
-            kmeans.fit(ref_data)
-            wk_ref = kmeans.inertia_
-            wk_refs.append(np.log(wk_ref))
-        wks_refs.append(np.mean(wk_refs))
-
-        gap_stats[k] = wks_refs[-1] - wks[-1]
-    
-    optimal_k = max(gap_stats, key=gap_stats.get)
-    
-    logger.info(f"Gap Statistic Found k: {optimal_k}")
-    
-    return optimal_k
\ No newline at end of file
diff --git a/entcl/utils/ncd.py b/entcl/utils/ncd.py
index 60df52c8ae2f30e471a5e1636c7532ca281087d2..8c5fb51fbab5dde6a5dcf98085be43ba186c406f 100644
--- a/entcl/utils/ncd.py
+++ b/entcl/utils/ncd.py
@@ -1,19 +1,21 @@
-
-
 import os
-from typing import Union
+from typing import Dict, Union
 from entcl.data.util import TransformedTensorDataset
-from entcl.utils.findk import elbow, gap, silhouette
 from loguru import logger
 import numpy as np
+import pandas as pd
 from sklearn.cluster import KMeans
 from sklearn.metrics import confusion_matrix
 from scipy.optimize import linear_sum_assignment
 from entcl.utils.util import generate_unique_path
 import torch
+import torch.utils
 from tqdm import tqdm
 
-def _extract_features(args, dataset: TransformedTensorDataset, model: torch.nn.Module) -> torch.Tensor:
+
+def _extract_features(
+    args, dataset: TransformedTensorDataset, model: torch.nn.Module
+) -> torch.Tensor:
     """
     Extracts features from the data in the dataset using the provided model's backbone.
     :param args: Arguments object with the attributes `device`, .
@@ -27,21 +29,26 @@ def _extract_features(args, dataset: TransformedTensorDataset, model: torch.nn.M
         pin_memory=args.pin_memory,
         shuffle=False,
     )
-    
+
     model.eval()
     features = None
     with torch.no_grad():
-        for batch in tqdm(dataloader, desc="Extracting Features", leave=True, unit="batch"):
-            x = batch[0] # Only the data is needed, assumes batch is iterable.
+        for batch in tqdm(
+            dataloader, desc="Extracting Features", leave=True, unit="batch"
+        ):
+            x = batch[0]  # Only the data is needed, assumes batch is iterable.
             x = x.to(args.device)
-            x_proj, _= model(x) # we do not need the predicted logits, only the features.
+            x_proj, _ = model(
+                x
+            )  # we do not need the predicted logits, only the features.
             if features is None:
                 features = x_proj
             else:
                 features = torch.cat((features, x_proj), dim=0)
-    
+
     return features
 
+
 def plot_confmat(confmat: torch.Tensor, path: str) -> None:
     """
     Plots a confusion matrix and saves it to the specified path.
@@ -49,105 +56,125 @@ def plot_confmat(confmat: torch.Tensor, path: str) -> None:
     :param path: str with the path to save the confusion matrix plot.
     """
     try:
+        import matplotlib
+
+        matplotlib.use("Agg")
         import matplotlib.pyplot as plt
+
         import seaborn as sns
         from sklearn.metrics import ConfusionMatrixDisplay
 
         fig, ax = plt.subplots(figsize=(10, 10))
-        sns.heatmap(confmat, annot=True, fmt='d', cmap='Blues', ax=ax)
-        ax.set_xlabel('Predicted Label')
-        ax.set_ylabel('True Label')
-        ax.set_title('Confusion Matrix')
+        sns.heatmap(confmat, annot=True, fmt="d", cmap="Blues", ax=ax)
+        ax.set_xlabel("Predicted Label")
+        ax.set_ylabel("True Label")
+        ax.set_title("Confusion Matrix")
         plt.savefig(path)
         plt.close()
     except ImportError:
-        logger.error("Could not import matplotlib or seaborn. Cannot plot confusion matrix. Confusion Matrix not saved.")
-        
-def _calculate_clustering_accuracy(true_labels: torch.Tensor, pseudo_labels: torch.Tensor, args) -> None:
+        logger.error(
+            "Could not import matplotlib or seaborn. Cannot plot confusion matrix. Confusion Matrix not saved."
+        )
+
+
+def generate_mapping(
+    true_labels: torch.Tensor, pseudo_labels: torch.Tensor, args
+) -> Dict[int, int]:
     """
-    Calculates the clustering accuracy between the true labels and the pseudo labels.
+    Calculates the clustering accuracy between the true labels and the pseudo labels and generates a mapping between the two for validation and testing.
     :param true_labels: torch.Tensor with the true labels for the novel samples.
     :param pseudo_labels: torch.Tensor with the pseudo labels for the novel samples.
+    :return: Dict[int, int] a mapping between the true labels and the pseudo labels.
     """
-    assert true_labels.shape == pseudo_labels.shape, f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}"
+    logger.debug("Calculating Clustering Accuracy")
+    assert (
+        true_labels.shape == pseudo_labels.shape
+    ), f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}"
 
-    
     # true labels will be > 50 and psuedo labels will start at 0. we need to adjust the pseudo labels to match the true labels.
     # we will assume the true labels are sequential, and the lowest true label is 0.
+
+    novel_true_min_idx = true_labels.min()
+    pseudo_labels += novel_true_min_idx
+
+    true_labels = true_labels.cpu().numpy()
+    pseudo_labels = pseudo_labels.cpu().numpy()
+
+    conf_mat = confusion_matrix(true_labels, pseudo_labels)
+    row_idxs, col_idxs = linear_sum_assignment(
+        -conf_mat
+    )  # Hungarian Algorithm to find the best matching between the true and pseudo labels.
+
+    # align the pseudo labels with the true labels based on the hungarian algorithm results
+    pseudo_labels_aligned = np.zeros_like(pseudo_labels)
+    for pseudo_label, true_label in zip(col_idxs, row_idxs):
+        pseudo_labels_aligned[pseudo_labels == pseudo_label] = true_label
+        
+    # create a mapping between the true labels and the pseudo labels, used in validation and testing
+    mapping = {true_label: pseudo_label for pseudo_label, true_label in zip(col_idxs, row_idxs)}
     
-    true_labels -= true_labels.min() # Adjust the true labels to start at 0.
-    
-    # optimal assignment of pseudo labels to true labels
-    confusion_mat = torch.from_numpy(confusion_matrix(true_labels.cpu().numpy(), pseudo_labels.cpu().numpy())) # Rows are true labels, columns are pseudo labels
-    cost_mat = -confusion_mat # Hungarian algorithm minimizes the cost, so we negate the confusion matrix.
-    row_idx, col_idx = linear_sum_assignment(cost_mat) # Hungarian algorithm to find the optimal assignment.
-    
-    # Maps pseudo labels to true labels.
-    label_assignments = dict(zip(col_idx, row_idx)) # Maps pseudo labels to true labels.
-    aligned_predicted_labels = torch.tensor([label_assignments[pseudo_label] for pseudo_label in pseudo_labels]) # Aligns the predicted labels with the true labels.
-    
-    # Calculate the accuracy
-    correct_assignments = (aligned_predicted_labels == true_labels).sum().item() # Number of correct assignments.
-    accuracy = correct_assignments / true_labels.shape[0] # Accuracy is the number of correct assignments divided by the number of samples.
-    
-    # Plot the confusion matrix
-    confmat_path = generate_unique_path(os.path.join(args.exp_dir, args.name, "confusion_matrix.png"))
-    plot_confmat(confusion_mat, path=confmat_path)
-    
-    unique_true_labels = np.unique(true_labels.cpu().numpy())
-    unique_pseudo_labels = np.unique(pseudo_labels.cpu().numpy())
-    
-    # find any ignored labels
-    ignored_true_labels = set(unique_true_labels) - set(unique_true_labels[row_idx])
-    ignored_pseudo_labels = set(unique_pseudo_labels) - set(unique_pseudo_labels[col_idx])
-    
-    # Log the results
-    
-    string = f"Clustering Accuracy Computed: {accuracy*100:.2f}%"
-    string += f"\n True Labels (adjusted): {unique_true_labels}"
-    string += f"\n Pseudo Labels: {unique_pseudo_labels}"
-    string += f"\n # of True Labels {len(unique_true_labels)}, # of Pseudo Labels {len(unique_pseudo_labels)}"
-    string += f"\n Confusion Matrix: Plot saved to `{confmat_path}`"
-    string += f"Ignored True Labels: {ignored_true_labels} Count: {len(ignored_true_labels)}"
-    string += f"Ignored Pseudo Labels: {ignored_pseudo_labels} Count: {len(ignored_pseudo_labels)}"
-    string += f"\n If there are ignored true labels, the number of clusters has been overestimated. If there are ignored pseudo labels, the number of clusters has been underestimated."
-    string += f"\n Ignored labels are not included in the accuracy calculation."
+    # compute the overall accuracy
+    overall_accuracy = np.mean(true_labels == pseudo_labels_aligned)
+
+    # compute the per-class accuracy
+    per_class_accuracy = {}
+    for true_class in np.unique(true_labels):
+        mask = true_labels == true_class
+        per_class_accuracy[true_class] = np.mean(
+            true_labels[mask] == pseudo_labels_aligned[mask]
+        )
+
+    string = f"NCD Clustering Accuracies for Session {args.current_session}:"
+    for true_class, acc in per_class_accuracy.items():
+        string += f"\nTrue Class {true_class}: {acc*100:4f}%"
     string += f"\n"
-    string += f"Per True Label Accuracy:"
-    for true_label in unique_true_labels:
-        mask = true_labels == true_label
-        true_label_accuracy = (aligned_predicted_labels[mask] == true_label).sum().item() / mask.sum().item()
-        string += f"\n True Label: {true_label} Accuracy: {true_label_accuracy*100:.2f}%"
-    
+    string += f"\nOverall Accuracy: {overall_accuracy*100:4f}%"
+
     logger.info(string)
     
+    string = f"Mapping for Session {args.current_session}:"
+    for true_label, pseudo_label in mapping.items():
+        string += f"\nTrue Label {true_label} -> Pseudo Label {pseudo_label}"
+    
+    logger.info(string)
+
+    # plot the confusion matrix
+    try:
+        plot_confmat(
+            conf_mat,
+            os.path.join(args.exp_dir, f"confmat_session_{args.currect_session}.png"),
+        )
+        logger.debug(
+            f"Confusion Matrix saved to {os.path.join(args.exp_dir, f'confmat_session_{args.currect_session}.png')}"
+        )
+    except Exception as e:
+        logger.error(
+            f"Could not plot the confusion matrix. Error: {e}\n Confusion Matrix not saved. Continuing..."
+        )
+    return mapping
 
-def _cluster_features(args, features:torch.Tensor) -> torch.Tensor:
+def _cluster_features(args, features: torch.Tensor) -> torch.Tensor:
     """
     Clusters the features using K Means.
-    :param args: Arguments object with the attribute `ncd_findk_method`, `novel_classes_per_session`, `seed`.
+    :param args: Arguments object with the attribute `novel_classes_per_session`, `seed`.
     :param features: torch.Tensor with the features to cluster. Assumes the features are from novel-classed samples only.
     :return: torch.Tensor with unadjusted psuedo-labels for given features.
     """
-    
-    if args.ncd_findk_method == "cheat":
-        k = args.novel_classes_per_session
-    elif args.ncd_findk_method == "elbow":
-        k = elbow(features, args)
-    elif args.ncd_findk_method == "silhouette":
-        k = silhouette(features, args)
-    elif args.ncd_findk_method == "gap":
-        k = gap(features, args)
-    else:
-        raise ValueError(f"Unknown method for finding k: {args.ncd_findk_method}")
-    
-    logger.info(f"FindK Method: {args.ncd_findk_method} k: {k} True k: {args.novel_classes_per_session}")
-    
-    kmeans = KMeans(n_clusters=k, random_state=args.seed)
-    pseudo_labels = torch.tensor(kmeans.fit_predict(features))
+    logger.debug(
+        f"Clustering Features: Using K = {args.novel_classes_per_session} (args.novel_classes_per_session)"
+    )
+
+    kmeans = KMeans(n_clusters=args.novel_classes_per_session, random_state=args.seed)
+
+    pseudo_labels = torch.tensor(kmeans.fit_predict(features.cpu().numpy()))
     return pseudo_labels
 
-def find_novel_classes_for_session(args, session_dataset: TransformedTensorDataset, model: torch.nn.Module) -> Union[torch.Tensor, TransformedTensorDataset]:
+
+def find_novel_classes_for_session(
+    args,
+    session_dataset: TransformedTensorDataset,
+    model: torch.nn.Module,
+) -> Union[torch.Tensor, TransformedTensorDataset]:
     """
     Finds the novel classes in the given session dataset using KMeans clustering and the given model
     :param args: Arguments object with the attributes `device`, `ncd_findk_method`, `novel_classes_per_session`, `seed`.
@@ -155,25 +182,66 @@ def find_novel_classes_for_session(args, session_dataset: TransformedTensorDatas
     :param model: torch.nn.Module with the model to use for clustering.
     :return: torch.Tensor with the pseudo-labels for the novel classes.
     """
-    features = _extract_features(args, session_dataset, model)
-    pseudo_labels = _cluster_features(args, features)
-    return pseudo_labels
-    
 
-def discover_classes_in_session_dataset(
-    args,
-    session_dataset: TransformedTensorDataset,
-    model: torch.nn.Module,
-    return_new_dataset: bool = True,
-):
-    """
-    Discovers classes in the session dataset using the provided model.
-    This step is for after OOD detection, which modifies the session dataset to include the OOD label in the third column.
-    
-    :param args: Arguments object.
-    :param session_dataset: TransformedTensorDataset for the session (data, target, ood).
-    :param model: The model to evaluate with.
-    :param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted classes.
-    """
-    
-    
\ No newline at end of file
+    # first we need to create a TransformedTensorDataset with only the predicted novel samples
+    novel_samples_mask = (
+        session_dataset.tensor_dataset.tensors[3] == 1
+    )  # OOD samples are marked with 2. the predicted types tensor (known/novel labels) is the third tensor in the dataset
+    novel_samples_mask = novel_samples_mask.cpu()  # idk why this is not already on cpu
+
+    novel_tensors = [
+        tensor[novel_samples_mask].cpu()
+        for tensor in session_dataset.tensor_dataset.tensors
+    ]  # get only the predicted novel samples
+
+    for i, t in enumerate(novel_tensors):
+        logger.debug(f"Tensor[{i}] Shape: {t.shape}")
+
+    novel_dataset = TransformedTensorDataset(
+        tensor_dataset=torch.utils.data.TensorDataset(*novel_tensors),
+        transform=session_dataset.transform,
+    )
+
+    # extract features from the novel samples
+    novel_features = _extract_features(args, novel_dataset, model)
+
+    # cluster the features
+    pseudo_labels = _cluster_features(args, novel_features)
+
+    # calculate the clustering accuracy (not used in the dataset, only for logging and testing)
+    mapping = generate_mapping(
+        novel_dataset.tensor_dataset.tensors[1], pseudo_labels, args
+    )
+
+    # next we need to align the pseudo labels tensor with the original dataset, giving known samples a pseudo label of -1
+    pseudo_labels_aligned = torch.full(
+        session_dataset.tensor_dataset.tensors[2].shape,
+        -1,
+        dtype=torch.long,
+        device=session_dataset.tensor_dataset.tensors[2].device,
+    )
+    pseudo_labels_aligned[novel_samples_mask] = (
+        pseudo_labels  # whereever the sample is novel, assign the corresponding pseudo label.
+    )
+
+    # just for checking that this stuff above works properly
+    clustering_df = pd.DataFrame(columns=["true_labels", "type", "pseudo_labels"])
+    clustering_df["true_labels"] = session_dataset.tensor_dataset.tensors[1].cpu().numpy()
+    clustering_df["type"] = session_dataset.tensor_dataset.tensors[2].cpu().numpy()
+    clustering_df["pseudo_labels"] = pseudo_labels_aligned.cpu().numpy()
+
+    # save the dataset to a csv file
+    dataset_path = os.path.join(args.exp_dir, f"clustering_s{args.current_session}.csv")
+    clustering_df.to_csv(dataset_path, index=False)
+
+    # create a new dataset with the pseudo labels alongside the original tensors and the same transform
+    # NOTE: the tensors attribute of a TensorDataset is a tuple. we can't append to it so instead we create a new TensorDataset with the new pseudo labels tensor.
+    # NOTE: The dataset at the end of NCD is in the for data, true labels, true type, predicted type, psuedo labels
+    new_dataset = TransformedTensorDataset(
+        tensor_dataset=torch.utils.data.TensorDataset(
+            *session_dataset.tensor_dataset.tensors, pseudo_labels_aligned.cpu()
+        ),
+        transform=session_dataset.transform,
+    )
+
+    return new_dataset, mapping
diff --git a/entcl/utils/ood.py b/entcl/utils/ood.py
index 4656bd5807be178de23bc88a4edb66731e76164d..c0b8daa21b670185283021c66e7913accce711a6 100644
--- a/entcl/utils/ood.py
+++ b/entcl/utils/ood.py
@@ -1,10 +1,13 @@
+import os
 from typing import Iterable, Tuple, Union
 from entcl.data.util import TransformedTensorDataset
+from entcl.utils.util import generate_unique_path
 from loguru import logger
+import numpy as np
+import pandas as pd
 from sklearn.mixture import GaussianMixture
 import torch
 from tqdm import tqdm
-from entcl.utils.findk import elbow, silhouette, gap
 
 def _get_scores(
     session_dataset: TransformedTensorDataset, model: torch.nn.Module, args
@@ -39,11 +42,11 @@ def _get_scores(
 
     entropies, energies = [], []
 
-    for x, _ in tqdm(session_loader, desc="Calculating Scores", leave=True, unit="batch"):
+    for x, _, _ in tqdm(session_loader, desc="Calculating Scores", leave=True, unit="batch"):
         x = x.to(args.device)
         with torch.no_grad():
             logits, _ = model(x)
-
+            logits = logits[:, : args.dataset.known]  # TODO this is just a stopgap. the head needs to be expanded after all            
             if compute_entropy:
                 softmax = torch.nn.functional.softmax(logits, dim=1)
                 entropy = -torch.sum(
@@ -72,7 +75,7 @@ def _fit_predict_gmm(data: torch.Tensor, args) -> torch.Tensor:
     :param args: Object with the attribute `seed`.
     """
     logger.debug(f"Fitting Gaussian Mixture Model to Data")
-    gmm = GaussianMixture(n_components=2, random_state=args.seed)
+    gmm = GaussianMixture(n_components=2, random_state=args.seed, max_iter=1000, init_params='kmeans', tol=1e-4)
 
     # Fit and predict using the GMM
     predtypes_hard = torch.tensor(
@@ -149,7 +152,6 @@ def label_ood_for_session(
     args,
     session_dataset: TransformedTensorDataset,
     model: torch.nn.Module,
-    return_new_dataset: bool = True,
 ) -> Union[TransformedTensorDataset, torch.Tensor]:
     """
     OOD Labelling for a session dataset. This function computes entropy and/or energy scores for the dataset based on `args.ood_score` and fits a Gaussian Mixture Model to each of the scores. The GMM has 2 components, one for in-distribution samples and one for OOD samples. The function then resolves conflicts between the entropy and energy predictions by selecting the type with the highest confidence. Finally, the function returns a new dataset for the session, including the predicted types.
@@ -196,21 +198,62 @@ def label_ood_for_session(
         )
     else:
         raise ValueError(f"Invalid OOD Score: {args.ood_score}")
+    
+    logger.debug("Returning New Dataset")
+    # return the new dataset with the predicted types
+    session_dataset = TransformedTensorDataset(
+        tensor_dataset=torch.utils.data.TensorDataset(
+            session_dataset.tensor_dataset.tensors[0],  # the data
+            session_dataset.tensor_dataset.tensors[1],  # the labels
+            session_dataset.tensor_dataset.tensors[2],  # the real types
+            final_predtypes,  # the predicted types (duh)
+        ),
+        transform=session_dataset.transform,
+    )
 
-    if return_new_dataset:
-        logger.debug("Returning New Dataset")
-        # return the new dataset with the predicted types
-        session_dataset = TransformedTensorDataset(
-            tensor_dataset=torch.utils.data.TensorDataset(
-                session_dataset.tensor_dataset.tensors[0],  # the data
-                session_dataset.tensor_dataset.tensors[1],  # the labels
-                final_predtypes,  # the predicted types (duh)
-            ),
-            transform=session_dataset.transform,
-        )
+    # compute the OOD Accuracy
+    _compute_ood_accuracy(session_dataset, args)
+    
+    return session_dataset
 
-        return session_dataset
-    else:
-        logger.debug("Returning Predicted Types")
-        # return just the predicted types
-        return final_predtypes
+
+def _compute_ood_accuracy(
+    session_dataset: TransformedTensorDataset, args
+) -> None:
+    """
+    Computes the Accuracy of the OOD Labelling for a session dataset.
+    :param session_dataset: Dataset for the session with 3 tensors (data, labels, predicted types).
+    :param args: Objects with the attributes `dataset`.
+    :return: None
+    """
+    
+    
+
+    # Create the DataFrame
+    df = pd.DataFrame(
+        {
+            "label": session_dataset.tensor_dataset.tensors[1].cpu().numpy(),
+            "type": session_dataset.tensor_dataset.tensors[2].cpu().numpy(),
+            "predtype": session_dataset.tensor_dataset.tensors[3].cpu().numpy(),
+        }
+    )
+    
+    df["is_correct"] = df["predtype"] == df["type"]
+    
+    
+    known = df[df["type"] == 0]
+    novel = df[df["type"] == 1]
+    
+    string = f"OOD Accuracies for Session {args.current_session}:"
+    string += f"\nKnown Samples Correct/Total (Accuracy%): {known['is_correct'].sum()}/{known.shape[0]} ({known['is_correct'].mean()*100:.4f}%)"
+    string += f"\nKnown Samples Incorrect/Total (Error%) : {len(known) - known['is_correct'].sum()}/{known.shape[0]} ({1 - known['is_correct'].mean()*100:.4f}%)"
+    string += f"\n"
+    string += f"\nNovel Samples Correct/Total (Accuracy%): {novel['is_correct'].sum()}/{novel.shape[0]} ({novel['is_correct'].mean()*100:.4f}%)"
+    string += f"\nNovel Samples Incorrect/Total (Error%) : {len(novel) - novel['is_correct'].sum()}/{novel.shape[0]} ({1 - novel['is_correct'].mean()*100:.4f}%)"
+    string += f"\n"
+    string += f"\nOverall Accuracy: {df['is_correct'].mean()*100:.4f}%"
+    logger.info(string)
+    
+    file_path = generate_unique_path(os.path.join(args.exp_dir, f'ood_accuracy_{args.currect_session}.csv'))
+    df.to_csv(file_path, index=False)
+    logger.info(f"OOD Accuracy CSV saved to {file_path}")
\ No newline at end of file