From 5b99f0c38243dde7863dac95ef353f2e4ed167ab Mon Sep 17 00:00:00 2001
From: Joseph Omar <j.omar@soton.ac.uk>
Date: Wed, 4 Dec 2024 20:01:27 +0000
Subject: [PATCH] Using a dataset of extracted features rather than a
 backbone-based system

---
 entcl/cl.py                                   |  22 +-
 entcl/config.py                               |   3 +-
 entcl/data/__init__.py                        |   0
 entcl/data/cifar100/__init__.py               |   2 +
 entcl/data/cifar100/cifar100feats.py          |  55 ++++
 .../cifar100partition.py}                     |  91 ++----
 entcl/data/cifar100/prep_cifar100_feats.py    | 101 ++++++
 entcl/models/model.py                         |   7 +
 entcl/pretrain.py                             |   5 +-
 entcl/run.py                                  |   4 +-
 entcl/utils/ncd.py                            |  45 ++-
 entcl/utils/ood.py                            |  30 +-
 experiments/experiments3.ipynb                | 306 ++++++++++--------
 13 files changed, 415 insertions(+), 256 deletions(-)
 delete mode 100644 entcl/data/__init__.py
 create mode 100644 entcl/data/cifar100/__init__.py
 create mode 100644 entcl/data/cifar100/cifar100feats.py
 rename entcl/data/{cifar100.py => cifar100/cifar100partition.py} (85%)
 create mode 100644 entcl/data/cifar100/prep_cifar100_feats.py

diff --git a/entcl/cl.py b/entcl/cl.py
index 2a3d403..4efdcb3 100644
--- a/entcl/cl.py
+++ b/entcl/cl.py
@@ -1,36 +1,32 @@
 from typing import Dict, Tuple
-from entcl.data.util import TransformedTensorDataset
 import pandas as pd
 import torch
 from loguru import logger
 from tqdm import tqdm
 
 
-def cl_session(args, session_dataset: TransformedTensorDataset, model: torch.nn.Module, mapping: Dict[int, int]) -> torch.nn.Module:
+def cl_session(args, session_dataset: torch.utils.data.Dataset, model: torch.nn.Module, mapping: Dict[int, int]) -> torch.nn.Module:
     """
     Run a continual learning session on the given session dataset using the given model
     :param args: Arguments object with the attributes `device`, `cl_epochs`, `seed`.
-    :param session_dataset: TransformedTensorDataset with the session data. Should have OOD predtypes and pseudo labels.
+    :param session_dataset: torch.utils.data.Dataset with the session data. Should have OOD predtypes and pseudo labels.
     :param model: torch.nn.Module with the model to use for continual learning. `forward` should return `logits, features`.
     :return: torch.nn.Module with the updated model.
     """
     logger.debug(f"Begin Continual Learning Session {args.current_session}")
     # make sure the dataset has the correct shape
-    assert len(session_dataset.tensor_dataset.tensors) == 5, "Session Dataset should have 5 tensors, (data, true labels, true types, pred types, pseudo labels). Got: " + str(len(session_dataset.tensor_dataset.tensors))
-    for i, t in enumerate(session_dataset.tensor_dataset.tensors):
+    assert len(session_dataset.tensors) == 5, "Session Dataset should have 5 tensors, (data, true labels, true types, pred types, pseudo labels). Got: " + str(len(session_dataset.tensors))
+    for i, t in enumerate(session_dataset.tensors):
         if t.device != torch.device("cpu"):
             logger.warning(f"Tensor {i} is not on CPU")
     # create the required training stuff and things and junk
     
     # we are only training with nove data at the moment, so we need to whittle down the dataset to only the predicted novel samples
-    novel_samples_mask = session_dataset.tensor_dataset.tensors[3] == 1 # novel samples are labelled with 1. predtypes are the 4th tensor in the dataset
+    novel_samples_mask = session_dataset.tensors[3] == 1 # novel samples are labelled with 1. predtypes are the 4th tensor in the dataset
     novel_samples_mask = novel_samples_mask.cpu()
-    novel_tensors = [tensor[novel_samples_mask] for tensor in session_dataset.tensor_dataset.tensors]
+    novel_tensors = [tensor[novel_samples_mask] for tensor in session_dataset.tensors]
     
