diff --git a/entcl/data/cifar100.py b/entcl/data/cifar100.py
index eba22ed0db17dce28a7d022d92637f2cf24db6f0..bf52d3bb49e8fda2abf72bd0ff222cd9fd5e6277 100644
--- a/entcl/data/cifar100.py
+++ b/entcl/data/cifar100.py
@@ -1,5 +1,6 @@
 import os
 from typing import Dict, List, Union
+from entcl.data.util import TransformedTensorDataset
 import torch
 from torchvision.datasets import CIFAR100 as _CIFAR100
 import torchvision.transforms.v2 as transforms
@@ -65,7 +66,7 @@ class CIFAR100Dataset:
         # load and sort the data into master lists
         logger.debug("Loading and Sorting CIFAR100 Train split")
         master_train_data: Dict[int, torch.Tensor] = self._split_data_by_class(_CIFAR100(
-            CIFAR100_DIR, train=True, transform=self.transform, download=download
+            CIFAR100_DIR, train=True, transform=transforms.ToTensor(), download=download
         ))
         # split the data into datasets for each session
         logger.debug("Splitting Train Data for Sessions")
@@ -75,7 +76,7 @@ class CIFAR100Dataset:
         
         logger.debug("Loading and Sorting CIFAR100 Test split")
         master_test_data: Dict[int, torch.Tensor] = self._split_data_by_class(_CIFAR100(
-            CIFAR100_DIR, train=False, transform=self.transform, download=download
+            CIFAR100_DIR, train=False,transform=transforms.ToTensor(), download=download
         ))
         logger.debug("Splitting Test Data for Sessions")
         self.test_datasets = self._split_test_data_for_sessions(master_test_data)
@@ -158,7 +159,7 @@ class CIFAR100Dataset:
         labels = torch.cat(labels)
         logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
         logger.debug(f"Classes in Pretraining Dataset: {labels.unique(sorted=True)}")
-        datasets[0] = torch.utils.data.TensorDataset(samples, labels)
+        datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
         
         # CL sessions' datasets
         logger.debug("Splitting data for CL sessions")    
@@ -191,7 +192,7 @@ class CIFAR100Dataset:
             labels = torch.cat(labels)
             logger.debug(f"Creating dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
             logger.debug(f"Classes in this Session {session}'s Train Dataset: {labels.unique(sorted=True)}")
-            datasets[session] = torch.utils.data.TensorDataset(samples, labels)
+            datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
 
         return datasets
     
@@ -219,22 +220,30 @@ class CIFAR100Dataset:
             labels = torch.cat(labels)
             logger.debug(f"Creating test dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
             logger.debug(f"Classes in Session {session}'s Test Dataset: {labels.unique(sorted=True)}")
-            datasets[session] = torch.utils.data.TensorDataset(samples, labels)
+            datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
             
         return datasets
                 
     
-    def _split_data_by_class(self, dataset: _CIFAR100):
-        # loop through the dataset and split the data by class
-        all_data = {} 
-        for data, label in dataset:
-            if label not in all_data:
-                all_data[label] = []
-            all_data[label].append(data)
-            
-        # stack the data for each key
+    def _split_data_by_class(self, dataset: _CIFAR100, batch_size=64, num_workers=0):
+        # Create a DataLoader to load the dataset in batches
+        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
+
+        # Dictionary to store data split by class
+        all_data = {}
+
+        # Iterate over the dataloader
+        for data_batch, labels_batch in dataloader:
+            for data, label in zip(data_batch, labels_batch):
+                label = label.item()  # Convert tensor label to Python int for key lookup
+                if label not in all_data:
+                    all_data[label] = []
+                all_data[label].append(data)
+
+        # Stack the data for each key
         for class_id in sorted(all_data.keys()):
             all_data[class_id] = torch.stack(all_data[class_id])
+
         return all_data
     
         
@@ -282,9 +291,12 @@ if __name__ == "__main__":
     cifar100 = CIFAR100Dataset()
     for session in range(cifar100.sessions + 1):
         logger.debug(f"Session {session}")
-        logger.debug(f"Train Dataset: {cifar100.get_dataset(session, train=True)}")
-        logger.debug(f"Test Dataset: {cifar100.get_dataset(session, train=False)}")
-    
+        train = cifar100.get_dataset(session, train=True)
+        test = cifar100.get_dataset(session, train=False)
+        logger.debug(f"Train Len: {len(train)}, Image Shape {train[0][0].shape}, Label Shape {train[0][1].shape}")
+        logger.debug(f"Test Len: {len(test)}, Image Shape {test[0][0].shape}, Label Shape {test[0][1].shape}")
+        
+
 
     sleep(5)
     
\ No newline at end of file
diff --git a/entcl/data/test.py b/entcl/data/test.py
index 36e28723799933dd5301c4384d306423acc7bd36..f5ebf2ded52a804b6ef1a97954eb45bdf2032eaa 100644
--- a/entcl/data/test.py
+++ b/entcl/data/test.py
@@ -9,7 +9,7 @@ t = transforms.Compose(
 )
 
 train = _CIFAR100(
-    "/cl/datasets/CIFAR", train=True, transform=t, download=True
+    "/cl/datasets/CIFAR", train=True, transform=transforms.ToTensor(), download=True
 )
 
 first_sample, first_label = train[0]
diff --git a/entcl/data/util.py b/entcl/data/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..23dc10818d04cb4bf3859620572be4cad45a732b
--- /dev/null
+++ b/entcl/data/util.py
@@ -0,0 +1,14 @@
+import torch
+class TransformedTensorDataset(torch.utils.data.Dataset):
+    def __init__(self, tensor_dataset, transform=None):
+        self.tensor_dataset = tensor_dataset
+        self.transform = transform
+
+    def __len__(self):
+        return len(self.tensor_dataset)
+
+    def __getitem__(self, idx):
+        data, target = self.tensor_dataset[idx]
+        if self.transform:
+            data = self.transform(data)
+        return data, target
diff --git a/entcl/models/linear_head.py b/entcl/models/linear_head.py
index 1e787990abf1dcaf8a697446833320b8160aabf7..755e0f92f63a32282ce60e2e0da6d7fcfdf8f27e 100644
--- a/entcl/models/linear_head.py
+++ b/entcl/models/linear_head.py
@@ -25,7 +25,8 @@ class LinearHead2(torch.nn.Module):
         for layer in m:
             if isinstance(layer, torch.nn.Linear):
                 torch.nn.init.kaiming_normal_(layer.weight)
-                torch.nn.init.zeros_(layer.bias)
+                if layer.bias is not None:
+                    torch.nn.init.zeros_(layer.bias)
     
     def forward(self, x):
         return self.mlp(x)
\ No newline at end of file