diff --git a/entcl/cl.py b/entcl/cl.py index 8b137891791fe96927ad78e64b0aad7bded08bdc..04d816ccfd618915725bdefc455a6d19e208993b 100644 --- a/entcl/cl.py +++ b/entcl/cl.py @@ -1 +1,200 @@ +from typing import Dict, Tuple +from entcl.data.util import TransformedTensorDataset +import pandas as pd +import torch +from loguru import logger +from tqdm import tqdm + +def cl_session(args, session_dataset: TransformedTensorDataset, model: torch.nn.Module, mapping: Dict[int, int]) -> torch.nn.Module: + """ + Run a continual learning session on the given session dataset using the given model + :param args: Arguments object with the attributes `device`, `cl_epochs`, `seed`. + :param session_dataset: TransformedTensorDataset with the session data. Should have OOD predtypes and pseudo labels. + :param model: torch.nn.Module with the model to use for continual learning. `forward` should return `logits, features`. + :return: torch.nn.Module with the updated model. + """ + logger.debug(f"Begin Continual Learning Session {args.current_session}") + # make sure the dataset has the correct shape + assert len(session_dataset.tensor_dataset.tensors) == 5, "Session Dataset should have 5 tensors, (data, true labels, true types, pred types, pseudo labels). Got: " + str(len(session_dataset.tensor_dataset.tensors)) + + # create the required training stuff and things and junk + + # we are only training with nove data at the moment, so we need to whittle down the dataset to only the predicted novel samples + novel_samples_mask = session_dataset.tensor_dataset.tensors[3] == 1 # novel samples are labelled with 1. predtypes are the 4th tensor in the dataset + + novel_tensors = [tensor[novel_samples_mask] for tensor in session_dataset.tensor_dataset.tensors] + + # adjust the psuedo labels (which start from 0 atm) to start from args.dataset.known + ((args.current_session - 1) * args.dataset.novel_inc) + adjust_value = args.dataset.known + ((args.current_session - 1) * args.dataset.novel_inc) + logger.debug(f"Adjusting Pseudo Labels by {adjust_value}") + novel_tensors[4] += adjust_value + logger.debug(f"Adjusted Pseudo Labels: {torch.unique(novel_tensors[4], sorted=True)}") + + session_dataset = TransformedTensorDataset( + tensor_dataset=torch.utils.data.TensorDataset(*novel_tensors), + transform=session_dataset.transform, + ) + + train_loader = torch.utils.data.DataLoader( + dataset=session_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + drop_last=True, + ) + + old_test_loader = torch.utils.data.DataLoader( + dataset=args.dataset.get_dataset(session=args.current_session, train=False)["old"], + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + drop_last=False, + ) + + new_test_loader = torch.utils.data.DataLoader( + dataset=args.dataset.get_dataset(session=args.current_session, train=False)["new"], + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + drop_last=False, + ) + + all_test_loader = torch.utils.data.DataLoader( + dataset=args.dataset.get_dataset(session=args.current_session, train=False)["all"], + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + drop_last=False, + ) + + optimiser = torch.optim.SGD( + model.parameters(), + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + + criterion = torch.nn.CrossEntropyLoss() + + results = None + + # train the model + for epoch in range (args.cl_epochs): + logger.debug(f"Session {args.current_session} Epoch {epoch} Started") + + # train model + model, train_loss = _train(args, model, train_loader, optimiser, criterion) + logger.info(f"Epoch {epoch}: TRAINING : Train Loss: {train_loss}") + + # validate model on the three test sets + model, old_accuracy, old_val_loss = _validate(args, model, old_test_loader, criterion, mapping) + logger.info(f"Epoch {epoch}: VALIDATION : OLD Accuracy: {old_accuracy}, OLD Val Loss: {old_val_loss}") + + model, new_accuracy, new_val_loss = _validate(args, model, new_test_loader, criterion, mapping) + logger.info(f"Epoch {epoch}: VALIDATION : NEW Accuracy: {new_accuracy}, NEW Val Loss: {new_val_loss}") + + model, all_accuracy, all_val_loss = _validate(args, model, all_test_loader, criterion, mapping) + logger.info(f"Epoch {epoch}: VALIDATION : ALL Accuracy: {all_accuracy}, ALL Val Loss: {all_val_loss}") + + # just save the head + torch.save(model.head.state_dict(), f"{args.exp_dir}/session_{args.current_session}/head_s{args.current_session}_ep{epoch}.pt") + logger.debug(f"Session {args.current_session} Head saved to {args.exp_dir}/session_{args.current_session}/head_s{args.current_session}_ep{epoch}.pt") + + # create a df to hold the results + epoch_results = pd.DataFrame( + { + "epoch": [epoch], + "train_loss": [train_loss], + "old_accuracy": [old_accuracy], + "old_val_loss": [old_val_loss], + "new_accuracy": [new_accuracy], + "new_val_loss": [new_val_loss], + "all_accuracy": [all_accuracy], + "all_val_loss": [all_val_loss], + } + ) + + # append the results to the results dataframe + results = epoch_results if results is None else pd.concat([results, epoch_results], ignore_index=True) + + # save the results dataframe + results.to_csv(f"{args.exp_dir}/results_s{args.current_session}.csv", index=False) + logger.debug(f"Session {args.current_session} Results saved to {args.exp_dir}/results_s{args.current_session}.csv") + + return model + +def _train(args, model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, optimiser: torch.optim.Optimizer, criterion: torch.nn.Module) -> Tuple[torch.nn.Module, float]: + """ + Train the model on the given train_loader for one epoch + :param args: Arguments object with the attributes `device`. + :param model: torch.nn.Module with the model to train. + :param train_loader: torch.utils.data.DataLoader with the training data. + :param optimiser: torch.optim.Optimiser with the optimiser to use. + :param criterion: torch.nn.Module with the loss function to use. + :return: torch.nn.Module with the updated model and float with the total loss. + """ + model.train() + loss_total = 0 + for x, _, _, _, y in tqdm(train_loader, desc="Training", unit="batch"): # we only want psuedo labels for training + x, y = x.to(args.device), y.to(args.device) + logger.debug(f"x shape: {x.shape}, y shape: {y.shape}") + logger.debug(f"Psuedo Labels (Unique): {torch.unique(y, sorted=True)}") + optimiser.zero_grad() + logits, _ = model(x) + logger.debug(f"Logits Shape: {logits.shape}") + + loss = criterion(logits, y) + loss.backward() + optimiser.step() + loss = loss.item() + logger.debug(f"Loss: {loss}") + loss_total += loss + loss_total /= len(train_loader) + return model, loss_total + + + +def _validate(args, model: torch.nn.Module, val_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, mapping: Dict[int, int]) -> Tuple[torch.nn.Module, float, float]: + """ + Validate the model on the given val_loader + :param args: Arguments object with the attributes `device`. + :param model: torch.nn.Module with the model to validate. + :param val_loader: torch.utils.data.DataLoader with the validation data. + :param criterion: torch.nn.Module with the loss function to use. + :param mapping: Dict[int, int] with the mapping from true labels to pseudo labels. + :return: torch.nn.Module with the updated model, float with the accuracy and float with the total loss. + """ + model.eval() + correct = 0 + total = 0 + loss_total = 0 + with torch.no_grad(): + for x, y, _ in tqdm(val_loader, desc="Validating", unit="batch"): + x, y = x.to(args.device), y.to(args.device) + logger.debug(f"x shape: {x.shape}, y shape: {y.shape}") + logger.debug(f"True Labels (Unique): {torch.unique(y, sorted=True)}") + + # if the true labels are in the mapping, replace them with the pseudo labels + for true_label, pseudo_label in mapping.items(): + y[y == true_label] = pseudo_label # for each y in y, if y == true_label, replace it with pseudo_label + + logits, _ = model(x) + loss = criterion(logits, y) + loss = loss.item() + loss_total += loss + + _, predicted = torch.max(logits, 1) + + num_correct = (predicted == y).sum().item() + total += y.size(0) + correct += num_correct + + logger.debug(f"This Batch Num Correct: {num_correct}, Total: {y.size(0)}, Accuracy: {num_correct / y.size(0)}, Loss: {loss}") + + return model, correct / total, loss_total / len(val_loader) + \ No newline at end of file diff --git a/entcl/data/cifar100.py b/entcl/data/cifar100.py index bf52d3bb49e8fda2abf72bd0ff222cd9fd5e6277..5fbbf9df6df141b67f376ca197e298f999ca19c9 100644 --- a/entcl/data/cifar100.py +++ b/entcl/data/cifar100.py @@ -2,11 +2,12 @@ import os from typing import Dict, List, Union from entcl.data.util import TransformedTensorDataset import torch +from torchvision import disable_beta_transforms_warning +disable_beta_transforms_warning() from torchvision.datasets import CIFAR100 as _CIFAR100 import torchvision.transforms.v2 as transforms from entcl.config import CIFAR100_DIR from loguru import logger - CIFAR100_TRANSFORM = transforms.Compose( [ transforms.ToTensor(), @@ -35,18 +36,19 @@ class CIFAR100Dataset: :param cl_n_novel: Number of samples per novel class for each CL session. Default: 400 :param cl_n_prevnovel: Number of samples per previously novel class for each CL session. Default: 20 """ - if known >= 100: + self.num_classes = 100 + if known >= self.num_classes: raise ValueError("Number of known classes cannot be greater than 100") self.transform = CIFAR100_TRANSFORM self.known = known self.sessions = sessions - + self.novel = self.known - self.known self.pretrain_n_known = pretrain_n_known - self.num_classes = 100 + self.cl_n_known = cl_n_known self.cl_n_novel = cl_n_novel self.cl_n_prevnovel = cl_n_prevnovel - self.novel_inc = (100 - self.known) // self.sessions + self.novel_inc = (self.num_classes - self.known) // self.sessions # Verify the CL settings logger.debug( "Verifying incremental learning settings\n" @@ -190,9 +192,16 @@ class CIFAR100Dataset: samples = torch.cat(samples) labels = torch.cat(labels) + + + # when labelling ood samples, we do not care for prevnovel. there are two classes known/seen or novel/unseen + ood_labels = labels.clone() + ood_labels[ood_labels < novel_start] = 0 + ood_labels[ood_labels >= novel_start] = 1 + logger.debug(f"Creating dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes") logger.debug(f"Classes in this Session {session}'s Train Dataset: {labels.unique(sorted=True)}") - datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform) + datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels, ood_labels), transform=self.transform) return datasets @@ -204,23 +213,77 @@ class CIFAR100Dataset: """ datasets = {} + + + logger.debug(f"Splitting test data for session 0") + # pretraining session's dataset (session 0) only has known classes, and the dataset is technically an ALL dataset + samples, labels = [], [] + for class_idx in range(self.known): + samples.append(masterlist[class_idx]) + labels.append(torch.full((masterlist[class_idx].size(0),), class_idx, dtype=torch.long)) + + samples = torch.cat(samples) + labels = torch.cat(labels) + types = torch.full((labels.size(0),), 0, dtype=torch.long) # they are all known classes, so type is 0 + + logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes") + logger.debug(f"Classes in Session 0's Test Dataset: {labels.unique(sorted=True)}") + datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels, types), transform=self.transform) + + del samples, labels, types # free up memory + logger.debug(f"Splitting test data for {self.sessions} sessions") - for session in range(0, self.sessions + 1): + for session in range(1, self.sessions + 1): + datasets[session] = {} logger.debug(f"Splitting test data for session {session}") - samples, labels = [], [] - # get data for all seen classes (self.known + session * self.novel_inc) - seen_classes_end = self.known + (session * self.novel_inc) - logger.debug(f"There are {seen_classes_end} seen classes. Starting at 0 (inc), ending at {seen_classes_end} (exc)") - for class_idx in range(seen_classes_end): - samples.append(masterlist[class_idx]) - labels.append(torch.full((masterlist[class_idx].size(0),), class_idx, dtype=torch.long)) + # for cl sessions, there are 3 datasets old, new and all. + old_end_idx = self.known + (session - 1) * self.novel_inc + new_end_idx = self.known + session * self.novel_inc + logger.debug(f"Old classes end at {old_end_idx} (exc), New classes start at {old_end_idx} (inc) and end at {new_end_idx} (exc)") - samples = torch.cat(samples) - labels = torch.cat(labels) - logger.debug(f"Creating test dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes") - logger.debug(f"Classes in Session {session}'s Test Dataset: {labels.unique(sorted=True)}") - datasets[session] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform) + + # old dataset ------------------------------------- + old_samples, old_labels = [], [] + for class_idx in range(old_end_idx): + old_samples.append(masterlist[class_idx]) + old_labels.append(torch.full((masterlist[class_idx].size(0),), class_idx, dtype=torch.long)) + + old_samples = torch.cat(old_samples) + old_labels = torch.cat(old_labels) + old_types = torch.full((old_labels.size(0),), 0, dtype=torch.long) # they are all known classes, so type is 0 + + logger.debug(f"Creating OLD dataset for session {session}. There are {len(old_samples)} samples, and {len(old_labels)} labels. There are {old_labels.unique().size(0)} different classes") + logger.debug(f"Classes in Session {session}'s OLD Dataset: {old_labels.unique(sorted=True)}") + + datasets[session] = {"old": TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(old_samples, old_labels, old_types), transform=self.transform)} + + + # new dataset ------------------------------------- + new_samples, new_labels = [], [] + for class_idx in range(old_end_idx, new_end_idx): + new_samples.append(masterlist[class_idx]) + new_labels.append(torch.full((masterlist[class_idx].size(0),), class_idx, dtype=torch.long)) + + new_samples = torch.cat(new_samples) + new_labels = torch.cat(new_labels) + new_types = torch.full((new_labels.size(0),), 1, dtype=torch.long) # they are all novel classes, so type is 1 + + logger.debug(f"Creating NEW dataset for session {session}. There are {len(new_samples)} samples, and {len(new_labels)} labels. There are {new_labels.unique().size(0)} different classes") + logger.debug(f"Classes in Session {session}'s NEW Dataset: {new_labels.unique(sorted=True)}") + + datasets[session]["new"] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(new_samples, new_labels, new_types), transform=self.transform) + + + # all dataset ------------------------------------- + all_samples = torch.cat([old_samples, new_samples]) + all_labels = torch.cat([old_labels, new_labels]) + all_types = torch.cat([old_types, new_types]) + + logger.debug(f"Creating ALL dataset for session {session}. There are {len(all_samples)} samples, and {len(all_labels)} labels. There are {all_labels.unique().size(0)} different classes") + logger.debug(f"Classes in Session {session}'s ALL Dataset: {all_labels.unique(sorted=True)}") + + datasets[session]["all"] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(all_samples, all_labels, all_types), transform=self.transform) return datasets diff --git a/entcl/data/util.py b/entcl/data/util.py index 23dc10818d04cb4bf3859620572be4cad45a732b..00c3aa191c21528a7727637140556bfc4a09731a 100644 --- a/entcl/data/util.py +++ b/entcl/data/util.py @@ -8,7 +8,8 @@ class TransformedTensorDataset(torch.utils.data.Dataset): return len(self.tensor_dataset) def __getitem__(self, idx): - data, target = self.tensor_dataset[idx] + the_tuple = self.tensor_dataset[idx] if self.transform: - data = self.transform(data) - return data, target + data = self.transform(the_tuple[0]) + + return data, *the_tuple[1:] diff --git a/entcl/pretrain.py b/entcl/pretrain.py index 5bfc2f83ee87d61014c6ff5a4e1719f9f6f1e546..11c2b0713fad58d08f9a562ea7843dc66ee6d7e7 100644 --- a/entcl/pretrain.py +++ b/entcl/pretrain.py @@ -7,7 +7,7 @@ from tqdm import tqdm def pretrain(args, model): train_dataset = args.dataset.get_dataset(session=0, train=True) - train_dataloader = torch.utils.data.DataLoader( + train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=args.batch_size, shuffle=True, @@ -17,7 +17,7 @@ def pretrain(args, model): ) val_dataset = args.dataset.get_dataset(session=0, train=False) - val_dataloader = torch.utils.data.DataLoader( + val_loader = torch.utils.data.DataLoader( dataset=val_dataset, batch_size=args.batch_size, shuffle=False, @@ -41,10 +41,10 @@ def pretrain(args, model): logger.debug(f"Loading pretrained model from {args.pretrain_load}") model.head.load_state_dict(torch.load(args.pretrain_load, weights_only=True)) model = model.to(args.device) - model, accuracy = _validate(args, model, val_dataloader) + model, accuracy, _ = _validate(args, model, val_loader, criterion=criterion) logger.info(f"Loaded Pretrained Model Accuracy: {accuracy}") return model - else: + elif args.mode in ["pretrain", "both"]: logger.debug("No pretrained model to load, training from scratch") model = model.to(args.device) for epoch in range(args.pretrain_epochs): @@ -52,15 +52,16 @@ def pretrain(args, model): logger.debug(f"Epoch {epoch} Started") # train model - model, train_loss = _train(args, model, train_dataloader, optimiser, criterion) + model, train_loss = _train(args, model, train_loader, optimiser, criterion) logger.info(f"Epoch {epoch}: TRAINING : Train Loss: {train_loss}") # validate model - model, accuracy, val_loss = _validate(args, model, val_dataloader, criterion) + model, accuracy, val_loss = _validate(args, model, val_loader, criterion) logger.info(f"Epoch {epoch}: VALIDATION : Accuracy: {accuracy}, Val Loss: {val_loss}") # just save the head, the backbone is frozen, and is fucking massive - model.head.save(f"{args.exp_root}/{args.name}/head_pretrain_{epoch}.pth") + torch.save(model.head.state_dict(), f"{args.exp_dir}/session_0/head_s0_ep{epoch}.pt") + logger.debug(f"Session 0 Head saved to {args.exp_dir}/session_0/head_s0_ep{epoch}.pt") # create a dataframe with the results epoch_results = pd.DataFrame( @@ -69,22 +70,26 @@ def pretrain(args, model): ) # append the results to the results dataframe - results = epoch_results if results is None else results.append(epoch_results) + results = epoch_results if results is None else pd.concat([results, epoch_results]) # save the results dataframe - results.to_csv(f"{args.exp_root}/{args.name}/results_pretrain.csv", index=False) - logger.debug(f"Epoch {epoch} Finished. Pretrain Results Saved to {args.exp_root}/{args.name}/results_pretrain.csv") + results.to_csv(f"{args.exp_dir}/results_s0.csv", index=False) + logger.debug(f"Epoch {epoch} Finished. Pretrain Results Saved to {args.exp_dir}/results_s0.csv") return model + else: + raise ValueError(f"No Model to load and mode is not pretrain or both. Mode: {args.mode}, Pretrain Load: {args.pretrain_load}") -def _train(args, model, train_dataloader, optimiser, criterion): +def _train(args, model, train_loader, optimiser, criterion): model.train() loss_total = 0 - for x, y in tqdm(train_dataloader, desc=f"Training", unit = "batch"): + for x, y, _ in tqdm(train_loader, desc=f"Training", unit = "batch"): x, y = x.to(args.device), y.to(args.device) - logger.debug(f"x shape: {x.shape}, y shape: {y.shape}") + logger.debug(f"X Shape: {x.shape}, Y Shape: {y.shape}") + logger.debug(f"True Labels (Unique): {torch.unique(y, sorted=True)}") + optimiser.zero_grad() logits, _ = model(x) - logger.debug(f"logits shape: {logits.shape}") + logger.debug(f"Logits Shape: {logits.shape}") loss = criterion(logits, y) loss.backward() @@ -92,18 +97,19 @@ def _train(args, model, train_dataloader, optimiser, criterion): loss = loss.item() logger.debug(f"Loss: {loss}") loss_total += loss - loss_total /= len(train_dataloader) + loss_total /= len(train_loader) return model, loss_total -def _validate(args, model, val_dataloader, criterion): +def _validate(args, model, val_loader, criterion): model.eval() correct = 0 total = 0 loss_total = 0 with torch.no_grad(): - for x, y in tqdm(val_dataloader, desc=f"Validating", unit = "batch"): + for x, y, _ in tqdm(val_loader, desc=f"Validating", unit = "batch"): x, y = x.to(args.device), y.to(args.device) logger.debug(f"x shape: {x.shape}, y shape: {y.shape}") + logger.debug(f"True Labels (Unique): {torch.unique(y, sorted=True)}") logits, _ = model(x) logger.debug(f"logits shape: {logits.shape}") @@ -120,4 +126,4 @@ def _validate(args, model, val_dataloader, criterion): logger.debug(f"This Batch Num Correct: {num_correct}, Total: {y.size(0)}, Accuracy: {num_correct / y.size(0)}, Loss: {loss}") - return model, correct / total, loss_total / len(val_dataloader) + return model, correct / total, loss_total / len(val_loader) diff --git a/entcl/run.py b/entcl/run.py index 48604efeab1fd5a1ed4a301d894c6f3949907ece..073a505a840f30381da69a2601700d9853055bcd 100644 --- a/entcl/run.py +++ b/entcl/run.py @@ -1,11 +1,14 @@ import argparse import os import sys +from entcl.cl import cl_session +from entcl.utils.ncd import find_novel_classes_for_session +from entcl.utils.ood import label_ood_for_session from loguru import logger from datetime import datetime import torch -from entcl.utils.util import seed +from entcl.utils.util import generate_unique_path, seed from entcl.models.model import ENTCLModel from entcl.pretrain import pretrain @@ -16,10 +19,39 @@ def main(args: argparse.Namespace): model = ENTCLModel(head=args.head, backbone_url=args.backbone_url, backbone=args.backbone, backbone_source=args.backbone_source) logger.debug(f"Model: {model}") - if args.mode == 'pretrain': - model = pretrain(args, model) - else: - raise NotImplementedError(f"Mode {args.mode} not implemented") + logger.info("Pretraining Model (Session 0)") + model = pretrain(args, model) + + if args.mode in ['cl', 'both']: + logger.info("Starting Continual Learning") + + for session in range(1, args.sessions + 1): + logger.info(f"Starting Continual Learning Session {session}") + args.current_session = session + session_dataset = args.dataset.get_dataset(session) + + # OOD detection + session_dataset = label_ood_for_session(args, session_dataset, model) # returns a new dataset with the OOD samples labelled + + # NCD + session_dataset, mapping = find_novel_classes_for_session(args, session_dataset, model) # returns a new dataset with the novel samples labelled + + # dataset should now have the form (data, true labels, true types, pred types, pseudo labels) + + # Expand Classification Head & Initialise + model.head.expand(args.dataset.novel_inc) # we are cheating here, we know the number of novel classes + + # freeze the weights for the existing classes. We are only training unknown samples (EG: 50 (known) + (2 (session) - 1) * 10 (novel_inc) = 60 classes have been seen in cl session 2) + model.head.freeze(start_idx=0, end_idx=args.dataset.known + ((session -1) * args.dataset.novel_inc)) + + # run continual learning session + model = cl_session(args, session_dataset, model, mapping) + + + + + + @@ -28,7 +60,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() # program args parser.add_argument('--name', type=str, default="entcl_" + datetime.now().isoformat(timespec='seconds')) - parser.add_argument('--mode', type=str, default='pretrain', help='Mode to run the program', choices=['pretrain', 'cl', 'dryrun']) + parser.add_argument('--mode', type=str, default='both', help='Mode to run the program', choices=['pretrain', 'cl', 'both']) parser.add_argument('--dryrun', action='store_true', default=False, help='Dry Run Mode. Does not save anything') parser.add_argument('--debug', action='store_true', default=False, help='Debug Mode. Epochs are only done once. Enables Verbose Mode automatically') @@ -79,6 +111,9 @@ if __name__ == "__main__": # ood args parser.add_argument('--ood_score', type=str, default='entropy', help='Changes the metric(s) to base OOD detection on', choices=['entropy', 'energy', 'both']) parser.add_argument('--ood_eps', type=float, default=1e-8, help='Epsilon value for computing entropy in OOD detection') + + # ncd args + parser.add_argument('--ncd_findk_method', type=str, default='cheat', help='Method to use for finding the number of novel classes', choices=['elbow', 'silhouette', 'gap', 'cheat']) args = parser.parse_args() seed(args.seed) # seed everything @@ -102,6 +137,9 @@ if __name__ == "__main__": # initialise directories os.makedirs(args.exp_root, exist_ok=True) args.exp_dir = os.path.join(args.exp_root, args.name) + + if args.name == "debug": + args.exp_dir = generate_unique_path(args.exp_dir) os.makedirs(args.exp_dir, exist_ok=False) # initialise logger @@ -129,6 +167,9 @@ if __name__ == "__main__": args.head = MLPHead(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256) logger.debug(f"Using MLP Head: {args.head}") + if args.mode == 'cl' and args.pretrain_load is None: + raise ValueError("Continual Learning Mode requires a pretrained model to load") + argstr = "Arguments: \n" for arg in vars(args): argstr += f"{arg}: {getattr(args, arg)}\n" diff --git a/entcl/utils/findk.py b/entcl/utils/findk.py deleted file mode 100644 index 8241a2236aefcaf53e6ef53379555b1af2f26766..0000000000000000000000000000000000000000 --- a/entcl/utils/findk.py +++ /dev/null @@ -1,102 +0,0 @@ -import torch -import numpy as np -from sklearn.cluster import KMeans -from loguru import logger -from tqdm import tqdm - -def elbow(features: torch.Tensor, args) -> int: - """ - Finds the optimal number of clusters using the elbow method and the kneed library. - :param features: torch.Tensor with the features to cluster. - :param args: Arguments object with the attribute `seed`. - :return: int with the optimal number of clusters. - """ - from kneed import KneeLocator - inertias = [] - - ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1) - logger.debug(f"Finding k using elbow method for k in {ks}") - - # Calculate the inertia for each k - for k in tqdm(ks, desc="Calculating Inertias", leave=True, unit="k"): - kmeans = KMeans(n_clusters=k, random_state=args.seed) - kmeans.fit(features) - inertias.append(kmeans.inertia_) - - logger.debug(f"Elbow Inertias: {inertias}") - logger.debug(f"Running KneeLocator") - - # Find the knee - kneedle = KneeLocator(ks, inertias, curve="convex", direction="decreasing") - logger.info(f"Elbow Method Found k: {kneedle.knee}") - return kneedle.knee - -def silhouette(features: torch.Tensor, args) -> int: - """ - Finds the optimal number of clusters using the silhouette method. - :param features: torch.Tensor with the features to cluster. - :param args: Arguments object with the attribute `seed`. - :return: int with the optimal number of clusters. - """ - from sklearn.metrics import silhouette_score - silhouettes = [] - - ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1) - logger.debug(f"Finding k using silhouette method for k in {ks}") - - # Calculate the silhouette score for each k - for k in tqdm(ks, desc="Calculating Silhouettes", leave=True, unit="k"): - kmeans = KMeans(n_clusters=k, random_state=args.seed) - pseudo_labels = kmeans.fit_predict(features) - silhouette = silhouette_score(features, pseudo_labels) - silhouettes.append(silhouette) - - logger.debug(f"Silhouette Scores: {silhouettes}") - - # Find the best k - best_k = ks[silhouettes.index(max(silhouettes))] - logger.info(f"Silhouette Method Found k: {best_k}") - return best_k - -def gap(features: torch.Tensor, args) -> int: - """ - Finds the optimal number of clusters using the gap statistic. - :param features: torch.Tensor with the features to cluster. - :param args: Arguments object with the attribute `seed`. - :return: int with the optimal number of clusters. - """ - logger.debug("Finding k using gap statistic") - features_np = features.cpu().numpy() - n_samples, n_features = features_np.shape - - min_vals = features_np.min(axis=0) - max_vals = features_np.max(axis=0) - - gap_stats = {} - wks = [] - wks_refs = [] - - for k in tqdm(range(args.ncd_findk_mink, args.ncd_findk_maxk + 1), desc="Calculating Gap Statistic", leave=True, unit="k"): - logger.debug(f"Calculating Gap Statistic for k={k}") - kmeans = KMeans(n_clusters=k, random_state=args.seed) - kmeans.fit(features_np) - wk = kmeans.inertia_ - wks.append(np.log(wk)) - - wk_refs = [] - for _ in tqdm(range(args.ncd_gap_nrefs), desc="Calculating Reference Gap Statistic", leave=True, unit="ref"): - logger.debug(f"Calculating Reference Gap Statistic for k={k}") - ref_data = np.random.uniform(min_vals, max_vals, (n_samples, n_features)) - kmeans = KMeans(n_clusters=k, random_state=args.seed) - kmeans.fit(ref_data) - wk_ref = kmeans.inertia_ - wk_refs.append(np.log(wk_ref)) - wks_refs.append(np.mean(wk_refs)) - - gap_stats[k] = wks_refs[-1] - wks[-1] - - optimal_k = max(gap_stats, key=gap_stats.get) - - logger.info(f"Gap Statistic Found k: {optimal_k}") - - return optimal_k \ No newline at end of file diff --git a/entcl/utils/ncd.py b/entcl/utils/ncd.py index 60df52c8ae2f30e471a5e1636c7532ca281087d2..8c5fb51fbab5dde6a5dcf98085be43ba186c406f 100644 --- a/entcl/utils/ncd.py +++ b/entcl/utils/ncd.py @@ -1,19 +1,21 @@ - - import os -from typing import Union +from typing import Dict, Union from entcl.data.util import TransformedTensorDataset -from entcl.utils.findk import elbow, gap, silhouette from loguru import logger import numpy as np +import pandas as pd from sklearn.cluster import KMeans from sklearn.metrics import confusion_matrix from scipy.optimize import linear_sum_assignment from entcl.utils.util import generate_unique_path import torch +import torch.utils from tqdm import tqdm -def _extract_features(args, dataset: TransformedTensorDataset, model: torch.nn.Module) -> torch.Tensor: + +def _extract_features( + args, dataset: TransformedTensorDataset, model: torch.nn.Module +) -> torch.Tensor: """ Extracts features from the data in the dataset using the provided model's backbone. :param args: Arguments object with the attributes `device`, . @@ -27,21 +29,26 @@ def _extract_features(args, dataset: TransformedTensorDataset, model: torch.nn.M pin_memory=args.pin_memory, shuffle=False, ) - + model.eval() features = None with torch.no_grad(): - for batch in tqdm(dataloader, desc="Extracting Features", leave=True, unit="batch"): - x = batch[0] # Only the data is needed, assumes batch is iterable. + for batch in tqdm( + dataloader, desc="Extracting Features", leave=True, unit="batch" + ): + x = batch[0] # Only the data is needed, assumes batch is iterable. x = x.to(args.device) - x_proj, _= model(x) # we do not need the predicted logits, only the features. + x_proj, _ = model( + x + ) # we do not need the predicted logits, only the features. if features is None: features = x_proj else: features = torch.cat((features, x_proj), dim=0) - + return features + def plot_confmat(confmat: torch.Tensor, path: str) -> None: """ Plots a confusion matrix and saves it to the specified path. @@ -49,105 +56,125 @@ def plot_confmat(confmat: torch.Tensor, path: str) -> None: :param path: str with the path to save the confusion matrix plot. """ try: + import matplotlib + + matplotlib.use("Agg") import matplotlib.pyplot as plt + import seaborn as sns from sklearn.metrics import ConfusionMatrixDisplay fig, ax = plt.subplots(figsize=(10, 10)) - sns.heatmap(confmat, annot=True, fmt='d', cmap='Blues', ax=ax) - ax.set_xlabel('Predicted Label') - ax.set_ylabel('True Label') - ax.set_title('Confusion Matrix') + sns.heatmap(confmat, annot=True, fmt="d", cmap="Blues", ax=ax) + ax.set_xlabel("Predicted Label") + ax.set_ylabel("True Label") + ax.set_title("Confusion Matrix") plt.savefig(path) plt.close() except ImportError: - logger.error("Could not import matplotlib or seaborn. Cannot plot confusion matrix. Confusion Matrix not saved.") - -def _calculate_clustering_accuracy(true_labels: torch.Tensor, pseudo_labels: torch.Tensor, args) -> None: + logger.error( + "Could not import matplotlib or seaborn. Cannot plot confusion matrix. Confusion Matrix not saved." + ) + + +def generate_mapping( + true_labels: torch.Tensor, pseudo_labels: torch.Tensor, args +) -> Dict[int, int]: """ - Calculates the clustering accuracy between the true labels and the pseudo labels. + Calculates the clustering accuracy between the true labels and the pseudo labels and generates a mapping between the two for validation and testing. :param true_labels: torch.Tensor with the true labels for the novel samples. :param pseudo_labels: torch.Tensor with the pseudo labels for the novel samples. + :return: Dict[int, int] a mapping between the true labels and the pseudo labels. """ - assert true_labels.shape == pseudo_labels.shape, f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}" + logger.debug("Calculating Clustering Accuracy") + assert ( + true_labels.shape == pseudo_labels.shape + ), f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}" - # true labels will be > 50 and psuedo labels will start at 0. we need to adjust the pseudo labels to match the true labels. # we will assume the true labels are sequential, and the lowest true label is 0. + + novel_true_min_idx = true_labels.min() + pseudo_labels += novel_true_min_idx + + true_labels = true_labels.cpu().numpy() + pseudo_labels = pseudo_labels.cpu().numpy() + + conf_mat = confusion_matrix(true_labels, pseudo_labels) + row_idxs, col_idxs = linear_sum_assignment( + -conf_mat + ) # Hungarian Algorithm to find the best matching between the true and pseudo labels. + + # align the pseudo labels with the true labels based on the hungarian algorithm results + pseudo_labels_aligned = np.zeros_like(pseudo_labels) + for pseudo_label, true_label in zip(col_idxs, row_idxs): + pseudo_labels_aligned[pseudo_labels == pseudo_label] = true_label + + # create a mapping between the true labels and the pseudo labels, used in validation and testing + mapping = {true_label: pseudo_label for pseudo_label, true_label in zip(col_idxs, row_idxs)} - true_labels -= true_labels.min() # Adjust the true labels to start at 0. - - # optimal assignment of pseudo labels to true labels - confusion_mat = torch.from_numpy(confusion_matrix(true_labels.cpu().numpy(), pseudo_labels.cpu().numpy())) # Rows are true labels, columns are pseudo labels - cost_mat = -confusion_mat # Hungarian algorithm minimizes the cost, so we negate the confusion matrix. - row_idx, col_idx = linear_sum_assignment(cost_mat) # Hungarian algorithm to find the optimal assignment. - - # Maps pseudo labels to true labels. - label_assignments = dict(zip(col_idx, row_idx)) # Maps pseudo labels to true labels. - aligned_predicted_labels = torch.tensor([label_assignments[pseudo_label] for pseudo_label in pseudo_labels]) # Aligns the predicted labels with the true labels. - - # Calculate the accuracy - correct_assignments = (aligned_predicted_labels == true_labels).sum().item() # Number of correct assignments. - accuracy = correct_assignments / true_labels.shape[0] # Accuracy is the number of correct assignments divided by the number of samples. - - # Plot the confusion matrix - confmat_path = generate_unique_path(os.path.join(args.exp_dir, args.name, "confusion_matrix.png")) - plot_confmat(confusion_mat, path=confmat_path) - - unique_true_labels = np.unique(true_labels.cpu().numpy()) - unique_pseudo_labels = np.unique(pseudo_labels.cpu().numpy()) - - # find any ignored labels - ignored_true_labels = set(unique_true_labels) - set(unique_true_labels[row_idx]) - ignored_pseudo_labels = set(unique_pseudo_labels) - set(unique_pseudo_labels[col_idx]) - - # Log the results - - string = f"Clustering Accuracy Computed: {accuracy*100:.2f}%" - string += f"\n True Labels (adjusted): {unique_true_labels}" - string += f"\n Pseudo Labels: {unique_pseudo_labels}" - string += f"\n # of True Labels {len(unique_true_labels)}, # of Pseudo Labels {len(unique_pseudo_labels)}" - string += f"\n Confusion Matrix: Plot saved to `{confmat_path}`" - string += f"Ignored True Labels: {ignored_true_labels} Count: {len(ignored_true_labels)}" - string += f"Ignored Pseudo Labels: {ignored_pseudo_labels} Count: {len(ignored_pseudo_labels)}" - string += f"\n If there are ignored true labels, the number of clusters has been overestimated. If there are ignored pseudo labels, the number of clusters has been underestimated." - string += f"\n Ignored labels are not included in the accuracy calculation." + # compute the overall accuracy + overall_accuracy = np.mean(true_labels == pseudo_labels_aligned) + + # compute the per-class accuracy + per_class_accuracy = {} + for true_class in np.unique(true_labels): + mask = true_labels == true_class + per_class_accuracy[true_class] = np.mean( + true_labels[mask] == pseudo_labels_aligned[mask] + ) + + string = f"NCD Clustering Accuracies for Session {args.current_session}:" + for true_class, acc in per_class_accuracy.items(): + string += f"\nTrue Class {true_class}: {acc*100:4f}%" string += f"\n" - string += f"Per True Label Accuracy:" - for true_label in unique_true_labels: - mask = true_labels == true_label - true_label_accuracy = (aligned_predicted_labels[mask] == true_label).sum().item() / mask.sum().item() - string += f"\n True Label: {true_label} Accuracy: {true_label_accuracy*100:.2f}%" - + string += f"\nOverall Accuracy: {overall_accuracy*100:4f}%" + logger.info(string) + string = f"Mapping for Session {args.current_session}:" + for true_label, pseudo_label in mapping.items(): + string += f"\nTrue Label {true_label} -> Pseudo Label {pseudo_label}" + + logger.info(string) + + # plot the confusion matrix + try: + plot_confmat( + conf_mat, + os.path.join(args.exp_dir, f"confmat_session_{args.currect_session}.png"), + ) + logger.debug( + f"Confusion Matrix saved to {os.path.join(args.exp_dir, f'confmat_session_{args.currect_session}.png')}" + ) + except Exception as e: + logger.error( + f"Could not plot the confusion matrix. Error: {e}\n Confusion Matrix not saved. Continuing..." + ) + return mapping -def _cluster_features(args, features:torch.Tensor) -> torch.Tensor: +def _cluster_features(args, features: torch.Tensor) -> torch.Tensor: """ Clusters the features using K Means. - :param args: Arguments object with the attribute `ncd_findk_method`, `novel_classes_per_session`, `seed`. + :param args: Arguments object with the attribute `novel_classes_per_session`, `seed`. :param features: torch.Tensor with the features to cluster. Assumes the features are from novel-classed samples only. :return: torch.Tensor with unadjusted psuedo-labels for given features. """ - - if args.ncd_findk_method == "cheat": - k = args.novel_classes_per_session - elif args.ncd_findk_method == "elbow": - k = elbow(features, args) - elif args.ncd_findk_method == "silhouette": - k = silhouette(features, args) - elif args.ncd_findk_method == "gap": - k = gap(features, args) - else: - raise ValueError(f"Unknown method for finding k: {args.ncd_findk_method}") - - logger.info(f"FindK Method: {args.ncd_findk_method} k: {k} True k: {args.novel_classes_per_session}") - - kmeans = KMeans(n_clusters=k, random_state=args.seed) - pseudo_labels = torch.tensor(kmeans.fit_predict(features)) + logger.debug( + f"Clustering Features: Using K = {args.novel_classes_per_session} (args.novel_classes_per_session)" + ) + + kmeans = KMeans(n_clusters=args.novel_classes_per_session, random_state=args.seed) + + pseudo_labels = torch.tensor(kmeans.fit_predict(features.cpu().numpy())) return pseudo_labels -def find_novel_classes_for_session(args, session_dataset: TransformedTensorDataset, model: torch.nn.Module) -> Union[torch.Tensor, TransformedTensorDataset]: + +def find_novel_classes_for_session( + args, + session_dataset: TransformedTensorDataset, + model: torch.nn.Module, +) -> Union[torch.Tensor, TransformedTensorDataset]: """ Finds the novel classes in the given session dataset using KMeans clustering and the given model :param args: Arguments object with the attributes `device`, `ncd_findk_method`, `novel_classes_per_session`, `seed`. @@ -155,25 +182,66 @@ def find_novel_classes_for_session(args, session_dataset: TransformedTensorDatas :param model: torch.nn.Module with the model to use for clustering. :return: torch.Tensor with the pseudo-labels for the novel classes. """ - features = _extract_features(args, session_dataset, model) - pseudo_labels = _cluster_features(args, features) - return pseudo_labels - -def discover_classes_in_session_dataset( - args, - session_dataset: TransformedTensorDataset, - model: torch.nn.Module, - return_new_dataset: bool = True, -): - """ - Discovers classes in the session dataset using the provided model. - This step is for after OOD detection, which modifies the session dataset to include the OOD label in the third column. - - :param args: Arguments object. - :param session_dataset: TransformedTensorDataset for the session (data, target, ood). - :param model: The model to evaluate with. - :param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted classes. - """ - - \ No newline at end of file + # first we need to create a TransformedTensorDataset with only the predicted novel samples + novel_samples_mask = ( + session_dataset.tensor_dataset.tensors[3] == 1 + ) # OOD samples are marked with 2. the predicted types tensor (known/novel labels) is the third tensor in the dataset + novel_samples_mask = novel_samples_mask.cpu() # idk why this is not already on cpu + + novel_tensors = [ + tensor[novel_samples_mask].cpu() + for tensor in session_dataset.tensor_dataset.tensors + ] # get only the predicted novel samples + + for i, t in enumerate(novel_tensors): + logger.debug(f"Tensor[{i}] Shape: {t.shape}") + + novel_dataset = TransformedTensorDataset( + tensor_dataset=torch.utils.data.TensorDataset(*novel_tensors), + transform=session_dataset.transform, + ) + + # extract features from the novel samples + novel_features = _extract_features(args, novel_dataset, model) + + # cluster the features + pseudo_labels = _cluster_features(args, novel_features) + + # calculate the clustering accuracy (not used in the dataset, only for logging and testing) + mapping = generate_mapping( + novel_dataset.tensor_dataset.tensors[1], pseudo_labels, args + ) + + # next we need to align the pseudo labels tensor with the original dataset, giving known samples a pseudo label of -1 + pseudo_labels_aligned = torch.full( + session_dataset.tensor_dataset.tensors[2].shape, + -1, + dtype=torch.long, + device=session_dataset.tensor_dataset.tensors[2].device, + ) + pseudo_labels_aligned[novel_samples_mask] = ( + pseudo_labels # whereever the sample is novel, assign the corresponding pseudo label. + ) + + # just for checking that this stuff above works properly + clustering_df = pd.DataFrame(columns=["true_labels", "type", "pseudo_labels"]) + clustering_df["true_labels"] = session_dataset.tensor_dataset.tensors[1].cpu().numpy() + clustering_df["type"] = session_dataset.tensor_dataset.tensors[2].cpu().numpy() + clustering_df["pseudo_labels"] = pseudo_labels_aligned.cpu().numpy() + + # save the dataset to a csv file + dataset_path = os.path.join(args.exp_dir, f"clustering_s{args.current_session}.csv") + clustering_df.to_csv(dataset_path, index=False) + + # create a new dataset with the pseudo labels alongside the original tensors and the same transform + # NOTE: the tensors attribute of a TensorDataset is a tuple. we can't append to it so instead we create a new TensorDataset with the new pseudo labels tensor. + # NOTE: The dataset at the end of NCD is in the for data, true labels, true type, predicted type, psuedo labels + new_dataset = TransformedTensorDataset( + tensor_dataset=torch.utils.data.TensorDataset( + *session_dataset.tensor_dataset.tensors, pseudo_labels_aligned.cpu() + ), + transform=session_dataset.transform, + ) + + return new_dataset, mapping diff --git a/entcl/utils/ood.py b/entcl/utils/ood.py index 4656bd5807be178de23bc88a4edb66731e76164d..c0b8daa21b670185283021c66e7913accce711a6 100644 --- a/entcl/utils/ood.py +++ b/entcl/utils/ood.py @@ -1,10 +1,13 @@ +import os from typing import Iterable, Tuple, Union from entcl.data.util import TransformedTensorDataset +from entcl.utils.util import generate_unique_path from loguru import logger +import numpy as np +import pandas as pd from sklearn.mixture import GaussianMixture import torch from tqdm import tqdm -from entcl.utils.findk import elbow, silhouette, gap def _get_scores( session_dataset: TransformedTensorDataset, model: torch.nn.Module, args @@ -39,11 +42,11 @@ def _get_scores( entropies, energies = [], [] - for x, _ in tqdm(session_loader, desc="Calculating Scores", leave=True, unit="batch"): + for x, _, _ in tqdm(session_loader, desc="Calculating Scores", leave=True, unit="batch"): x = x.to(args.device) with torch.no_grad(): logits, _ = model(x) - + logits = logits[:, : args.dataset.known] # TODO this is just a stopgap. the head needs to be expanded after all if compute_entropy: softmax = torch.nn.functional.softmax(logits, dim=1) entropy = -torch.sum( @@ -72,7 +75,7 @@ def _fit_predict_gmm(data: torch.Tensor, args) -> torch.Tensor: :param args: Object with the attribute `seed`. """ logger.debug(f"Fitting Gaussian Mixture Model to Data") - gmm = GaussianMixture(n_components=2, random_state=args.seed) + gmm = GaussianMixture(n_components=2, random_state=args.seed, max_iter=1000, init_params='kmeans', tol=1e-4) # Fit and predict using the GMM predtypes_hard = torch.tensor( @@ -149,7 +152,6 @@ def label_ood_for_session( args, session_dataset: TransformedTensorDataset, model: torch.nn.Module, - return_new_dataset: bool = True, ) -> Union[TransformedTensorDataset, torch.Tensor]: """ OOD Labelling for a session dataset. This function computes entropy and/or energy scores for the dataset based on `args.ood_score` and fits a Gaussian Mixture Model to each of the scores. The GMM has 2 components, one for in-distribution samples and one for OOD samples. The function then resolves conflicts between the entropy and energy predictions by selecting the type with the highest confidence. Finally, the function returns a new dataset for the session, including the predicted types. @@ -196,21 +198,62 @@ def label_ood_for_session( ) else: raise ValueError(f"Invalid OOD Score: {args.ood_score}") + + logger.debug("Returning New Dataset") + # return the new dataset with the predicted types + session_dataset = TransformedTensorDataset( + tensor_dataset=torch.utils.data.TensorDataset( + session_dataset.tensor_dataset.tensors[0], # the data + session_dataset.tensor_dataset.tensors[1], # the labels + session_dataset.tensor_dataset.tensors[2], # the real types + final_predtypes, # the predicted types (duh) + ), + transform=session_dataset.transform, + ) - if return_new_dataset: - logger.debug("Returning New Dataset") - # return the new dataset with the predicted types - session_dataset = TransformedTensorDataset( - tensor_dataset=torch.utils.data.TensorDataset( - session_dataset.tensor_dataset.tensors[0], # the data - session_dataset.tensor_dataset.tensors[1], # the labels - final_predtypes, # the predicted types (duh) - ), - transform=session_dataset.transform, - ) + # compute the OOD Accuracy + _compute_ood_accuracy(session_dataset, args) + + return session_dataset - return session_dataset - else: - logger.debug("Returning Predicted Types") - # return just the predicted types - return final_predtypes + +def _compute_ood_accuracy( + session_dataset: TransformedTensorDataset, args +) -> None: + """ + Computes the Accuracy of the OOD Labelling for a session dataset. + :param session_dataset: Dataset for the session with 3 tensors (data, labels, predicted types). + :param args: Objects with the attributes `dataset`. + :return: None + """ + + + + # Create the DataFrame + df = pd.DataFrame( + { + "label": session_dataset.tensor_dataset.tensors[1].cpu().numpy(), + "type": session_dataset.tensor_dataset.tensors[2].cpu().numpy(), + "predtype": session_dataset.tensor_dataset.tensors[3].cpu().numpy(), + } + ) + + df["is_correct"] = df["predtype"] == df["type"] + + + known = df[df["type"] == 0] + novel = df[df["type"] == 1] + + string = f"OOD Accuracies for Session {args.current_session}:" + string += f"\nKnown Samples Correct/Total (Accuracy%): {known['is_correct'].sum()}/{known.shape[0]} ({known['is_correct'].mean()*100:.4f}%)" + string += f"\nKnown Samples Incorrect/Total (Error%) : {len(known) - known['is_correct'].sum()}/{known.shape[0]} ({1 - known['is_correct'].mean()*100:.4f}%)" + string += f"\n" + string += f"\nNovel Samples Correct/Total (Accuracy%): {novel['is_correct'].sum()}/{novel.shape[0]} ({novel['is_correct'].mean()*100:.4f}%)" + string += f"\nNovel Samples Incorrect/Total (Error%) : {len(novel) - novel['is_correct'].sum()}/{novel.shape[0]} ({1 - novel['is_correct'].mean()*100:.4f}%)" + string += f"\n" + string += f"\nOverall Accuracy: {df['is_correct'].mean()*100:.4f}%" + logger.info(string) + + file_path = generate_unique_path(os.path.join(args.exp_dir, f'ood_accuracy_{args.currect_session}.csv')) + df.to_csv(file_path, index=False) + logger.info(f"OOD Accuracy CSV saved to {file_path}") \ No newline at end of file