From db348837a9cf22c774a195b55d5061dbf40ea3d0 Mon Sep 17 00:00:00 2001 From: Joseph Omar <j.omar@soton.ac.uk> Date: Mon, 11 Nov 2024 15:21:30 +0000 Subject: [PATCH] somehow it is really good? --- entcl/data/cifar100.py | 46 +++++++++++++++++++++++-------------- entcl/data/test.py | 2 +- entcl/data/util.py | 14 +++++++++++ entcl/models/linear_head.py | 3 ++- 4 files changed, 46 insertions(+), 19 deletions(-) create mode 100644 entcl/data/util.py diff --git a/entcl/data/cifar100.py b/entcl/data/cifar100.py index eba22ed..bf52d3b 100644 --- a/entcl/data/cifar100.py +++ b/entcl/data/cifar100.py @@ -1,5 +1,6 @@ import os from typing import Dict, List, Union +from entcl.data.util import TransformedTensorDataset import torch from torchvision.datasets import CIFAR100 as _CIFAR100 import torchvision.transforms.v2 as transforms @@ -65,7 +66,7 @@ class CIFAR100Dataset: # load and sort the data into master lists logger.debug("Loading and Sorting CIFAR100 Train split") master_train_data: Dict[int, torch.Tensor] = self._split_data_by_class(_CIFAR100( - CIFAR100_DIR, train=True, transform=self.transform, download=download + CIFAR100_DIR, train=True, transform=transforms.ToTensor(), download=download )) # split the data into datasets for each session logger.debug("Splitting Train Data for Sessions") @@ -75,7 +76,7 @@ class CIFAR100Dataset: logger.debug("Loading and Sorting CIFAR100 Test split") master_test_data: Dict[int, torch.Tensor] = self._split_data_by_class(_CIFAR100( - CIFAR100_DIR, train=False, transform=self.transform, download=download + CIFAR100_DIR, train=False,transform=transforms.ToTensor(), download=download )) logger.debug("Splitting Test Data for Sessions") self.test_datasets = self._split_test_data_for_sessions(master_test_data) @@ -158,7 +159,7 @@ class CIFAR100Dataset: labels = torch.cat(labels) logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes") logger.debug(f"Classes in Pretraining Dataset: {labels.unique(sorted=True)}") - datasets[0] = torch.utils.data.TensorDataset(samples, labels) + datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform) # CL sessions' datasets logger.debug("Splitting data for CL sessions") @@ -191,7 +192,7 @@ class CIFAR100Dataset: labels = torch.cat(labels) logger.debug(f"Creating dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes") logger.debug(f"Classes in this Session {session}'s Train Dataset: {labels.unique(sorted=True)}") - datasets[session] = torch.utils.data.TensorDataset(samples, labels) + datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform) return datasets @@ -219,22 +220,30 @@ class CIFAR100Dataset: labels = torch.cat(labels) logger.debug(f"Creating test dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes") logger.debug(f"Classes in Session {session}'s Test Dataset: {labels.unique(sorted=True)}") - datasets[session] = torch.utils.data.TensorDataset(samples, labels) + datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform) return datasets - def _split_data_by_class(self, dataset: _CIFAR100): - # loop through the dataset and split the data by class - all_data = {} - for data, label in dataset: - if label not in all_data: - all_data[label] = [] - all_data[label].append(data) - - # stack the data for each key + def _split_data_by_class(self, dataset: _CIFAR100, batch_size=64, num_workers=0): + # Create a DataLoader to load the dataset in batches + dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + # Dictionary to store data split by class + all_data = {} + + # Iterate over the dataloader + for data_batch, labels_batch in dataloader: + for data, label in zip(data_batch, labels_batch): + label = label.item() # Convert tensor label to Python int for key lookup + if label not in all_data: + all_data[label] = [] + all_data[label].append(data) + + # Stack the data for each key for class_id in sorted(all_data.keys()): all_data[class_id] = torch.stack(all_data[class_id]) + return all_data @@ -282,9 +291,12 @@ if __name__ == "__main__": cifar100 = CIFAR100Dataset() for session in range(cifar100.sessions + 1): logger.debug(f"Session {session}") - logger.debug(f"Train Dataset: {cifar100.get_dataset(session, train=True)}") - logger.debug(f"Test Dataset: {cifar100.get_dataset(session, train=False)}") - + train = cifar100.get_dataset(session, train=True) + test = cifar100.get_dataset(session, train=False) + logger.debug(f"Train Len: {len(train)}, Image Shape {train[0][0].shape}, Label Shape {train[0][1].shape}") + logger.debug(f"Test Len: {len(test)}, Image Shape {test[0][0].shape}, Label Shape {test[0][1].shape}") + + sleep(5) \ No newline at end of file diff --git a/entcl/data/test.py b/entcl/data/test.py index 36e2872..f5ebf2d 100644 --- a/entcl/data/test.py +++ b/entcl/data/test.py @@ -9,7 +9,7 @@ t = transforms.Compose( ) train = _CIFAR100( - "/cl/datasets/CIFAR", train=True, transform=t, download=True + "/cl/datasets/CIFAR", train=True, transform=transforms.ToTensor(), download=True ) first_sample, first_label = train[0] diff --git a/entcl/data/util.py b/entcl/data/util.py new file mode 100644 index 0000000..23dc108 --- /dev/null +++ b/entcl/data/util.py @@ -0,0 +1,14 @@ +import torch +class TransformedTensorDataset(torch.utils.data.Dataset): + def __init__(self, tensor_dataset, transform=None): + self.tensor_dataset = tensor_dataset + self.transform = transform + + def __len__(self): + return len(self.tensor_dataset) + + def __getitem__(self, idx): + data, target = self.tensor_dataset[idx] + if self.transform: + data = self.transform(data) + return data, target diff --git a/entcl/models/linear_head.py b/entcl/models/linear_head.py index 1e78799..755e0f9 100644 --- a/entcl/models/linear_head.py +++ b/entcl/models/linear_head.py @@ -25,7 +25,8 @@ class LinearHead2(torch.nn.Module): for layer in m: if isinstance(layer, torch.nn.Linear): torch.nn.init.kaiming_normal_(layer.weight) - torch.nn.init.zeros_(layer.bias) + if layer.bias is not None: + torch.nn.init.zeros_(layer.bias) def forward(self, x): return self.mlp(x) \ No newline at end of file -- GitLab