From db348837a9cf22c774a195b55d5061dbf40ea3d0 Mon Sep 17 00:00:00 2001
From: Joseph Omar <j.omar@soton.ac.uk>
Date: Mon, 11 Nov 2024 15:21:30 +0000
Subject: [PATCH] somehow it is really good?

---
 entcl/data/cifar100.py      | 46 +++++++++++++++++++++++--------------
 entcl/data/test.py          |  2 +-
 entcl/data/util.py          | 14 +++++++++++
 entcl/models/linear_head.py |  3 ++-
 4 files changed, 46 insertions(+), 19 deletions(-)
 create mode 100644 entcl/data/util.py

diff --git a/entcl/data/cifar100.py b/entcl/data/cifar100.py
index eba22ed..bf52d3b 100644
--- a/entcl/data/cifar100.py
+++ b/entcl/data/cifar100.py
@@ -1,5 +1,6 @@
 import os
 from typing import Dict, List, Union
+from entcl.data.util import TransformedTensorDataset
 import torch
 from torchvision.datasets import CIFAR100 as _CIFAR100
 import torchvision.transforms.v2 as transforms
@@ -65,7 +66,7 @@ class CIFAR100Dataset:
         # load and sort the data into master lists
         logger.debug("Loading and Sorting CIFAR100 Train split")
         master_train_data: Dict[int, torch.Tensor] = self._split_data_by_class(_CIFAR100(
-            CIFAR100_DIR, train=True, transform=self.transform, download=download
+            CIFAR100_DIR, train=True, transform=transforms.ToTensor(), download=download
         ))
         # split the data into datasets for each session
         logger.debug("Splitting Train Data for Sessions")
@@ -75,7 +76,7 @@ class CIFAR100Dataset:
         
         logger.debug("Loading and Sorting CIFAR100 Test split")
         master_test_data: Dict[int, torch.Tensor] = self._split_data_by_class(_CIFAR100(
-            CIFAR100_DIR, train=False, transform=self.transform, download=download
+            CIFAR100_DIR, train=False,transform=transforms.ToTensor(), download=download
         ))
         logger.debug("Splitting Test Data for Sessions")
         self.test_datasets = self._split_test_data_for_sessions(master_test_data)
@@ -158,7 +159,7 @@ class CIFAR100Dataset:
         labels = torch.cat(labels)
         logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
         logger.debug(f"Classes in Pretraining Dataset: {labels.unique(sorted=True)}")
-        datasets[0] = torch.utils.data.TensorDataset(samples, labels)
+        datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
         
         # CL sessions' datasets
         logger.debug("Splitting data for CL sessions")    
@@ -191,7 +192,7 @@ class CIFAR100Dataset:
             labels = torch.cat(labels)
             logger.debug(f"Creating dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
             logger.debug(f"Classes in this Session {session}'s Train Dataset: {labels.unique(sorted=True)}")
-            datasets[session] = torch.utils.data.TensorDataset(samples, labels)
+            datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
 
         return datasets
     
@@ -219,22 +220,30 @@ class CIFAR100Dataset:
             labels = torch.cat(labels)
             logger.debug(f"Creating test dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
             logger.debug(f"Classes in Session {session}'s Test Dataset: {labels.unique(sorted=True)}")
-            datasets[session] = torch.utils.data.TensorDataset(samples, labels)
+            datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
             
         return datasets
                 
     
-    def _split_data_by_class(self, dataset: _CIFAR100):
-        # loop through the dataset and split the data by class
-        all_data = {} 
-        for data, label in dataset:
-            if label not in all_data:
-                all_data[label] = []
-            all_data[label].append(data)
-            
-        # stack the data for each key
+    def _split_data_by_class(self, dataset: _CIFAR100, batch_size=64, num_workers=0):
+        # Create a DataLoader to load the dataset in batches
+        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
+
+        # Dictionary to store data split by class
+        all_data = {}
+
+        # Iterate over the dataloader
+        for data_batch, labels_batch in dataloader:
+            for data, label in zip(data_batch, labels_batch):
+                label = label.item()  # Convert tensor label to Python int for key lookup
+                if label not in all_data:
+                    all_data[label] = []
+                all_data[label].append(data)
+
+        # Stack the data for each key
         for class_id in sorted(all_data.keys()):
             all_data[class_id] = torch.stack(all_data[class_id])
+
         return all_data
     
         
@@ -282,9 +291,12 @@ if __name__ == "__main__":
     cifar100 = CIFAR100Dataset()
     for session in range(cifar100.sessions + 1):
         logger.debug(f"Session {session}")
-        logger.debug(f"Train Dataset: {cifar100.get_dataset(session, train=True)}")
-        logger.debug(f"Test Dataset: {cifar100.get_dataset(session, train=False)}")
-    
+        train = cifar100.get_dataset(session, train=True)
+        test = cifar100.get_dataset(session, train=False)
+        logger.debug(f"Train Len: {len(train)}, Image Shape {train[0][0].shape}, Label Shape {train[0][1].shape}")
+        logger.debug(f"Test Len: {len(test)}, Image Shape {test[0][0].shape}, Label Shape {test[0][1].shape}")
+        
+
 
     sleep(5)
     
\ No newline at end of file
diff --git a/entcl/data/test.py b/entcl/data/test.py
index 36e2872..f5ebf2d 100644
--- a/entcl/data/test.py
+++ b/entcl/data/test.py
@@ -9,7 +9,7 @@ t = transforms.Compose(
 )
 
 train = _CIFAR100(
-    "/cl/datasets/CIFAR", train=True, transform=t, download=True
+    "/cl/datasets/CIFAR", train=True, transform=transforms.ToTensor(), download=True
 )
 
 first_sample, first_label = train[0]
diff --git a/entcl/data/util.py b/entcl/data/util.py
new file mode 100644
index 0000000..23dc108
--- /dev/null
+++ b/entcl/data/util.py
@@ -0,0 +1,14 @@
+import torch
+class TransformedTensorDataset(torch.utils.data.Dataset):
+    def __init__(self, tensor_dataset, transform=None):
+        self.tensor_dataset = tensor_dataset
+        self.transform = transform
+
+    def __len__(self):
+        return len(self.tensor_dataset)
+
+    def __getitem__(self, idx):
+        data, target = self.tensor_dataset[idx]
+        if self.transform:
+            data = self.transform(data)
+        return data, target
diff --git a/entcl/models/linear_head.py b/entcl/models/linear_head.py
index 1e78799..755e0f9 100644
--- a/entcl/models/linear_head.py
+++ b/entcl/models/linear_head.py
@@ -25,7 +25,8 @@ class LinearHead2(torch.nn.Module):
         for layer in m:
             if isinstance(layer, torch.nn.Linear):
                 torch.nn.init.kaiming_normal_(layer.weight)
-                torch.nn.init.zeros_(layer.bias)
+                if layer.bias is not None:
+                    torch.nn.init.zeros_(layer.bias)
     
     def forward(self, x):
         return self.mlp(x)
\ No newline at end of file
-- 
GitLab