-    session_dataset = TransformedTensorDataset(
-        tensor_dataset=torch.utils.data.TensorDataset(*novel_tensors),
-        transform=session_dataset.transform,
-    )
+    session_dataset = torch.utils.data.TensorDataset(*novel_tensors)
     
     train_loader = torch.utils.data.DataLoader(
         dataset=session_dataset,
@@ -154,7 +150,7 @@ def _train(args, model: torch.nn.Module, train_loader: torch.utils.data.DataLoad
         logger.debug(f"x shape: {x.shape}, y shape: {y.shape}")
         logger.debug(f"Psuedo Labels (Unique): {torch.unique(y, sorted=True)}")
         optimiser.zero_grad()
-        logits, _ = model(x)
+        logits = model.forward_head(x)
         logger.debug(f"Logits Shape: {logits.shape}")
             
         loss = criterion(logits, y)
@@ -192,7 +188,7 @@ def _validate(args, model: torch.nn.Module, val_loader: torch.utils.data.DataLoa
             for true_label, pseudo_label in mapping.items():
                 y[y == true_label] = pseudo_label # for each y in y, if y == true_label, replace it with pseudo_label
             
-            logits, _ = model(x)
+            logits = model.forward_head(x)
             loss = criterion(logits, y)
             loss = loss.item()
             loss_total += loss
diff --git a/entcl/config.py b/entcl/config.py
index 68bca1b..1e4e3c3 100644
--- a/entcl/config.py
+++ b/entcl/config.py
@@ -1 +1,2 @@
-CIFAR100_DIR = '/cl/datasets/CIFAR/'
\ No newline at end of file
+CIFAR100_DIR = '/cl/datasets/CIFAR/'
+CIFAR100FEATSDIR = '/cl/datasets/CIFAR100Features/'
\ No newline at end of file
diff --git a/entcl/data/__init__.py b/entcl/data/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/entcl/data/cifar100/__init__.py b/entcl/data/cifar100/__init__.py
new file mode 100644
index 0000000..c584990
--- /dev/null
+++ b/entcl/data/cifar100/__init__.py
@@ -0,0 +1,2 @@
+from .cifar100feats import CIFAR100Features
+from .cifar100partition import PartitionedCIFAR100FeaturesDataset
\ No newline at end of file
diff --git a/entcl/data/cifar100/cifar100feats.py b/entcl/data/cifar100/cifar100feats.py
new file mode 100644
index 0000000..8700507
--- /dev/null
+++ b/entcl/data/cifar100/cifar100feats.py
@@ -0,0 +1,55 @@
+import torch
+import os
+from typing import Optional, Callable
+
+class CIFAR100Features(torch.utils.data.Dataset):
+    """
+    Dataset class for CIFAR100 features and labels
+    :param root: str with the root directory for the CIFAR100 features and labels
+    :param train: bool indicating whether to load the training or testing set
+    :param transform: Optional[Callable] with a function to transform the features
+    """
+    
+    def __init__(self, root:str, train:bool = True, transform:Optional[Callable] = None, target_transform: Optional[Callable] = None):
+        self.root = root
+        self.train = train
+        self.transform = transform
+        self.target_transform = target_transform
+        
+        self.file_path = os.path.join(self.root, "cifar100feats_train" if self.train else "cifar100feats_test")
+        self.feats, self.labels = torch.load(self.file_path)
+        
+    def __len__(self):
+        return len(self.labels)
+    
+    def __getitem__(self, idx):
+        """
+        Get an item from the dataset
+        :param idx: int with the index of the item to retrieve
+        :return: torch.Tensor with the features and int with the label
+        """
+        feats = self.feats[idx]
+        label = self.labels[idx]
+        
+        if self.transform is not None:
+            feats = self.transform(feats)
+        
+        if self.target_transform is not None:
+            label = self.target_transform(label)
+        
+        return feats, label
+    
+if __name__ == "__main__":
+    train_dataset = CIFAR100Features(root="/cl/datasets/CIFAR100Features/", train=True)
+    test_dataset = CIFAR100Features(root="/cl/datasets/CIFAR100Features/", train=False)
+    
+    print(f"Train Dataset Length: {len(train_dataset)}")
+    print(f"Test Dataset Length: {len(test_dataset)}")
+    
+    train_feats, train_labels = train_dataset[0]
+    test_feats, test_labels = test_dataset[0]
+    
+    print(f"Train Features Shape: {train_feats.shape}")
+    print(f"Train Label: {train_labels}")
+    print(f"Test Features Shape: {test_feats.shape}")
+    print(f"Test Label: {test_labels}")
\ No newline at end of file
diff --git a/entcl/data/cifar100.py b/entcl/data/cifar100/cifar100partition.py
similarity index 85%
rename from entcl/data/cifar100.py
rename to entcl/data/cifar100/cifar100partition.py
index 4e9d751..418384a 100644
--- a/entcl/data/cifar100.py
+++ b/entcl/data/cifar100/cifar100partition.py
@@ -1,28 +1,15 @@
 import os
-from typing import Dict, List, Union
-from entcl.data.util import TransformedTensorDataset
+from typing import Dict, List
+from entcl.data.cifar100.cifar100feats import CIFAR100Features
 import torch
 from torchvision import disable_beta_transforms_warning
 disable_beta_transforms_warning()
-from torchvision.datasets import CIFAR100 as _CIFAR100
+from tqdm import tqdm
 import torchvision.transforms.v2 as transforms
-from entcl.config import CIFAR100_DIR
+from entcl.config import CIFAR100FEATSDIR
 from loguru import logger
 
-CIFAR100_TRANSFORM = transforms.Compose(
-    [
-        transforms.Resize(int(224/0.875), interpolation=3),
-        transforms.CenterCrop(224),
-        transforms.ToTensor(),
-        transforms.Normalize(
-            mean=[0.485, 0.456, 0.406],
-            std=[0.229, 0.224, 0.225]
-        )
-        
-    ]
-)
-
-class CIFAR100Dataset:
+class PartitionedCIFAR100FeaturesDataset:
     def __init__(
         self,
         known: int = 50,
@@ -32,29 +19,22 @@ class CIFAR100Dataset:
         cl_n_prevnovel: int = 20,
         sessions: int = 5,
         mutex: bool = True,
-        force_download = False,
     ):
-        """
-        CIFAR100 dataset with incremental learning settings.
-        :param known: Number of known classes. Default: 50
-        :param pretrain_n_known: Number of samples per known class for pretraining. Default: 400
-        :param cl_n_known: Number of samples per known class for each CL session. Default: 20
-        :param cl_n_novel: Number of samples per novel class for each CL session. Default: 400
-        :param cl_n_prevnovel: Number of samples per previously novel class for each CL session. Default: 20
-        """
         self.num_classes = 100
+        
         if known >= self.num_classes:
             raise ValueError("Number of known classes cannot be greater than 100")
-        self.transform = CIFAR100_TRANSFORM
+        
         self.known = known
         self.sessions = sessions
         self.novel = self.known - self.known
         self.pretrain_n_known = pretrain_n_known
-
+    
         self.cl_n_known = cl_n_known
         self.cl_n_novel = cl_n_novel
         self.cl_n_prevnovel = cl_n_prevnovel
         self.novel_inc = (self.num_classes - self.known) // self.sessions
+        
         # Verify the CL settings
         logger.debug(
             "Verifying incremental learning settings\n"
@@ -65,27 +45,23 @@ class CIFAR100Dataset:
             + f"Samples per previously novel class per CL session: {self.cl_n_prevnovel}\n"
             + f"CL sessions: {self.sessions}"
         )
-        self._verify_splits()
-        
         
-        download = (not os.path.exists(os.path.join(CIFAR100_DIR, "cifar-100-python"))) or force_download
-        logger.debug(f"Download: {download}")
+        self._verify_splits()
         
+
+        # Train Set --------------------------------------------------------------
         # load and sort the data into master lists
         logger.debug("Loading and Sorting CIFAR100 Train split")
-        master_train_data: Dict[int, torch.Tensor] = self._split_data_by_class(_CIFAR100(
-            CIFAR100_DIR, train=True, transform=transforms.ToTensor(), download=download
-        ))
+        master_train_data: Dict[int, torch.Tensor] = self._split_data_by_class(CIFAR100Features(root=CIFAR100FEATSDIR, train=True))
         # split the data into datasets for each session
         logger.debug("Splitting Train Data for Sessions")
         self.train_datasets = self._split_train_data_for_sessions(master_train_data, mutex=mutex)
-        
         del master_train_data
         
+        # Test Set --------------------------------------------------------------
         logger.debug("Loading and Sorting CIFAR100 Test split")
-        master_test_data: Dict[int, torch.Tensor] = self._split_data_by_class(_CIFAR100(
-            CIFAR100_DIR, train=False,transform=transforms.ToTensor(), download=download
-        ))
+        master_test_data: Dict[int, torch.Tensor] = self._split_data_by_class(CIFAR100Features(root=CIFAR100FEATSDIR, train=False))
+        
         logger.debug("Splitting Test Data for Sessions")
         self.test_datasets = self._split_test_data_for_sessions(master_test_data)
         
@@ -99,8 +75,6 @@ class CIFAR100Dataset:
         :param train: Whether to get the training set. Default: True
         :return: Dataset for the given session
         """
-        if session == "pretrain":
-            session = 0
         if session not in self.train_datasets:
             raise ValueError(f"Session {session} does not exist, only sessions {list(self.train_datasets.keys())} exist")
         
@@ -168,7 +142,7 @@ class CIFAR100Dataset:
         types = torch.full((labels.size(0),), 0, dtype=torch.long) # they are all known classes, so type is 0
         logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
         logger.debug(f"Classes in Pretraining Dataset: {labels.unique(sorted=True)}")
-        datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels, types), transform=self.transform)
+        datasets[0] = torch.utils.data.TensorDataset(samples, labels, types)
         
         # CL sessions' datasets
         logger.debug("Splitting data for CL sessions")    
@@ -208,7 +182,7 @@ class CIFAR100Dataset:
             
             logger.debug(f"Creating dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
             logger.debug(f"Classes in this Session {session}'s Train Dataset: {labels.unique(sorted=True)}")
-            datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels, ood_labels), transform=self.transform)
+            datasets[session] = torch.utils.data.TensorDataset(samples, labels, ood_labels)
 
         return datasets
     
@@ -235,7 +209,7 @@ class CIFAR100Dataset:
         
         logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
         logger.debug(f"Classes in Session 0's Test Dataset: {labels.unique(sorted=True)}")
-        datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels, types), transform=self.transform)
+        datasets[0] = torch.utils.data.TensorDataset(samples, labels, types)
         
         del samples, labels, types # free up memory
         
@@ -263,7 +237,7 @@ class CIFAR100Dataset:
             logger.debug(f"Creating OLD dataset for session {session}. There are {len(old_samples)} samples, and {len(old_labels)} labels. There are {old_labels.unique().size(0)} different classes")
             logger.debug(f"Classes in Session {session}'s OLD Dataset: {old_labels.unique(sorted=True)}")
             
-            datasets[session] = {"old": TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(old_samples, old_labels, old_types), transform=self.transform)}
+            datasets[session] = {"old": torch.utils.data.TensorDataset(old_samples, old_labels, old_types)}
             
             
             # new dataset -------------------------------------
@@ -279,7 +253,7 @@ class CIFAR100Dataset:
             logger.debug(f"Creating NEW dataset for session {session}. There are {len(new_samples)} samples, and {len(new_labels)} labels. There are {new_labels.unique().size(0)} different classes")
             logger.debug(f"Classes in Session {session}'s NEW Dataset: {new_labels.unique(sorted=True)}")
             
-            datasets[session]["new"] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(new_samples, new_labels, new_types), transform=self.transform)
+            datasets[session]["new"] = torch.utils.data.TensorDataset(new_samples, new_labels, new_types)
             
             
             # all dataset -------------------------------------
@@ -290,12 +264,12 @@ class CIFAR100Dataset:
             logger.debug(f"Creating ALL dataset for session {session}. There are {len(all_samples)} samples, and {len(all_labels)} labels. There are {all_labels.unique().size(0)} different classes")
             logger.debug(f"Classes in Session {session}'s ALL Dataset: {all_labels.unique(sorted=True)}")
             
-            datasets[session]["all"] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(all_samples, all_labels, all_types), transform=self.transform)
+            datasets[session]["all"] = torch.utils.data.TensorDataset(all_samples, all_labels, all_types)
             
         return datasets
                 
     
-    def _split_data_by_class(self, dataset: _CIFAR100, batch_size=64, num_workers=0):
+    def _split_data_by_class(self, dataset: CIFAR100Features, batch_size=128, num_workers=4):
         # Create a DataLoader to load the dataset in batches
         dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
 
@@ -315,7 +289,6 @@ class CIFAR100Dataset:
             all_data[class_id] = torch.stack(all_data[class_id])
 
         return all_data
-    
         
     def _verify_splits(self):
         # verify sessions
@@ -352,21 +325,23 @@ class CIFAR100Dataset:
             raise ValueError(
                 f'Number of samples per previously novel class for each CL session should be between 0 and (500 - cl_n_novel) / sessions. Given "cl_n_prevnovel": {self.cl_n_prevnovel}, "cl_n_novel": {self.cl_n_novel}, "sessions": {self.sessions}'
             )
-
+            
+            
 if __name__ == "__main__":
     from time import sleep
     logger.debug("Entry Point: cifar100.py")
-    if CIFAR100_DIR is None:
+    if CIFAR100Features is None:
         raise ValueError("CIFAR100_DIR is not set. Please set it in entcl/config.py")
-    cifar100 = CIFAR100Dataset()
+    
+    cifar100 = PartitionedCIFAR100FeaturesDataset()
     for session in range(cifar100.sessions + 1):
         logger.debug(f"Session {session}")
         train = cifar100.get_dataset(session, train=True)
         test = cifar100.get_dataset(session, train=False)
         logger.debug(f"Train Len: {len(train)}, Image Shape {train[0][0].shape}, Label Shape {train[0][1].shape}")
-        logger.debug(f"Test Len: {len(test)}, Image Shape {test[0][0].shape}, Label Shape {test[0][1].shape}")
-        
-
-
-    sleep(5)
+        if session == 0:
+            logger.debug(f"Test Len: {len(test)}, Image Shape {test[0][0].shape}, Label Shape {test[0][1].shape}")
+        else:
+            for key in test.keys():
+                logger.debug(f"Test {key} Len: {len(test[key])}, Image Shape {test[key][0][0].shape}, Label Shape {test[key][0][1].shape}")
     
\ No newline at end of file
diff --git a/entcl/data/cifar100/prep_cifar100_feats.py b/entcl/data/cifar100/prep_cifar100_feats.py
new file mode 100644
index 0000000..258128f
--- /dev/null
+++ b/entcl/data/cifar100/prep_cifar100_feats.py
@@ -0,0 +1,101 @@
+import argparse
+import os
+import torch
+from entcl.config import CIFAR100_DIR, CIFAR100FEATSDIR
+from torchvision.datasets import CIFAR100
+from tqdm import tqdm
+
+from torchvision import disable_beta_transforms_warning
+disable_beta_transforms_warning()
+import torchvision.transforms.v2 as transforms
+
+CIFAR100_TRANSFORM = transforms.Compose(
+    [
+        transforms.Resize(int(224/0.875), interpolation=3, antialias=True),
+        transforms.CenterCrop(224),
+        transforms.ToImageTensor(),
+        transforms.ConvertImageDtype(torch.float32),
+        transforms.Normalize(
+            mean=[0.485, 0.456, 0.406],
+            std=[0.229, 0.224, 0.225]
+        )
+        
+    ]
+)
+
+def prep_cifar100_feats(
+    device: torch.device,
+    num_workers: int,
+    batch_size: int,
+    backbone_path: str,
+    backbone_name: str
+):
+    model = torch.hub.load(backbone_path, backbone_name, pretrained=True)
+    model.to(device)
+    
+    train_dataset = CIFAR100(root=CIFAR100_DIR, train=True, transform=CIFAR100_TRANSFORM, download=True)
+    test_dataset = CIFAR100(root=CIFAR100_DIR, train=False, transform=CIFAR100_TRANSFORM, download=True)
+    
+    train_loader = torch.utils.data.DataLoader(
+        dataset=train_dataset,
+        batch_size=batch_size,
+        shuffle=False,
+        num_workers=num_workers,
+        pin_memory=True,
+        drop_last=False,
+    )
+    
+    test_loader = torch.utils.data.DataLoader(
+        dataset=test_dataset,
+        batch_size=batch_size,
+        shuffle=False,
+        num_workers=num_workers,
+        pin_memory=True,
+        drop_last=False,
+    )
+    
+    train_feats, train_labels = _extract_feats(model, train_loader, device)
+    test_feats, test_labels = _extract_feats(model, test_loader, device)
+    
+    os.makedirs(CIFAR100FEATSDIR, exist_ok=True)
+    
+    train_path = os.path.join(CIFAR100FEATSDIR, "cifar100feats_train")
+    test_path = os.path.join(CIFAR100FEATSDIR, "cifar100feats_test")
+    
+    torch.save((train_feats, train_labels), train_path)
+    torch.save((test_feats, test_labels), test_path)
+    
+    print(f"Saved CIFAR100 Features (train) to {train_path}\nSize on Disk is {os.path.getsize(train_path)}")
+    print("")
+    print(f"Saved CIFAR100 Features (test) to {test_path}\nSize on Disk is {os.path.getsize(test_path)}")
+    
+def _extract_feats(model, loader, device):
+    feats = []
+    labels = []
+    
+    model.eval()
+    
+    with torch.no_grad():
+        for data, target in tqdm(loader, desc="Extracting Features"):
+            data = data.to(device)
+            feats.append(model(data).cpu())
+            labels.append(target)
+    
+    feats = torch.cat(feats)
+    labels = torch.cat(labels)
+    
+    return feats, labels
+    
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training")
+    parser.add_argument("--num_workers", type=int,default=4, help="Number of workers for the DataLoader")
+    parser.add_argument("--device", type=int, default=0, help="Device to Use")
+    parser.add_argument("--backbone_path", type=str, default="facebookresearch/dino:main", help="torch hub url for backbone")
+    parser.add_argument("--backbone_name", type=str, default="dino_vitb16", help="Name of the backbone model")
+    
+    args = parser.parse_args()
+    args.device = torch.device(f"cuda:{args.device}")
+    
+    prep_cifar100_feats(device=args.device, num_workers=args.num_workers, batch_size=args.batch_size, backbone_path=args.backbone_path, backbone_name=args.backbone_name)
\ No newline at end of file
diff --git a/entcl/models/model.py b/entcl/models/model.py
index bb4d0ac..3d5d9b4 100644
--- a/entcl/models/model.py
+++ b/entcl/models/model.py
@@ -24,6 +24,13 @@ class ENTCLModel(torch.nn.Module):
         feats = self.backbone(x)
         logits = self.head(feats)
         return logits, feats
+    
+    def forward_backbone(self, x):
+        return self.backbone(x)
+    
+    def forward_head(self, x):
+        return self.head(x)
+    
 
     def train(self, mode=True):
         super().train(mode)
diff --git a/entcl/pretrain.py b/entcl/pretrain.py
index 28a8701..3d44526 100644
--- a/entcl/pretrain.py
+++ b/entcl/pretrain.py
@@ -58,6 +58,7 @@ def pretrain(args, model):
     elif args.mode in ["pretrain", "both"]:
         logger.debug("No pretrained model to load, training from scratch")
         model = model.to(args.device)
+        
         for epoch in range(args.pretrain_epochs):
             
             logger.debug(f"Epoch {epoch} Started")
@@ -103,7 +104,7 @@ def _train(args, model, train_loader, optimiser, criterion):
         logger.debug(f"True Labels (Unique): {torch.unique(y, sorted=True)}")
         
         optimiser.zero_grad()
-        logits, _ = model(x)
+        logits = model.forward_head(x)
         logger.debug(f"Logits Shape: {logits.shape}")
             
         loss = criterion(logits, y)
@@ -126,7 +127,7 @@ def _validate(args, model, val_loader, criterion):
             logger.debug(f"x shape: {x.shape}, y shape: {y.shape}")
             logger.debug(f"True Labels (Unique): {torch.unique(y, sorted=True)}")
                 
-            logits, _ = model(x)
+            logits = model.forward_head(x)
             logger.debug(f"logits shape: {logits.shape}")
             
             loss = criterion(logits, y)
diff --git a/entcl/run.py b/entcl/run.py
index 269f8f8..5cc75c5 100644
--- a/entcl/run.py
+++ b/entcl/run.py
@@ -148,9 +148,9 @@ if __name__ == "__main__":
     
     # initialise dataset
     if args.dataset == 'cifar100':
-        from entcl.data.cifar100 import CIFAR100Dataset
+        from entcl.data.cifar100 import PartitionedCIFAR100FeaturesDataset
         args.novel_classes_per_session = (100 - args.known) // args.sessions
-        args.dataset = CIFAR100Dataset(known=args.known, pretrain_n_known=args.pretrain_n_known, cl_n_known=args.cl_n_known, cl_n_novel=args.cl_n_novel, cl_n_prevnovel=args.cl_n_prevnovel, sessions=5)
+        args.dataset = PartitionedCIFAR100FeaturesDataset(known=args.known, pretrain_n_known=args.pretrain_n_known, cl_n_known=args.cl_n_known, cl_n_novel=args.cl_n_novel, cl_n_prevnovel=args.cl_n_prevnovel, sessions=5)
         
         
     if args.head == 'linear':
diff --git a/entcl/utils/ncd.py b/entcl/utils/ncd.py
index 5bdb669..12ad6e6 100644
--- a/entcl/utils/ncd.py
+++ b/entcl/utils/ncd.py
@@ -1,6 +1,5 @@
 import os
-from typing import Dict, Union
-from entcl.data.util import TransformedTensorDataset
+from typing import Dict, Tuple, Union
 from loguru import logger
 import numpy as np
 import pandas as pd
@@ -14,13 +13,14 @@ from tqdm import tqdm
 
 
 def _extract_features(
-    args, dataset: TransformedTensorDataset, model: torch.nn.Module
+    args, dataset: torch.utils.data.Dataset, model: torch.nn.Module
 ) -> torch.Tensor:
     """
     Extracts features from the data in the dataset using the provided model's backbone.
     :param args: Arguments object with the attributes `device`, .
-    :param dataset: TransformedTensorDataset with the data to extract features from.
+    :param dataset: torch.utils.data.Dataset with the data to extract features from.
     """
+    raise DeprecationWarning("This function is not needed anymore, dataset should already be a feature dataset")
     logger.debug("Extracting Features")
     dataloader = torch.utils.data.DataLoader(
         dataset,
@@ -188,38 +188,37 @@ def _cluster_features(args, features: torch.Tensor) -> torch.Tensor:
 
 def find_novel_classes_for_session(
     args,
-    session_dataset: TransformedTensorDataset,
+    session_dataset: torch.utils.data.Dataset,
     model: torch.nn.Module,
-) -> Union[torch.Tensor, TransformedTensorDataset]:
+) -> Tuple[torch.utils.data.Dataset, Dict[int, int]]:
     """
     Finds the novel classes in the given session dataset using KMeans clustering and the given model
     :param args: Arguments object with the attributes `device`, `ncd_findk_method`, `novel_classes_per_session`, `seed`.
-    :param session_dataset: TransformedTensorDataset with the session data.
+    :param session_dataset: torch.utils.data.Dataset with the session data.
     :param model: torch.nn.Module with the model to use for clustering.
     :return: torch.Tensor with the pseudo-labels for the novel classes.
     """
 
-    # first we need to create a TransformedTensorDataset with only the predicted novel samples
+    # first we need to create a torch.utils.data.Dataset with only the predicted novel samples
     novel_samples_mask = (
-        session_dataset.tensor_dataset.tensors[3] == 1
+        session_dataset.tensors[3] == 1
     )  # OOD samples are marked with 2. the predicted types tensor (known/novel labels) is the third tensor in the dataset
     novel_samples_mask = novel_samples_mask.cpu()  # idk why this is not already on cpu
 
     novel_tensors = [
         tensor[novel_samples_mask].cpu()
-        for tensor in session_dataset.tensor_dataset.tensors
+        for tensor in session_dataset.tensors
     ]  # get only the predicted novel samples
 
     for i, t in enumerate(novel_tensors):
         logger.debug(f"Tensor[{i}] Shape: {t.shape}")
 
-    novel_dataset = TransformedTensorDataset(
-        tensor_dataset=torch.utils.data.TensorDataset(*novel_tensors),
-        transform=session_dataset.transform,
-    )
+    novel_dataset = torch.utils.data.TensorDataset(*novel_tensors)
 
     # extract features from the novel samples
-    novel_features = _extract_features(args, novel_dataset, model)
+    #novel_features = _extract_features(args, novel_dataset, model) # using a feature dataset so this is not needed
+    
+    novel_features = novel_dataset.tensors[0]  # the first tensor in the dataset is the data tensor
 
     # cluster the features
     pseudo_labels = _cluster_features(args, novel_features)
@@ -230,15 +229,15 @@ def find_novel_classes_for_session(
     
     # calculate the clustering accuracy (not used in the dataset, only for logging and testing)
     mapping = generate_mapping(
-        novel_dataset.tensor_dataset.tensors[1], pseudo_labels, args
+        novel_dataset.tensors[1], pseudo_labels, args
     )
 
     # next we need to align the pseudo labels tensor with the original dataset, giving known samples a pseudo label of -1
     pseudo_labels_aligned = torch.full(
-        session_dataset.tensor_dataset.tensors[2].shape,
+        session_dataset.tensors[2].shape,
         -1,
         dtype=torch.long,
-        device=session_dataset.tensor_dataset.tensors[2].device,
+        device=session_dataset.tensors[2].device,
     )
     pseudo_labels_aligned[novel_samples_mask] = (
         pseudo_labels  # whereever the sample is novel, assign the corresponding pseudo label.
@@ -247,11 +246,7 @@ def find_novel_classes_for_session(
     # create a new dataset with the pseudo labels alongside the original tensors and the same transform
     # NOTE: the tensors attribute of a TensorDataset is a tuple. we can't append to it so instead we create a new TensorDataset with the new pseudo labels tensor.
     # NOTE: The dataset at the end of NCD is in the for data, true labels, true type, predicted type, psuedo labels
-    new_dataset = TransformedTensorDataset(
-        tensor_dataset=torch.utils.data.TensorDataset(
-            *session_dataset.tensor_dataset.tensors, pseudo_labels_aligned.cpu()
-        ),
-        transform=session_dataset.transform,
-    )
-
+    new_dataset = torch.utils.data.TensorDataset(
+            *session_dataset.tensors, pseudo_labels_aligned.cpu()
+        )
     return new_dataset, mapping
diff --git a/entcl/utils/ood.py b/entcl/utils/ood.py
index aae758c..76c1e93 100644
--- a/entcl/utils/ood.py
+++ b/entcl/utils/ood.py
@@ -1,6 +1,5 @@
 import os
 from typing import Iterable, Tuple, Union
-from entcl.data.util import TransformedTensorDataset
 from entcl.utils.util import generate_unique_path
 from loguru import logger
 import numpy as np
@@ -10,7 +9,7 @@ import torch
 from tqdm import tqdm
 
 def _get_scores(
-    session_dataset: TransformedTensorDataset, model: torch.nn.Module, args
+    session_dataset: torch.utils.data.Dataset, model: torch.nn.Module, args
 ) -> Union[
     Tuple[torch.Tensor, torch.Tensor],
     Tuple[None, None],
@@ -45,7 +44,7 @@ def _get_scores(
     for x, _, _ in tqdm(session_loader, desc="Calculating Scores", leave=True, unit="batch"):
         x = x.to(args.device)
         with torch.no_grad():
-            logits, _ = model(x)
+            logits = model.forward_head(x)
             logits = logits[:, : args.dataset.known]  # TODO this is just a stopgap. the head needs to be expanded after all            
             if compute_entropy:
                 softmax = torch.nn.functional.softmax(logits, dim=1)
@@ -150,16 +149,16 @@ def _resolve_conflicts(
 
 def label_ood_for_session(
     args,
-    session_dataset: TransformedTensorDataset,
+    session_dataset: torch.utils.data.Dataset,
     model: torch.nn.Module,
-) -> Union[TransformedTensorDataset, torch.Tensor]:
+) -> torch.utils.data.Dataset:
     """
     OOD Labelling for a session dataset. This function computes entropy and/or energy scores for the dataset based on `args.ood_score` and fits a Gaussian Mixture Model to each of the scores. The GMM has 2 components, one for in-distribution samples and one for OOD samples. The function then resolves conflicts between the entropy and energy predictions by selecting the type with the highest confidence. Finally, the function returns a new dataset for the session, including the predicted types.
     :param args: Objects with the attributes `ood_score`, `ood_eps`, `seed` and `device` (Program Arguments).
     :param session_dataset: Dataset for the session.
     :param model: The model to evaluate.
     :param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted types.
-    :return: A TransformedTensorDataset with the predicted types or just a torch.Tensor of the predicted types.
+    :return: A torch.utils.data.Dataset with the predicted types
     """
     logger.debug("Starting OOD Labelling for Session")
 
@@ -201,14 +200,11 @@ def label_ood_for_session(
     
     logger.debug("Returning New Dataset")
     # return the new dataset with the predicted types
-    session_dataset = TransformedTensorDataset(
-        tensor_dataset=torch.utils.data.TensorDataset(
-            session_dataset.tensor_dataset.tensors[0],  # the data
-            session_dataset.tensor_dataset.tensors[1],  # the labels
-            session_dataset.tensor_dataset.tensors[2],  # the real types
+    session_dataset = torch.utils.data.TensorDataset(
+            session_dataset.tensors[0],  # the data
+            session_dataset.tensors[1],  # the labels
+            session_dataset.tensors[2],  # the real types
             final_predtypes.cpu(),  # the predicted types (duh)
-        ),
-        transform=session_dataset.transform,
     )
 
     # compute the OOD Accuracy
@@ -218,7 +214,7 @@ def label_ood_for_session(
 
 
 def _compute_ood_accuracy(
-    session_dataset: TransformedTensorDataset, entropies, energies, args
+    session_dataset: torch.utils.data.Dataset, entropies, energies, args
 ) -> None:
     """
     Computes the Accuracy of the OOD Labelling for a session dataset.
@@ -230,9 +226,9 @@ def _compute_ood_accuracy(
     # Create the DataFrame
     df = pd.DataFrame(
         {
-            "label": session_dataset.tensor_dataset.tensors[1].cpu().numpy(),
-            "type": session_dataset.tensor_dataset.tensors[2].cpu().numpy(),
-            "predtype": session_dataset.tensor_dataset.tensors[3].cpu().numpy(),
+            "label": session_dataset.tensors[1].cpu().numpy(),
+            "type": session_dataset.tensors[2].cpu().numpy(),
+            "predtype": session_dataset.tensors[3].cpu().numpy(),
         }
     )
     
diff --git a/experiments/experiments3.ipynb b/experiments/experiments3.ipynb
index ac9d55c..287d60c 100644
--- a/experiments/experiments3.ipynb
+++ b/experiments/experiments3.ipynb
@@ -25,160 +25,158 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
+   "execution_count": null,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "/root/miniconda3/envs/entcl/lib/python3.10/site-packages/torchvision/transforms/v2/_deprecated.py:41: UserWarning: The transform `ToTensor()` is deprecated and will be removed in a future release. Instead, please use `transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])`.\n",
-      "  warnings.warn(\n",
-      "\u001b[32m2024-12-02 15:25:35.075\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m53\u001b[0m - \u001b[34m\u001b[1mVerifying incremental learning settings\n",
+      "\u001b[32m2024-12-03 12:06:35.001\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m59\u001b[0m - \u001b[34m\u001b[1mVerifying incremental learning settings\n",
       "Known classes: 50\n",
       "Pretraining samples per known class: 400\n",
       "Samples per known class per CL session: 20\n",
       "Samples per novel class per CL session: 400\n",
       "Samples per previously novel class per CL session: 20\n",
       "CL sessions: 5\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:35.076\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m66\u001b[0m - \u001b[34m\u001b[1mDownload: False\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:35.077\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m69\u001b[0m - \u001b[34m\u001b[1mLoading and Sorting CIFAR100 Train split\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:35.002\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m72\u001b[0m - \u001b[34m\u001b[1mDownload: False\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:35.003\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m75\u001b[0m - \u001b[34m\u001b[1mLoading and Sorting CIFAR100 Train split\u001b[0m\n",
       "/root/miniconda3/envs/entcl/lib/python3.10/site-packages/torchvision/transforms/v2/_deprecated.py:41: UserWarning: The transform `ToTensor()` is deprecated and will be removed in a future release. Instead, please use `transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])`.\n",
       "  warnings.warn(\n",
-      "\u001b[32m2024-12-02 15:25:42.365\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m74\u001b[0m - \u001b[34m\u001b[1mSplitting Train Data for Sessions\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.366\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m152\u001b[0m - \u001b[34m\u001b[1mSplitting data for 5 sessions\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.367\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m155\u001b[0m - \u001b[34m\u001b[1mSplitting data for session 0 (pretraining)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.411\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m163\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 0 (pretraining). There are 20000 samples, and 20000 labels. There are 50 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.412\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m164\u001b[0m - \u001b[34m\u001b[1mClasses in Pretraining Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:41.400\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m80\u001b[0m - \u001b[34m\u001b[1mSplitting Train Data for Sessions\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.401\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m158\u001b[0m - \u001b[34m\u001b[1mSplitting data for 5 sessions\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.401\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m161\u001b[0m - \u001b[34m\u001b[1mSplitting data for session 0 (pretraining)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.440\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m169\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 0 (pretraining). There are 20000 samples, and 20000 labels. There are 50 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.442\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m170\u001b[0m - \u001b[34m\u001b[1mClasses in Pretraining Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.413\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m168\u001b[0m - \u001b[34m\u001b[1mSplitting data for CL sessions\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.414\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m170\u001b[0m - \u001b[34m\u001b[1mSplitting data for session 1\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.414\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m174\u001b[0m - \u001b[34m\u001b[1mThere are 50 known classes. Starting at 0 (inc), ending at 50 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.421\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m180\u001b[0m - \u001b[34m\u001b[1mThere are 10 novel classes. Starting at 50 (inc), ending at 60 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.436\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m203\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 1. There are 5000 samples, and 5000 labels. There are 60 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.438\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m204\u001b[0m - \u001b[34m\u001b[1mClasses in this Session 1's Train Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:41.443\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m174\u001b[0m - \u001b[34m\u001b[1mSplitting data for CL sessions\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.443\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m176\u001b[0m - \u001b[34m\u001b[1mSplitting data for session 1\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.443\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m180\u001b[0m - \u001b[34m\u001b[1mThere are 50 known classes. Starting at 0 (inc), ending at 50 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.450\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m186\u001b[0m - \u001b[34m\u001b[1mThere are 10 novel classes. Starting at 50 (inc), ending at 60 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.460\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m209\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 1. There are 5000 samples, and 5000 labels. There are 60 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.462\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m210\u001b[0m - \u001b[34m\u001b[1mClasses in this Session 1's Train Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.439\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m170\u001b[0m - \u001b[34m\u001b[1mSplitting data for session 2\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.440\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m174\u001b[0m - \u001b[34m\u001b[1mThere are 50 known classes. Starting at 0 (inc), ending at 50 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.446\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m180\u001b[0m - \u001b[34m\u001b[1mThere are 10 novel classes. Starting at 60 (inc), ending at 70 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.451\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m189\u001b[0m - \u001b[34m\u001b[1mThere are 10 previously novel classes. Starting at 50 (inc), ending at 60 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.463\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m203\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 2. There are 5200 samples, and 5200 labels. There are 70 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.465\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m204\u001b[0m - \u001b[34m\u001b[1mClasses in this Session 2's Train Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:41.463\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m176\u001b[0m - \u001b[34m\u001b[1mSplitting data for session 2\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.463\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m180\u001b[0m - \u001b[34m\u001b[1mThere are 50 known classes. Starting at 0 (inc), ending at 50 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.469\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m186\u001b[0m - \u001b[34m\u001b[1mThere are 10 novel classes. Starting at 60 (inc), ending at 70 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.473\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m195\u001b[0m - \u001b[34m\u001b[1mThere are 10 previously novel classes. Starting at 50 (inc), ending at 60 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.480\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m209\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 2. There are 5200 samples, and 5200 labels. There are 70 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.482\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m210\u001b[0m - \u001b[34m\u001b[1mClasses in this Session 2's Train Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.465\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m170\u001b[0m - \u001b[34m\u001b[1mSplitting data for session 3\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.466\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m174\u001b[0m - \u001b[34m\u001b[1mThere are 50 known classes. Starting at 0 (inc), ending at 50 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.471\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m180\u001b[0m - \u001b[34m\u001b[1mThere are 10 novel classes. Starting at 70 (inc), ending at 80 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.476\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m189\u001b[0m - \u001b[34m\u001b[1mThere are 10 previously novel classes. Starting at 50 (inc), ending at 70 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.490\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m203\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 3. There are 5400 samples, and 5400 labels. There are 80 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.492\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m204\u001b[0m - \u001b[34m\u001b[1mClasses in this Session 3's Train Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:41.483\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m176\u001b[0m - \u001b[34m\u001b[1mSplitting data for session 3\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.484\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m180\u001b[0m - \u001b[34m\u001b[1mThere are 50 known classes. Starting at 0 (inc), ending at 50 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.490\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m186\u001b[0m - \u001b[34m\u001b[1mThere are 10 novel classes. Starting at 70 (inc), ending at 80 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.495\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m195\u001b[0m - \u001b[34m\u001b[1mThere are 10 previously novel classes. Starting at 50 (inc), ending at 70 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.504\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m209\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 3. There are 5400 samples, and 5400 labels. There are 80 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.507\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m210\u001b[0m - \u001b[34m\u001b[1mClasses in this Session 3's Train Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,\n",
       "        72, 73, 74, 75, 76, 77, 78, 79])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.493\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m170\u001b[0m - \u001b[34m\u001b[1mSplitting data for session 4\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.494\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m174\u001b[0m - \u001b[34m\u001b[1mThere are 50 known classes. Starting at 0 (inc), ending at 50 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.498\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m180\u001b[0m - \u001b[34m\u001b[1mThere are 10 novel classes. Starting at 80 (inc), ending at 90 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.503\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m189\u001b[0m - \u001b[34m\u001b[1mThere are 10 previously novel classes. Starting at 50 (inc), ending at 80 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.518\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m203\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 4. There are 5600 samples, and 5600 labels. There are 90 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.520\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m204\u001b[0m - \u001b[34m\u001b[1mClasses in this Session 4's Train Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:41.508\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m176\u001b[0m - \u001b[34m\u001b[1mSplitting data for session 4\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.508\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m180\u001b[0m - \u001b[34m\u001b[1mThere are 50 known classes. Starting at 0 (inc), ending at 50 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.513\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m186\u001b[0m - \u001b[34m\u001b[1mThere are 10 novel classes. Starting at 80 (inc), ending at 90 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.517\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m195\u001b[0m - \u001b[34m\u001b[1mThere are 10 previously novel classes. Starting at 50 (inc), ending at 80 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.526\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m209\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 4. There are 5600 samples, and 5600 labels. There are 90 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.528\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m210\u001b[0m - \u001b[34m\u001b[1mClasses in this Session 4's Train Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,\n",
       "        72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.521\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m170\u001b[0m - \u001b[34m\u001b[1mSplitting data for session 5\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.521\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m174\u001b[0m - \u001b[34m\u001b[1mThere are 50 known classes. Starting at 0 (inc), ending at 50 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.525\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m180\u001b[0m - \u001b[34m\u001b[1mThere are 10 novel classes. Starting at 90 (inc), ending at 100 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.530\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m189\u001b[0m - \u001b[34m\u001b[1mThere are 10 previously novel classes. Starting at 50 (inc), ending at 90 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.548\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m203\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 5. There are 5800 samples, and 5800 labels. There are 100 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.549\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m204\u001b[0m - \u001b[34m\u001b[1mClasses in this Session 5's Train Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:41.528\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m176\u001b[0m - \u001b[34m\u001b[1mSplitting data for session 5\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.529\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m180\u001b[0m - \u001b[34m\u001b[1mThere are 50 known classes. Starting at 0 (inc), ending at 50 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.532\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m186\u001b[0m - \u001b[34m\u001b[1mThere are 10 novel classes. Starting at 90 (inc), ending at 100 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.537\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m195\u001b[0m - \u001b[34m\u001b[1mThere are 10 previously novel classes. Starting at 50 (inc), ending at 90 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.547\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m209\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 5. There are 5800 samples, and 5800 labels. There are 100 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:41.549\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_train_data_for_sessions\u001b[0m:\u001b[36m210\u001b[0m - \u001b[34m\u001b[1mClasses in this Session 5's Train Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,\n",
       "        72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,\n",
       "        90, 91, 92, 93, 94, 95, 96, 97, 98, 99])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:42.557\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m79\u001b[0m - \u001b[34m\u001b[1mLoading and Sorting CIFAR100 Test split\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.017\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m83\u001b[0m - \u001b[34m\u001b[1mSplitting Test Data for Sessions\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.018\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m219\u001b[0m - \u001b[34m\u001b[1mSplitting test data for session 0\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.024\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m230\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 0 (pretraining). There are 5000 samples, and 5000 labels. There are 50 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.025\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m231\u001b[0m - \u001b[34m\u001b[1mClasses in Session 0's Test Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:41.550\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m85\u001b[0m - \u001b[34m\u001b[1mLoading and Sorting CIFAR100 Test split\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:42.969\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m__init__\u001b[0m:\u001b[36m89\u001b[0m - \u001b[34m\u001b[1mSplitting Test Data for Sessions\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:42.970\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m225\u001b[0m - \u001b[34m\u001b[1mSplitting test data for session 0\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:42.976\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m236\u001b[0m - \u001b[34m\u001b[1mCreating dataset for session 0 (pretraining). There are 5000 samples, and 5000 labels. There are 50 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:42.978\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m237\u001b[0m - \u001b[34m\u001b[1mClasses in Session 0's Test Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.026\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m236\u001b[0m - \u001b[34m\u001b[1mSplitting test data for 5 sessions\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.027\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m239\u001b[0m - \u001b[34m\u001b[1mSplitting test data for session 1\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.027\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m244\u001b[0m - \u001b[34m\u001b[1mOld classes end at 50 (exc), New classes start at 50 (inc) and end at 60 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.031\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m257\u001b[0m - \u001b[34m\u001b[1mCreating OLD dataset for session 1. There are 5000 samples, and 5000 labels. There are 50 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.033\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m258\u001b[0m - \u001b[34m\u001b[1mClasses in Session 1's OLD Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:42.979\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m242\u001b[0m - \u001b[34m\u001b[1mSplitting test data for 5 sessions\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:42.980\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m245\u001b[0m - \u001b[34m\u001b[1mSplitting test data for session 1\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:42.980\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m250\u001b[0m - \u001b[34m\u001b[1mOld classes end at 50 (exc), New classes start at 50 (inc) and end at 60 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:42.986\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m263\u001b[0m - \u001b[34m\u001b[1mCreating OLD dataset for session 1. There are 5000 samples, and 5000 labels. There are 50 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:42.988\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m264\u001b[0m - \u001b[34m\u001b[1mClasses in Session 1's OLD Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.035\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m273\u001b[0m - \u001b[34m\u001b[1mCreating NEW dataset for session 1. There are 1000 samples, and 1000 labels. There are 10 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.036\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m274\u001b[0m - \u001b[34m\u001b[1mClasses in Session 1's NEW Dataset: tensor([50, 51, 52, 53, 54, 55, 56, 57, 58, 59])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.042\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m284\u001b[0m - \u001b[34m\u001b[1mCreating ALL dataset for session 1. There are 6000 samples, and 6000 labels. There are 60 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.044\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m285\u001b[0m - \u001b[34m\u001b[1mClasses in Session 1's ALL Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:42.990\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m279\u001b[0m - \u001b[34m\u001b[1mCreating NEW dataset for session 1. There are 1000 samples, and 1000 labels. There are 10 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:42.991\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m280\u001b[0m - \u001b[34m\u001b[1mClasses in Session 1's NEW Dataset: tensor([50, 51, 52, 53, 54, 55, 56, 57, 58, 59])\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:42.996\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m290\u001b[0m - \u001b[34m\u001b[1mCreating ALL dataset for session 1. There are 6000 samples, and 6000 labels. There are 60 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:42.997\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m291\u001b[0m - \u001b[34m\u001b[1mClasses in Session 1's ALL Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.044\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m239\u001b[0m - \u001b[34m\u001b[1mSplitting test data for session 2\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.045\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m244\u001b[0m - \u001b[34m\u001b[1mOld classes end at 60 (exc), New classes start at 60 (inc) and end at 70 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.054\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m257\u001b[0m - \u001b[34m\u001b[1mCreating OLD dataset for session 2. There are 6000 samples, and 6000 labels. There are 60 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.055\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m258\u001b[0m - \u001b[34m\u001b[1mClasses in Session 2's OLD Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:42.998\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m245\u001b[0m - \u001b[34m\u001b[1mSplitting test data for session 2\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:42.998\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m250\u001b[0m - \u001b[34m\u001b[1mOld classes end at 60 (exc), New classes start at 60 (inc) and end at 70 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.004\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m263\u001b[0m - \u001b[34m\u001b[1mCreating OLD dataset for session 2. There are 6000 samples, and 6000 labels. There are 60 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.005\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m264\u001b[0m - \u001b[34m\u001b[1mClasses in Session 2's OLD Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.057\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m273\u001b[0m - \u001b[34m\u001b[1mCreating NEW dataset for session 2. There are 1000 samples, and 1000 labels. There are 10 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.058\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m274\u001b[0m - \u001b[34m\u001b[1mClasses in Session 2's NEW Dataset: tensor([60, 61, 62, 63, 64, 65, 66, 67, 68, 69])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.063\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m284\u001b[0m - \u001b[34m\u001b[1mCreating ALL dataset for session 2. There are 7000 samples, and 7000 labels. There are 70 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.065\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m285\u001b[0m - \u001b[34m\u001b[1mClasses in Session 2's ALL Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:43.007\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m279\u001b[0m - \u001b[34m\u001b[1mCreating NEW dataset for session 2. There are 1000 samples, and 1000 labels. There are 10 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.008\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m280\u001b[0m - \u001b[34m\u001b[1mClasses in Session 2's NEW Dataset: tensor([60, 61, 62, 63, 64, 65, 66, 67, 68, 69])\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.013\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m290\u001b[0m - \u001b[34m\u001b[1mCreating ALL dataset for session 2. There are 7000 samples, and 7000 labels. There are 70 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.014\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m291\u001b[0m - \u001b[34m\u001b[1mClasses in Session 2's ALL Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.065\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m239\u001b[0m - \u001b[34m\u001b[1mSplitting test data for session 3\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.066\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m244\u001b[0m - \u001b[34m\u001b[1mOld classes end at 70 (exc), New classes start at 70 (inc) and end at 80 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.076\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m257\u001b[0m - \u001b[34m\u001b[1mCreating OLD dataset for session 3. There are 7000 samples, and 7000 labels. There are 70 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.078\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m258\u001b[0m - \u001b[34m\u001b[1mClasses in Session 3's OLD Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:43.015\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m245\u001b[0m - \u001b[34m\u001b[1mSplitting test data for session 3\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.015\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m250\u001b[0m - \u001b[34m\u001b[1mOld classes end at 70 (exc), New classes start at 70 (inc) and end at 80 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.026\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m263\u001b[0m - \u001b[34m\u001b[1mCreating OLD dataset for session 3. There are 7000 samples, and 7000 labels. There are 70 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.027\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m264\u001b[0m - \u001b[34m\u001b[1mClasses in Session 3's OLD Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.079\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m273\u001b[0m - \u001b[34m\u001b[1mCreating NEW dataset for session 3. There are 1000 samples, and 1000 labels. There are 10 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.080\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m274\u001b[0m - \u001b[34m\u001b[1mClasses in Session 3's NEW Dataset: tensor([70, 71, 72, 73, 74, 75, 76, 77, 78, 79])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.088\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m284\u001b[0m - \u001b[34m\u001b[1mCreating ALL dataset for session 3. There are 8000 samples, and 8000 labels. There are 80 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.089\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m285\u001b[0m - \u001b[34m\u001b[1mClasses in Session 3's ALL Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:43.029\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m279\u001b[0m - \u001b[34m\u001b[1mCreating NEW dataset for session 3. There are 1000 samples, and 1000 labels. There are 10 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.030\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m280\u001b[0m - \u001b[34m\u001b[1mClasses in Session 3's NEW Dataset: tensor([70, 71, 72, 73, 74, 75, 76, 77, 78, 79])\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.037\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m290\u001b[0m - \u001b[34m\u001b[1mCreating ALL dataset for session 3. There are 8000 samples, and 8000 labels. There are 80 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.038\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m291\u001b[0m - \u001b[34m\u001b[1mClasses in Session 3's ALL Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,\n",
       "        72, 73, 74, 75, 76, 77, 78, 79])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.090\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m239\u001b[0m - \u001b[34m\u001b[1mSplitting test data for session 4\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.090\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m244\u001b[0m - \u001b[34m\u001b[1mOld classes end at 80 (exc), New classes start at 80 (inc) and end at 90 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.102\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m257\u001b[0m - \u001b[34m\u001b[1mCreating OLD dataset for session 4. There are 8000 samples, and 8000 labels. There are 80 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.104\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m258\u001b[0m - \u001b[34m\u001b[1mClasses in Session 4's OLD Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:43.039\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m245\u001b[0m - \u001b[34m\u001b[1mSplitting test data for session 4\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.039\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m250\u001b[0m - \u001b[34m\u001b[1mOld classes end at 80 (exc), New classes start at 80 (inc) and end at 90 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.052\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m263\u001b[0m - \u001b[34m\u001b[1mCreating OLD dataset for session 4. There are 8000 samples, and 8000 labels. There are 80 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.054\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m264\u001b[0m - \u001b[34m\u001b[1mClasses in Session 4's OLD Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,\n",
       "        72, 73, 74, 75, 76, 77, 78, 79])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.106\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m273\u001b[0m - \u001b[34m\u001b[1mCreating NEW dataset for session 4. There are 1000 samples, and 1000 labels. There are 10 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.107\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m274\u001b[0m - \u001b[34m\u001b[1mClasses in Session 4's NEW Dataset: tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.114\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m284\u001b[0m - \u001b[34m\u001b[1mCreating ALL dataset for session 4. There are 9000 samples, and 9000 labels. There are 90 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.116\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m285\u001b[0m - \u001b[34m\u001b[1mClasses in Session 4's ALL Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:43.055\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m279\u001b[0m - \u001b[34m\u001b[1mCreating NEW dataset for session 4. There are 1000 samples, and 1000 labels. There are 10 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.056\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m280\u001b[0m - \u001b[34m\u001b[1mClasses in Session 4's NEW Dataset: tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89])\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.064\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m290\u001b[0m - \u001b[34m\u001b[1mCreating ALL dataset for session 4. There are 9000 samples, and 9000 labels. There are 90 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.065\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m291\u001b[0m - \u001b[34m\u001b[1mClasses in Session 4's ALL Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,\n",
       "        72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.117\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m239\u001b[0m - \u001b[34m\u001b[1mSplitting test data for session 5\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.117\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m244\u001b[0m - \u001b[34m\u001b[1mOld classes end at 90 (exc), New classes start at 90 (inc) and end at 100 (exc)\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.132\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m257\u001b[0m - \u001b[34m\u001b[1mCreating OLD dataset for session 5. There are 9000 samples, and 9000 labels. There are 90 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.134\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m258\u001b[0m - \u001b[34m\u001b[1mClasses in Session 5's OLD Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:43.066\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m245\u001b[0m - \u001b[34m\u001b[1mSplitting test data for session 5\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.066\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m250\u001b[0m - \u001b[34m\u001b[1mOld classes end at 90 (exc), New classes start at 90 (inc) and end at 100 (exc)\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.080\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m263\u001b[0m - \u001b[34m\u001b[1mCreating OLD dataset for session 5. There are 9000 samples, and 9000 labels. There are 90 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.082\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m264\u001b[0m - \u001b[34m\u001b[1mClasses in Session 5's OLD Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,\n",
       "        72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.136\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m273\u001b[0m - \u001b[34m\u001b[1mCreating NEW dataset for session 5. There are 1000 samples, and 1000 labels. There are 10 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.137\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m274\u001b[0m - \u001b[34m\u001b[1mClasses in Session 5's NEW Dataset: tensor([90, 91, 92, 93, 94, 95, 96, 97, 98, 99])\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.150\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m284\u001b[0m - \u001b[34m\u001b[1mCreating ALL dataset for session 5. There are 10000 samples, and 10000 labels. There are 100 different classes\u001b[0m\n",
-      "\u001b[32m2024-12-02 15:25:44.151\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m285\u001b[0m - \u001b[34m\u001b[1mClasses in Session 5's ALL Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
+      "\u001b[32m2024-12-03 12:06:43.084\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m279\u001b[0m - \u001b[34m\u001b[1mCreating NEW dataset for session 5. There are 1000 samples, and 1000 labels. There are 10 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.084\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m280\u001b[0m - \u001b[34m\u001b[1mClasses in Session 5's NEW Dataset: tensor([90, 91, 92, 93, 94, 95, 96, 97, 98, 99])\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.093\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m290\u001b[0m - \u001b[34m\u001b[1mCreating ALL dataset for session 5. There are 10000 samples, and 10000 labels. There are 100 different classes\u001b[0m\n",
+      "\u001b[32m2024-12-03 12:06:43.094\u001b[0m | \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mentcl.data.cifar100\u001b[0m:\u001b[36m_split_test_data_for_sessions\u001b[0m:\u001b[36m291\u001b[0m - \u001b[34m\u001b[1mClasses in Session 5's ALL Dataset: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
       "        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,\n",
       "        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,\n",
@@ -192,20 +190,20 @@
     "seed(8008135)\n",
     "from entcl.models.model import ENTCLModel\n",
     "from entcl.models.linear_head import LinearHead\n",
-    "from entcl.data.cifar100 import CIFAR100Dataset\n",
+    "from entcl.data.cifar100feats import CIFAR100FeatureDataset\n",
     "import torch\n",
     "import pandas as pd\n",
     "import numpy as np\n",
     "from tqdm.notebook import tqdm\n",
     "\n",
-    "device = torch.device('cpu')\n",
+    "device = torch.device('cuda:0')\n",
     "eps = 1e-8\n",
     "\n",
     "pretrained_model = ENTCLModel(LinearHead(768, 50), backbone_version=1)\n",
     "pretrained_model.head.load_state_dict(torch.load('/cl/entcl_LFS/experiments/dino_nosched_bb/session_0/head_s0_ep99.pt'))\n",
     "pretrained_model = pretrained_model.to(device)\n",
     "\n",
-    "dataset_master = CIFAR100Dataset()\n",
+    "dataset_master = CIFAR100FeatureDataset(backbone=)\n",
     "dataset = dataset_master.get_dataset(1)"
    ]
   },
@@ -218,13 +216,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 14,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "119c4a9fd83d4d23a73c1af59a9d43d4",
+       "model_id": "cba9e1a8a0eb49eb881081a14a2fd8ca",
        "version_major": 2,
        "version_minor": 0
       },
@@ -234,6 +232,20 @@
      },
      "metadata": {},
      "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/root/miniconda3/envs/entcl/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
+      "  warnings.warn(\n",
+      "/root/miniconda3/envs/entcl/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
+      "  warnings.warn(\n",
+      "/root/miniconda3/envs/entcl/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
+      "  warnings.warn(\n",
+      "/root/miniconda3/envs/entcl/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
+      "  warnings.warn(\n"
+     ]
     }
    ],
    "source": [
@@ -295,12 +307,12 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 15,
    "metadata": {},
    "outputs": [
     {
      "data": {
-      "image/png": "",
+      "image/png": "",
       "text/plain": [
        "<Figure size 1000x600 with 1 Axes>"
       ]
@@ -335,12 +347,12 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 16,
    "metadata": {},
    "outputs": [
     {
      "data": {
-      "image/png": "",
+      "image/png": "",
       "text/plain": [
        "<Figure size 1000x600 with 1 Axes>"
       ]
@@ -368,13 +380,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "926a9415c404460c9a45984d3350c36e",
+       "model_id": "0299d4800a744441ad923b5d9530a21c",
        "version_major": 2,
        "version_minor": 0
       },
@@ -385,11 +397,25 @@
      "metadata": {},
      "output_type": "display_data"
     },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/root/miniconda3/envs/entcl/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
+      "  warnings.warn(\n",
+      "/root/miniconda3/envs/entcl/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
+      "  warnings.warn(\n",
+      "/root/miniconda3/envs/entcl/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
+      "  warnings.warn(\n",
+      "/root/miniconda3/envs/entcl/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
+      "  warnings.warn(\n"
+     ]
+    },
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Accuracy of the model on the testset: 85.20%\n"
+      "Accuracy of the model on the testset: 86.96%\n"
      ]
     }
    ],
@@ -427,27 +453,9 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 18,
    "metadata": {},
-   "outputs": [
-    {
-     "ename": "ValueError",
-     "evalue": "Input X contains infinity or a value too large for dtype('float32').",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
-      "Cell \u001b[0;32mIn[7], line 21\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[38;5;66;03m# GMM 2: energy\u001b[39;00m\n\u001b[1;32m     20\u001b[0m energy_gmm \u001b[38;5;241m=\u001b[39m GaussianMixture(n_components\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, random_state\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m8008135\u001b[39m, max_iter\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1000\u001b[39m, init_params\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mk-means++\u001b[39m\u001b[38;5;124m'\u001b[39m, tol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-4\u001b[39m)\n\u001b[0;32m---> 21\u001b[0m df[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124menergy_cluster\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43menergy_gmm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_predict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdf\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43menergy\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreshape\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     22\u001b[0m soft_clusters \u001b[38;5;241m=\u001b[39m energy_gmm\u001b[38;5;241m.\u001b[39mpredict_proba(df[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124menergy\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mvalues\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m))\n\u001b[1;32m     24\u001b[0m mean_0 \u001b[38;5;241m=\u001b[39m df[df[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124menergy_cluster\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124menergy\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mmean()\n",
-      "File \u001b[0;32m~/miniconda3/envs/entcl/lib/python3.10/site-packages/sklearn/base.py:1473\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1466\u001b[0m     estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[1;32m   1468\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m   1469\u001b[0m     skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m   1470\u001b[0m         prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m   1471\u001b[0m     )\n\u001b[1;32m   1472\u001b[0m ):\n\u001b[0;32m-> 1473\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
-      "File \u001b[0;32m~/miniconda3/envs/entcl/lib/python3.10/site-packages/sklearn/mixture/_base.py:212\u001b[0m, in \u001b[0;36mBaseMixture.fit_predict\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m    184\u001b[0m \u001b[38;5;129m@_fit_context\u001b[39m(prefer_skip_nested_validation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m    185\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit_predict\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, y\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m    186\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"Estimate model parameters using X and predict the labels for X.\u001b[39;00m\n\u001b[1;32m    187\u001b[0m \n\u001b[1;32m    188\u001b[0m \u001b[38;5;124;03m    The method fits the model n_init times and sets the parameters with\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    210\u001b[0m \u001b[38;5;124;03m        Component labels.\u001b[39;00m\n\u001b[1;32m    211\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 212\u001b[0m     X \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat64\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat32\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mensure_min_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    213\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m X\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_components:\n\u001b[1;32m    214\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m    215\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpected n_samples >= n_components \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    216\u001b[0m             \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbut got n_components = \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_components\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    217\u001b[0m             \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_samples = \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mX\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    218\u001b[0m         )\n",
-      "File \u001b[0;32m~/miniconda3/envs/entcl/lib/python3.10/site-packages/sklearn/base.py:633\u001b[0m, in \u001b[0;36mBaseEstimator._validate_data\u001b[0;34m(self, X, y, reset, validate_separately, cast_to_ndarray, **check_params)\u001b[0m\n\u001b[1;32m    631\u001b[0m         out \u001b[38;5;241m=\u001b[39m X, y\n\u001b[1;32m    632\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m no_val_X \u001b[38;5;129;01mand\u001b[39;00m no_val_y:\n\u001b[0;32m--> 633\u001b[0m     out \u001b[38;5;241m=\u001b[39m \u001b[43mcheck_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mX\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mcheck_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    634\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m no_val_X \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m no_val_y:\n\u001b[1;32m    635\u001b[0m     out \u001b[38;5;241m=\u001b[39m _check_y(y, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mcheck_params)\n",
-      "File \u001b[0;32m~/miniconda3/envs/entcl/lib/python3.10/site-packages/sklearn/utils/validation.py:1064\u001b[0m, in \u001b[0;36mcheck_array\u001b[0;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)\u001b[0m\n\u001b[1;32m   1058\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m   1059\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFound array with dim \u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m expected <= 2.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1060\u001b[0m         \u001b[38;5;241m%\u001b[39m (array\u001b[38;5;241m.\u001b[39mndim, estimator_name)\n\u001b[1;32m   1061\u001b[0m     )\n\u001b[1;32m   1063\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m force_all_finite:\n\u001b[0;32m-> 1064\u001b[0m     \u001b[43m_assert_all_finite\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1065\u001b[0m \u001b[43m        \u001b[49m\u001b[43marray\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1066\u001b[0m \u001b[43m        \u001b[49m\u001b[43minput_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1067\u001b[0m \u001b[43m        \u001b[49m\u001b[43mestimator_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mestimator_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1068\u001b[0m \u001b[43m        \u001b[49m\u001b[43mallow_nan\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_all_finite\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mallow-nan\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1069\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1071\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m copy:\n\u001b[1;32m   1072\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m _is_numpy_namespace(xp):\n\u001b[1;32m   1073\u001b[0m         \u001b[38;5;66;03m# only make a copy if `array` and `array_orig` may share memory`\u001b[39;00m\n",
-      "File \u001b[0;32m~/miniconda3/envs/entcl/lib/python3.10/site-packages/sklearn/utils/validation.py:123\u001b[0m, in \u001b[0;36m_assert_all_finite\u001b[0;34m(X, allow_nan, msg_dtype, estimator_name, input_name)\u001b[0m\n\u001b[1;32m    120\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m first_pass_isfinite:\n\u001b[1;32m    121\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[0;32m--> 123\u001b[0m \u001b[43m_assert_all_finite_element_wise\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    124\u001b[0m \u001b[43m    \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    125\u001b[0m \u001b[43m    \u001b[49m\u001b[43mxp\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mxp\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    126\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_nan\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mallow_nan\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    127\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmsg_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmsg_dtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    128\u001b[0m \u001b[43m    \u001b[49m\u001b[43mestimator_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mestimator_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    129\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    130\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
-      "File \u001b[0;32m~/miniconda3/envs/entcl/lib/python3.10/site-packages/sklearn/utils/validation.py:172\u001b[0m, in \u001b[0;36m_assert_all_finite_element_wise\u001b[0;34m(X, xp, allow_nan, msg_dtype, estimator_name, input_name)\u001b[0m\n\u001b[1;32m    155\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m estimator_name \u001b[38;5;129;01mand\u001b[39;00m input_name \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m has_nan_error:\n\u001b[1;32m    156\u001b[0m     \u001b[38;5;66;03m# Improve the error message on how to handle missing values in\u001b[39;00m\n\u001b[1;32m    157\u001b[0m     \u001b[38;5;66;03m# scikit-learn.\u001b[39;00m\n\u001b[1;32m    158\u001b[0m     msg_err \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m    159\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mestimator_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m does not accept missing values\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    160\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m encoded as NaN natively. For supervised learning, you might want\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    170\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m#estimators-that-handle-nan-values\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    171\u001b[0m     )\n\u001b[0;32m--> 172\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(msg_err)\n",
-      "\u001b[0;31mValueError\u001b[0m: Input X contains infinity or a value too large for dtype('float32')."
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "from sklearn.mixture import GaussianMixture\n",
     "# GMM 1: entropy\n",
@@ -493,23 +501,23 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 19,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Entropy-based Known Correct/Total (Accuracy%) 574/1000 (57.4000%)\n",
-      "Entropy-based Novel Correct/Total (Accuracy%) 3704/4000 (92.6000%)\n",
+      "Entropy-based Known Correct/Total (Accuracy%) 538/1000 (53.8000%)\n",
+      "Entropy-based Novel Correct/Total (Accuracy%) 3278/4000 (81.9500%)\n",
       "\n",
-      "Energy-based Known Correct/Total (Accuracy%) 891/1000 (89.1000%)\n",
-      "Energy-based Novel Correct/Total (Accuracy%) 2639/4000 (65.9750%)\n"
+      "Energy-based Known Correct/Total (Accuracy%) 801/1000 (80.1000%)\n",
+      "Energy-based Novel Correct/Total (Accuracy%) 2369/4000 (59.2250%)\n"
      ]
     },
     {
      "data": {
-      "image/png": "",
+      "image/png": "",
       "text/plain": [
        "<Figure size 1000x600 with 1 Axes>"
       ]
@@ -519,7 +527,7 @@
     },
     {
      "data": {
-      "image/png": "",
+      "image/png": "",
       "text/plain": [
        "<Figure size 1000x600 with 1 Axes>"
       ]
@@ -562,15 +570,15 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 20,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Known Correct/Total (Accuracy%) 686/1000 (68.6000%)\n",
-      "Novel Correct/Total (Accuracy%) 3504/4000 (87.6000%)\n"
+      "Known Correct/Total (Accuracy%) 617/1000 (61.7000%)\n",
+      "Novel Correct/Total (Accuracy%) 3093/4000 (77.3250%)\n"
      ]
     }
    ],
@@ -608,15 +616,15 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 21,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Known Correct/Total (Accuracy%) 610/1000 (61.0000%)\n",
-      "Novel Correct/Total (Accuracy%) 3640/4000 (91.0000%)\n"
+      "Known Correct/Total (Accuracy%) 582/1000 (58.2000%)\n",
+      "Novel Correct/Total (Accuracy%) 3154/4000 (78.8500%)\n"
      ]
     }
    ],
@@ -675,7 +683,27 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Means 0, -1, 1: 0.06324542313814163, 0.5723875164985657, 1.943918228149414\n"
+     ]
+    },
+    {
+     "ename": "ValueError",
+     "evalue": "The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
+      "\u001b[0;32m/tmp/ipykernel_1131/4139231074.py\u001b[0m in \u001b[0;36m?\u001b[0;34m()\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"cluster\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'cluster'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrename_mapping\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Means 0, -1, 1: {df[df['cluster'] == 0]['entropy'].mean()}, {df[df['cluster'] == -1]['entropy'].mean()}, {df[df['cluster'] == 1]['entropy'].mean()}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mknown\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdf\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"true_type\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"cluster\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m \u001b[0mnovel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdf\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"true_type\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"cluster\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     21\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Non -1 Known Correct/Total (Accuracy%) {known[known['cluster'] == 0].shape[0]}/{known.shape[0]} ({ 100 * (known[known['cluster'] == 0].shape[0]/known.shape[0]):.4f}%)\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/miniconda3/envs/entcl/lib/python3.10/site-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1575\u001b[0m     \u001b[0;34m@\u001b[0m\u001b[0mfinal\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1576\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__nonzero__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mNoReturn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1577\u001b[0;31m         raise ValueError(\n\u001b[0m\u001b[1;32m   1578\u001b[0m             \u001b[0;34mf\"The truth value of a {type(self).__name__} is ambiguous. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1579\u001b[0m             \u001b[0;34m\"Use a.empty, a.bool(), a.item(), a.any() or a.all().\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1580\u001b[0m         )\n",
+      "\u001b[0;31mValueError\u001b[0m: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
+     ]
+    }
+   ],
    "source": [
     "df = master_df.copy()\n",
     "\n",
@@ -685,7 +713,7 @@
     "df[\"cluster\"] = gmm.fit_predict(df['entropy'].values.reshape(-1, 1))\n",
     "soft_clusters = gmm.predict_proba(df['entropy'].values.reshape(-1, 1))\n",
     "\n",
-    "cluster_means = df.group_by('cluster')['entropy'].mean()\n",
+    "cluster_means = df.groupby('cluster')['entropy'].mean()\n",
     "sorted_clusters = cluster_means.sort_values().index\n",
     "\n",
     "rename_mapping = {sorted_clusters[0]: 0, sorted_clusters[1]: -1, sorted_clusters[2]: 1}\n",
@@ -694,8 +722,10 @@
     "\n",
     "print(f\"Means 0, -1, 1: {df[df['cluster'] == 0]['entropy'].mean()}, {df[df['cluster'] == -1]['entropy'].mean()}, {df[df['cluster'] == 1]['entropy'].mean()}\")\n",
     "\n",
-    "known = df[df[\"true_type\"] == 0 and df[\"cluster\"] != -1]\n",
-    "novel = df[df[\"true_type\"] == 1 and df[\"cluster\"] != -1]\n",
+    "known = df[df[\"true_type\"] == 0]\n",
+    "known = known[known[\"cluster\"] != -1]\n",
+    "novel = df[df[\"true_type\"] == 1]\n",
+    "novel = novel[novel[\"cluster\"] != -1]\n",
     "\n",
     "print(f\"Non -1 Known Correct/Total (Accuracy%) {known[known['cluster'] == 0].shape[0]}/{known.shape[0]} ({ 100 * (known[known['cluster'] == 0].shape[0]/known.shape[0]):.4f}%)\")\n",
     "print(f\"Non -1 Novel Correct/Total (Accuracy%) {novel[novel['cluster'] == 1].shape[0]}/{novel.shape[0]} ({ 100 * (novel[novel['cluster'] == 1].shape[0]/novel.shape[0]):.4f}%)\")\n",
@@ -725,7 +755,7 @@
     "# for the feature distance sorting of cluster -1, we need an exemplar set. We will randomly sample 32 images per class from the session 0 training dataset\n",
     "session_0_trainset = dataset_master.get_dataset(session=0)\n",
     "\n",
-    "labels_np = session_0_trainset.tensor_dataset.tensors[1].cpu().numpy()\n",
+    "labels_np = session_0_trainset.tensors[1].cpu().numpy()\n",
     "\n",
     "samples_per_label = 32\n",
     "\n",
@@ -737,7 +767,7 @@
     "    label_indices = np.where(labels_np == label)[0]\n",
     "    sample_indices.extend(np.random.choice(label_indices, samples_per_label, replace=False))\n",
     "\n",
-    "subset = torch.utils.data.Subset(session_0_trainset.tensor_dataset, sample_indices)\n",
+    "subset = torch.utils.data.Subset(session_0_trainset, sample_indices)\n",
     "subset_loader = torch.utils.data.DataLoader(subset, batch_size=512, shuffle=False, num_workers=4, pin_memory=True)\n",
     "\n",
     "# get the features, logits, entropies and energies of the exemplar set\n",
-- 
GitLab