diff --git a/entcl/cl.py b/entcl/cl.py new file mode 100644 index 0000000000000000000000000000000000000000..2aea96b164189a49d9ab57c24403faa201e43afd --- /dev/null +++ b/entcl/cl.py @@ -0,0 +1,8 @@ +from typing import Optional, Union +from entcl.data.util import TransformedTensorDataset +from entcl.models.model import ENTCLModel +import torch +from loguru import logger +from sklearn.mixture import GaussianMixture +import tqdm + diff --git a/entcl/models/model.py b/entcl/models/model.py index d1327f3fdbabb442d2686e368362afb344130ce4..9c9f4ba9177d583b2ad4304e89be697d8e4f4cae 100644 --- a/entcl/models/model.py +++ b/entcl/models/model.py @@ -3,11 +3,12 @@ from loguru import logger import torch class ENTCLModel(torch.nn.Module): - def __init__(self, head: torch.nn.Module): + def __init__(self, head: torch.nn.Module, backbone_url: str, backbone: str, backbone_source: str): super().__init__() # load the backbone - self.backbone = torch.hub.load(os.path.join(os.path.dirname(__file__), 'dinov2'), 'dinov2_vitb14', source='local') + self.backbone = torch.hub.load(backbone_url, backbone, source=backbone_source) + logger.debug(f"Loaded backbone: {backbone} from {backbone_url} (src: {backbone_source})") # freeze the backbone for param in self.backbone.parameters(): diff --git a/entcl/pretrain.py b/entcl/pretrain.py index 26e48587615a08edbbfee5b3a5a610744c1115e4..25e9295baed72906f271a720f552874c39037d17 100644 --- a/entcl/pretrain.py +++ b/entcl/pretrain.py @@ -50,14 +50,14 @@ def pretrain(args, model): logger.debug(f"Epoch {epoch} Started") # train model print(f"Epoch {epoch}:") - model, loss_total = _train(args, model, train_dataloader, optimiser, criterion) - logger.info(f"Epoch {epoch}: Loss: {loss_total}") + model, train_loss = _train(args, model, train_dataloader, optimiser, criterion) + logger.info(f"Epoch {epoch}: Loss: {train_loss}") # validate model - model, accuracy = _validate(args, model, val_dataloader) + model, accuracy, val_loss = _validate(args, model, val_dataloader, criterion) accuracies.append(accuracy) - logger.info(f"Epoch {epoch}: Accuracy: {accuracy}") + logger.info(f"Epoch {epoch}: Accuracy: {accuracy}, Loss: {val_loss}") torch.save(model.state_dict(), os.path.join(args.exp_dir, f"model_{epoch}.pt")) # select the best model @@ -101,10 +101,11 @@ def _train(args, model, train_dataloader, optimiser, criterion): loss_total /= len(train_dataloader) return model, loss_total -def _validate(args, model, val_dataloader): +def _validate(args, model, val_dataloader, criterion): model.eval() correct = 0 total = 0 + loss_total = 0 with torch.no_grad(): for x, y in tqdm(val_dataloader, desc=f"Validating", unit = "batch"): x, y = x.to(args.device), y.to(args.device) @@ -112,12 +113,17 @@ def _validate(args, model, val_dataloader): logits, _ = model(x) logger.debug(f"logits shape: {logits.shape}") - + + loss = criterion(logits, y) + loss = loss.item() + loss_total += loss + _, predicted = torch.max(logits, 1) num_correct = (predicted == y).sum().item() total += y.size(0) correct += num_correct - logger.debug(f"This Batch Num Correct: {num_correct}, Total: {y.size(0)}, Accuracy: {num_correct / y.size(0)}") - return model, correct / total + logger.debug(f"This Batch Num Correct: {num_correct}, Total: {y.size(0)}, Accuracy: {num_correct / y.size(0)}, Loss: {loss}") + + return model, correct / total, loss_total / len(val_dataloader) diff --git a/entcl/run.py b/entcl/run.py index b1fff5811902eaea321e31770add2beec1b15106..a27f8c61316b3b4755f924a8d29cb92e247609ed 100644 --- a/entcl/run.py +++ b/entcl/run.py @@ -5,14 +5,23 @@ from loguru import logger from datetime import datetime import torch +from entcl.utils.util import seed from entcl.models.model import ENTCLModel from entcl.pretrain import pretrain @logger.catch -def main(): +def main(args: argparse.Namespace): + model = ENTCLModel(head=args.head, backbone_url=args.backbone_url, backbone=args.backbone, backbone_source=args.backbone_source) + logger.debug(f"Model: {model}") + model = pretrain(args, model) + + + +if __name__ == "__main__": + logger.debug("Entry Point: run.py") parser = argparse.ArgumentParser() # program args parser.add_argument('--name', type=str, default="entcl_" + datetime.now().isoformat(timespec='seconds')) @@ -35,10 +44,10 @@ def main(): parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset to use', choices=['cifar100']) # optimiser args - parser.add_argument('--lr', type=float, default=0.1, help='Learning Rate for all optimisers') + parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate for all optimisers') parser.add_argument('--gamma', type=float, default=0.1, help='Gamma for all optimisers') - parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for all optimisers') - parser.add_argument('--weight_decay', type=float, default=5e-5, help='Weight Decay for all optimisers') + parser.add_argument('--momentum', type=float, default=0, help='Momentum for all optimisers') + parser.add_argument('--weight_decay', type=float, default=0, help='Weight Decay for all optimisers') # cl args parser.add_argument('--known', type=int, default=50, help='Number of known classes. The rest are novel classes') @@ -60,8 +69,17 @@ def main(): # model args parser.add_argument('--head', type=str, default='linear2', help='Classification head to use', choices=['linear','linear2', 'dino_head']) + parser.add_argument('--backbone_url', type=str, default="/cl/entcl/entcl/models/dinov2", help="URL to the repo containing the backbone model") + parser.add_argument("--backbone", type=str, default="dinov2_vitb14", help="Name of the backbone model to use") + parser.add_argument("--backbone_source", type=str, default="local", help="Source of the backbone model") + + # ood args + parser.add_argument('--ood_score', type=str, default='entropy', help='Changes the metric(s) to base OOD detection on', choices=['entropy', 'energy', 'both']) + parser.add_argument('--ood_eps', type=float, default=1e-8, help='Epsilon value for computing entropy in OOD detection') args = parser.parse_args() + seed(args.seed) # seed everything + # setup device if not torch.cuda.is_available(): raise ValueError("CUDA not available") @@ -107,12 +125,8 @@ def main(): args.head = LinearHead2(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256) logger.debug(f"Using Linear2 Head: {args.head}") - model = ENTCLModel(head=args.head) - - logger.debug(f"Model: {model}") - - model = pretrain(args, model) - -if __name__ == "__main__": - logger.debug("Entry Point: run.py") - main() \ No newline at end of file + argstr = "Arguments: \n" + for arg in vars(args): + argstr += f"{arg}: {getattr(args, arg)}\n" + + main(args) \ No newline at end of file diff --git a/entcl/utils/__init__.py b/entcl/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/entcl/utils/ood.py b/entcl/utils/ood.py new file mode 100644 index 0000000000000000000000000000000000000000..c31ac1e0d46e5c460b6f17af85e591bb142b4738 --- /dev/null +++ b/entcl/utils/ood.py @@ -0,0 +1,217 @@ +from typing import Iterable, Tuple, Union +from entcl.data.util import TransformedTensorDataset +from entcl.models.model import ENTCLModel +from loguru import logger +from sklearn.mixture import GaussianMixture +import torch +from tqdm import tqdm + + +def _get_scores( + loader: Union[torch.utils.data.DataLoader, Iterable], model: ENTCLModel, args +) -> Union[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[None, None], + Tuple[torch.Tensor, None], + Tuple[None, torch.Tensor], +]: + """ + Computes entropy and/or energy scores for the dataset based on `args.ood_score`. + + :param loader: a Dataloader or Iterable for the dataset. + :param model: The model to evaluate. + :param args: Object with the attributes `ood_score` and `device`. + :return: Tuple of torch.Tensors with the entropy and energy scores respectively. + """ + logger.debug(f"Getting Scores: {args.ood_score}") + model.eval() + + compute_entropy = args.ood_score in ["entropy", "all"] + compute_energy = args.ood_score in ["energy", "all"] + + entropies, energies = [], [] + + for x, _ in tqdm(loader, desc="Calculating Scores", leave=True, unit="batch"): + x = x.to(args.device) + with torch.no_grad(): + logits, _ = model(x) + + if compute_entropy: + softmax = torch.nn.functional.softmax(logits, dim=1) + entropy = -torch.sum( + softmax * torch.log(softmax + args.ood_eps), dim=1 + ) # Added epsilon for numerical stability ood_eps is 1e-8 by default + entropies.append(entropy) + + if compute_energy: + energy = -torch.logsumexp(logits, dim=1) + energies.append(energy) + + logger.debug( + f"Scores Calculated: Entropy: {len(entropies)} batches, Energy: {len(energies)} batches" + ) + + entropies = torch.cat(entropies) if compute_entropy else None + energies = torch.cat(energies) if compute_energy else None + + return entropies, energies + + +def _fit_predict_gmm(data: torch.Tensor, args) -> torch.Tensor: + """ + Helper function to fit a Gaussian Mixture Model to the data + :param data: Tensor of shape [N] with the data to fit the GMM to. + :param args: Object with the attribute `seed`. + """ + logger.debug(f"Fitting Gaussian Mixture Model to Data") + gmm = GaussianMixture(n_components=2, random_state=args.seed) + + # Fit and predict using the GMM + predtypes_hard = torch.tensor( + gmm.fit_predict(data.view(-1, 1).cpu().numpy()), device=data.device + ) + predtypes_soft = torch.tensor( + gmm.predict_proba(data.view(-1, 1).cpu().numpy()), device=data.device + ) + + # Retrieve the means of the two clusters + mean_0 = torch.mean(data[predtypes_hard == 0], dim=0) + mean_1 = torch.mean(data[predtypes_hard == 1], dim=0) + + logger.debug(f"Mean for Type 0: {mean_0}, Mean for Type 1: {mean_1}") + + # Swapping clusters if necessary + if mean_1 < mean_0: + logger.debug("Type 1 has lower mean than Type 0. Swapping types") + predtypes_hard = 1 - predtypes_hard # Swap the types + predtypes_soft = predtypes_soft[ + :, [1, 0] + ] # Swap probability columns so that the first column is the probability of type 0 and the second column is the probability of type 1 + else: + logger.debug("Type 0 has lower mean than Type 1. Keeping types") + + return predtypes_hard, predtypes_soft + + +def _resolve_conflicts( + entropy_predtypes_soft: torch.Tensor, energy_predtypes_soft: torch.Tensor +) -> torch.Tensor: + """ + Resolves conflicts between entropy and energy predictions by selecting the type with the highest confidence. + + :param entropy_predtypes_soft: Tensor of shape [N, 2] with soft predictions from the entropy GMM. + :param energy_predtypes_soft: Tensor of shape [N, 2] with soft predictions from the energy GMM. + :return: Tensor of shape [N] with resolved hard predictions (0 or 1). + """ + logger.debug("Resolving Conflicts") + assert ( + entropy_predtypes_soft.shape == energy_predtypes_soft.shape + ), f"Entropy and Energy predictions must have the same shape. Got Entropy.shape: {entropy_predtypes_soft.shape}, Energy.shape: {energy_predtypes_soft.shape}" + + # Compute hard predictions and their confidence scores + entropy_predtypes_hard = torch.argmax(entropy_predtypes_soft, dim=1) + energy_predtypes_hard = torch.argmax(energy_predtypes_soft, dim=1) + + # for each sample, get the confidence of the predicted type + entropy_confidence = entropy_predtypes_soft[ + torch.arange(entropy_predtypes_soft.size(0)), entropy_predtypes_hard + ] + energy_confidence = energy_predtypes_soft[ + torch.arange(energy_predtypes_soft.size(0)), energy_predtypes_hard + ] + + # Resolve conflicts by selecting the type with the highest confidence + # torch.where(condition, x, y) returns x if condition is True, otherwise y + resolved_predictions = torch.where( + entropy_predtypes_hard + == energy_predtypes_hard, # if the predictions are the same + entropy_predtypes_hard, # return the prediction + torch.where( # otherwise + energy_confidence + > entropy_confidence, # if the energy confidence is higher + energy_predtypes_hard, # use the energy prediction + entropy_predtypes_hard, + ), # otherwise use the entropy prediction + ) + + return resolved_predictions + + +def label_ood_for_session( + args, + session_dataset: TransformedTensorDataset, + model: ENTCLModel, + return_new_dataset: bool = False, +) -> Union[TransformedTensorDataset, torch.Tensor]: + """ + OOD Labelling for a session dataset. This function computes entropy and/or energy scores for the dataset based on `args.ood_score` and fits a Gaussian Mixture Model to each of the scores. The GMM has 2 components, one for in-distribution samples and one for OOD samples. The function then resolves conflicts between the entropy and energy predictions by selecting the type with the highest confidence. Finally, the function returns a new dataset for the session, including the predicted types. + :param args: Objects with the attributes `ood_score`, `ood_eps`, `seed` and `device` (Program Arguments). + :param session_dataset: Dataset for the session. + :param model: The model to evaluate. + :param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted types (useful when theres more to do before training). + :return: A TransformedTensorDataset with the predicted types or just a torch.Tensor of the predicted types. + """ + logger.debug("Starting OOD Labelling for Session") + + # first we dataload the session dataset, for memory efficiency + session_loader = torch.utils.data.DataLoader( + session_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + shuffle=False, + ) + + # next we run the dataset through the model to retrieve an entropy/energy score for each sample + entropies, energies = _get_scores(session_loader, model, args) + + # next step depends on the selected OOD Score + # placeholder for the final predictions + final_predtypes = None + # if we are using entropy only + if args.ood_score == "entropy": + logger.debug("Using Entropy Only") + predtypes_hard, _ = _fit_predict_gmm( + entropies, args + ) # fit a GMM to the entropy scores, we do not care about the soft predictions + final_predtypes = predtypes_hard + + # if we are using energy only + elif args.ood_score == "energy": + logger.debug("Using Energy Only") + predtypes_hard, _ = _fit_predict_gmm(energies, args) + final_predtypes = predtypes_hard + + # if we are using both entropy and energy + elif args.ood_score == "both": + logger.debug("Using Both Entropy and Energy") + _, entropy_predtypes_soft = _fit_predict_gmm( + entropies, args + ) # we do not care about the hard predictions + _, energy_predtypes_soft = _fit_predict_gmm( + energies, args + ) # we do not care about the hard predictions + + final_predtypes = _resolve_conflicts( + entropy_predtypes_soft, energy_predtypes_soft + ) + else: + raise ValueError(f"Invalid OOD Score: {args.ood_score}") + + if return_new_dataset: + logger.debug("Returning New Dataset") + # return the new dataset with the predicted types + session_dataset = TransformedTensorDataset( + tensor_dataset=torch.utils.data.TensorDataset( + session_dataset.tensor_dataset.tensors[0], # the data + session_dataset.tensor_dataset.tensors[1], # the labels + final_predtypes, # the predicted types (duh) + ), + transform=session_dataset.transform, + ) + + return session_dataset + else: + logger.debug("Returning Predicted Types") + # return just the predicted types + return final_predtypes diff --git a/entcl/utils/util.py b/entcl/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..fae6827dce29e499ed6b0928759562154283ffda --- /dev/null +++ b/entcl/utils/util.py @@ -0,0 +1,15 @@ + +import torch +import numpy as np +import os +import random + +def seed(seed=8008135): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True \ No newline at end of file