Skip to content
Snippets Groups Projects
Commit db348837 authored by Joseph Omar's avatar Joseph Omar
Browse files

somehow it is really good?

parent 959f9d54
No related branches found
No related tags found
No related merge requests found
import os
from typing import Dict, List, Union
from entcl.data.util import TransformedTensorDataset
import torch
from torchvision.datasets import CIFAR100 as _CIFAR100
import torchvision.transforms.v2 as transforms
......@@ -65,7 +66,7 @@ class CIFAR100Dataset:
# load and sort the data into master lists
logger.debug("Loading and Sorting CIFAR100 Train split")
master_train_data: Dict[int, torch.Tensor] = self._split_data_by_class(_CIFAR100(
CIFAR100_DIR, train=True, transform=self.transform, download=download
CIFAR100_DIR, train=True, transform=transforms.ToTensor(), download=download
))
# split the data into datasets for each session
logger.debug("Splitting Train Data for Sessions")
......@@ -75,7 +76,7 @@ class CIFAR100Dataset:
logger.debug("Loading and Sorting CIFAR100 Test split")
master_test_data: Dict[int, torch.Tensor] = self._split_data_by_class(_CIFAR100(
CIFAR100_DIR, train=False, transform=self.transform, download=download
CIFAR100_DIR, train=False,transform=transforms.ToTensor(), download=download
))
logger.debug("Splitting Test Data for Sessions")
self.test_datasets = self._split_test_data_for_sessions(master_test_data)
......@@ -158,7 +159,7 @@ class CIFAR100Dataset:
labels = torch.cat(labels)
logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
logger.debug(f"Classes in Pretraining Dataset: {labels.unique(sorted=True)}")
datasets[0] = torch.utils.data.TensorDataset(samples, labels)
datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
# CL sessions' datasets
logger.debug("Splitting data for CL sessions")
......@@ -191,7 +192,7 @@ class CIFAR100Dataset:
labels = torch.cat(labels)
logger.debug(f"Creating dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
logger.debug(f"Classes in this Session {session}'s Train Dataset: {labels.unique(sorted=True)}")
datasets[session] = torch.utils.data.TensorDataset(samples, labels)
datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
return datasets
......@@ -219,22 +220,30 @@ class CIFAR100Dataset:
labels = torch.cat(labels)
logger.debug(f"Creating test dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
logger.debug(f"Classes in Session {session}'s Test Dataset: {labels.unique(sorted=True)}")
datasets[session] = torch.utils.data.TensorDataset(samples, labels)
datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
return datasets
def _split_data_by_class(self, dataset: _CIFAR100):
# loop through the dataset and split the data by class
all_data = {}
for data, label in dataset:
if label not in all_data:
all_data[label] = []
all_data[label].append(data)
# stack the data for each key
def _split_data_by_class(self, dataset: _CIFAR100, batch_size=64, num_workers=0):
# Create a DataLoader to load the dataset in batches
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
# Dictionary to store data split by class
all_data = {}
# Iterate over the dataloader
for data_batch, labels_batch in dataloader:
for data, label in zip(data_batch, labels_batch):
label = label.item() # Convert tensor label to Python int for key lookup
if label not in all_data:
all_data[label] = []
all_data[label].append(data)
# Stack the data for each key
for class_id in sorted(all_data.keys()):
all_data[class_id] = torch.stack(all_data[class_id])
return all_data
......@@ -282,9 +291,12 @@ if __name__ == "__main__":
cifar100 = CIFAR100Dataset()
for session in range(cifar100.sessions + 1):
logger.debug(f"Session {session}")
logger.debug(f"Train Dataset: {cifar100.get_dataset(session, train=True)}")
logger.debug(f"Test Dataset: {cifar100.get_dataset(session, train=False)}")
train = cifar100.get_dataset(session, train=True)
test = cifar100.get_dataset(session, train=False)
logger.debug(f"Train Len: {len(train)}, Image Shape {train[0][0].shape}, Label Shape {train[0][1].shape}")
logger.debug(f"Test Len: {len(test)}, Image Shape {test[0][0].shape}, Label Shape {test[0][1].shape}")
sleep(5)
\ No newline at end of file
......@@ -9,7 +9,7 @@ t = transforms.Compose(
)
train = _CIFAR100(
"/cl/datasets/CIFAR", train=True, transform=t, download=True
"/cl/datasets/CIFAR", train=True, transform=transforms.ToTensor(), download=True
)
first_sample, first_label = train[0]
......
import torch
class TransformedTensorDataset(torch.utils.data.Dataset):
def __init__(self, tensor_dataset, transform=None):
self.tensor_dataset = tensor_dataset
self.transform = transform
def __len__(self):
return len(self.tensor_dataset)
def __getitem__(self, idx):
data, target = self.tensor_dataset[idx]
if self.transform:
data = self.transform(data)
return data, target
......@@ -25,7 +25,8 @@ class LinearHead2(torch.nn.Module):
for layer in m:
if isinstance(layer, torch.nn.Linear):
torch.nn.init.kaiming_normal_(layer.weight)
torch.nn.init.zeros_(layer.bias)
if layer.bias is not None:
torch.nn.init.zeros_(layer.bias)
def forward(self, x):
return self.mlp(x)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment