From 0c28c917ee9f0e797c969be6844b67d3e17f7dcb Mon Sep 17 00:00:00 2001
From: Joseph Omar <j.omar@soton.ac.uk>
Date: Fri, 15 Nov 2024 13:49:17 +0000
Subject: [PATCH] added ood detection

---
 entcl/cl.py             |   8 ++
 entcl/models/model.py   |   5 +-
 entcl/pretrain.py       |  22 ++--
 entcl/run.py            |  40 +++++---
 entcl/utils/__init__.py |   0
 entcl/utils/ood.py      | 217 ++++++++++++++++++++++++++++++++++++++++
 entcl/utils/util.py     |  15 +++
 7 files changed, 284 insertions(+), 23 deletions(-)
 create mode 100644 entcl/cl.py
 create mode 100644 entcl/utils/__init__.py
 create mode 100644 entcl/utils/ood.py
 create mode 100644 entcl/utils/util.py

diff --git a/entcl/cl.py b/entcl/cl.py
new file mode 100644
index 0000000..2aea96b
--- /dev/null
+++ b/entcl/cl.py
@@ -0,0 +1,8 @@
+from typing import Optional, Union
+from entcl.data.util import TransformedTensorDataset
+from entcl.models.model import ENTCLModel
+import torch
+from loguru import logger
+from sklearn.mixture import GaussianMixture
+import tqdm
+
diff --git a/entcl/models/model.py b/entcl/models/model.py
index d1327f3..9c9f4ba 100644
--- a/entcl/models/model.py
+++ b/entcl/models/model.py
@@ -3,11 +3,12 @@ from loguru import logger
 import torch
 
 class ENTCLModel(torch.nn.Module):
-    def __init__(self, head: torch.nn.Module):
+    def __init__(self, head: torch.nn.Module, backbone_url: str, backbone: str, backbone_source: str):
         super().__init__()
         
         # load the backbone
-        self.backbone = torch.hub.load(os.path.join(os.path.dirname(__file__), 'dinov2'), 'dinov2_vitb14', source='local')
+        self.backbone = torch.hub.load(backbone_url, backbone, source=backbone_source)
+        logger.debug(f"Loaded backbone: {backbone} from {backbone_url} (src: {backbone_source})")
         
         # freeze the backbone
         for param in self.backbone.parameters():
diff --git a/entcl/pretrain.py b/entcl/pretrain.py
index 26e4858..25e9295 100644
--- a/entcl/pretrain.py
+++ b/entcl/pretrain.py
@@ -50,14 +50,14 @@ def pretrain(args, model):
             logger.debug(f"Epoch {epoch} Started")
             # train model
             print(f"Epoch {epoch}:")
-            model, loss_total = _train(args, model, train_dataloader, optimiser, criterion)
-            logger.info(f"Epoch {epoch}: Loss: {loss_total}")
+            model, train_loss = _train(args, model, train_dataloader, optimiser, criterion)
+            logger.info(f"Epoch {epoch}: Loss: {train_loss}")
             
             # validate model
-            model, accuracy = _validate(args, model, val_dataloader)
+            model, accuracy, val_loss = _validate(args, model, val_dataloader, criterion)
             accuracies.append(accuracy)
             
-            logger.info(f"Epoch {epoch}: Accuracy: {accuracy}")
+            logger.info(f"Epoch {epoch}: Accuracy: {accuracy}, Loss: {val_loss}")
             torch.save(model.state_dict(), os.path.join(args.exp_dir, f"model_{epoch}.pt"))
         
         # select the best model
@@ -101,10 +101,11 @@ def _train(args, model, train_dataloader, optimiser, criterion):
     loss_total /= len(train_dataloader)
     return model, loss_total
 
-def _validate(args, model, val_dataloader):
+def _validate(args, model, val_dataloader, criterion):
     model.eval()
     correct = 0
     total = 0
+    loss_total = 0
     with torch.no_grad():
         for x, y in tqdm(val_dataloader, desc=f"Validating", unit = "batch"):
             x, y = x.to(args.device), y.to(args.device)
@@ -112,12 +113,17 @@ def _validate(args, model, val_dataloader):
                 
             logits, _ = model(x)
             logger.debug(f"logits shape: {logits.shape}")
-                
+            
+            loss = criterion(logits, y)
+            loss = loss.item()
+            loss_total += loss
+            
             _, predicted = torch.max(logits, 1)
                 
             num_correct = (predicted == y).sum().item()
             total += y.size(0)
             correct += num_correct
-            logger.debug(f"This Batch Num Correct: {num_correct}, Total: {y.size(0)}, Accuracy: {num_correct / y.size(0)}")
             
-    return model, correct / total
+            logger.debug(f"This Batch Num Correct: {num_correct}, Total: {y.size(0)}, Accuracy: {num_correct / y.size(0)}, Loss: {loss}")
+            
+    return model, correct / total, loss_total / len(val_dataloader)
diff --git a/entcl/run.py b/entcl/run.py
index b1fff58..a27f8c6 100644
--- a/entcl/run.py
+++ b/entcl/run.py
@@ -5,14 +5,23 @@ from loguru import logger
 from datetime import datetime
 import torch
 
+from entcl.utils.util import seed
 
 from entcl.models.model import ENTCLModel
 from entcl.pretrain import pretrain
 
 
 @logger.catch
-def main():
+def main(args: argparse.Namespace):
+    model = ENTCLModel(head=args.head, backbone_url=args.backbone_url, backbone=args.backbone, backbone_source=args.backbone_source)
+    logger.debug(f"Model: {model}")
     
+    model = pretrain(args, model)
+
+
+
+if __name__ == "__main__":
+    logger.debug("Entry Point: run.py")
     parser = argparse.ArgumentParser()
     # program args
     parser.add_argument('--name', type=str, default="entcl_" + datetime.now().isoformat(timespec='seconds'))
@@ -35,10 +44,10 @@ def main():
     parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset to use', choices=['cifar100'])
     
     # optimiser args
-    parser.add_argument('--lr', type=float, default=0.1, help='Learning Rate for all optimisers')
+    parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate for all optimisers')
     parser.add_argument('--gamma', type=float, default=0.1, help='Gamma for all optimisers')
-    parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for all optimisers')
-    parser.add_argument('--weight_decay', type=float, default=5e-5, help='Weight Decay for all optimisers')
+    parser.add_argument('--momentum', type=float, default=0, help='Momentum for all optimisers')
+    parser.add_argument('--weight_decay', type=float, default=0, help='Weight Decay for all optimisers')
     
     # cl args
     parser.add_argument('--known', type=int, default=50, help='Number of known classes. The rest are novel classes')
@@ -60,8 +69,17 @@ def main():
     
     # model args
     parser.add_argument('--head', type=str, default='linear2', help='Classification head to use', choices=['linear','linear2', 'dino_head'])
+    parser.add_argument('--backbone_url', type=str, default="/cl/entcl/entcl/models/dinov2", help="URL to the repo containing the backbone model")
+    parser.add_argument("--backbone", type=str, default="dinov2_vitb14", help="Name of the backbone model to use")
+    parser.add_argument("--backbone_source", type=str, default="local", help="Source of the backbone model")
+    
+    # ood args
+    parser.add_argument('--ood_score', type=str, default='entropy', help='Changes the metric(s) to base OOD detection on', choices=['entropy', 'energy', 'both'])
+    parser.add_argument('--ood_eps', type=float, default=1e-8, help='Epsilon value for computing entropy in OOD detection')
     args = parser.parse_args()
     
+    seed(args.seed) # seed everything
+    
     # setup device
     if not torch.cuda.is_available():
         raise ValueError("CUDA not available")
@@ -107,12 +125,8 @@ def main():
         args.head = LinearHead2(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256)
         logger.debug(f"Using Linear2 Head: {args.head}")
         
-    model = ENTCLModel(head=args.head)
-    
-    logger.debug(f"Model: {model}")
-    
-    model = pretrain(args, model)
-
-if __name__ == "__main__":
-    logger.debug("Entry Point: run.py")
-    main()
\ No newline at end of file
+    argstr = "Arguments: \n"
+    for arg in vars(args):
+        argstr += f"{arg}: {getattr(args, arg)}\n"
+        
+    main(args)
\ No newline at end of file
diff --git a/entcl/utils/__init__.py b/entcl/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/entcl/utils/ood.py b/entcl/utils/ood.py
new file mode 100644
index 0000000..c31ac1e
--- /dev/null
+++ b/entcl/utils/ood.py
@@ -0,0 +1,217 @@
+from typing import Iterable, Tuple, Union
+from entcl.data.util import TransformedTensorDataset
+from entcl.models.model import ENTCLModel
+from loguru import logger
+from sklearn.mixture import GaussianMixture
+import torch
+from tqdm import tqdm
+
+
+def _get_scores(
+    loader: Union[torch.utils.data.DataLoader, Iterable], model: ENTCLModel, args
+) -> Union[
+    Tuple[torch.Tensor, torch.Tensor],
+    Tuple[None, None],
+    Tuple[torch.Tensor, None],
+    Tuple[None, torch.Tensor],
+]:
+    """
+    Computes entropy and/or energy scores for the dataset based on `args.ood_score`.
+
+    :param loader: a Dataloader or Iterable for the dataset.
+    :param model: The model to evaluate.
+    :param args: Object with the attributes `ood_score` and `device`.
+    :return: Tuple of torch.Tensors with the entropy and energy scores respectively.
+    """
+    logger.debug(f"Getting Scores: {args.ood_score}")
+    model.eval()
+
+    compute_entropy = args.ood_score in ["entropy", "all"]
+    compute_energy = args.ood_score in ["energy", "all"]
+
+    entropies, energies = [], []
+
+    for x, _ in tqdm(loader, desc="Calculating Scores", leave=True, unit="batch"):
+        x = x.to(args.device)
+        with torch.no_grad():
+            logits, _ = model(x)
+
+            if compute_entropy:
+                softmax = torch.nn.functional.softmax(logits, dim=1)
+                entropy = -torch.sum(
+                    softmax * torch.log(softmax + args.ood_eps), dim=1
+                )  # Added epsilon for numerical stability ood_eps is 1e-8 by default
+                entropies.append(entropy)
+
+            if compute_energy:
+                energy = -torch.logsumexp(logits, dim=1)
+                energies.append(energy)
+
+    logger.debug(
+        f"Scores Calculated: Entropy: {len(entropies)} batches, Energy: {len(energies)} batches"
+    )
+
+    entropies = torch.cat(entropies) if compute_entropy else None
+    energies = torch.cat(energies) if compute_energy else None
+
+    return entropies, energies
+
+
+def _fit_predict_gmm(data: torch.Tensor, args) -> torch.Tensor:
+    """
+    Helper function to fit a Gaussian Mixture Model to the data
+    :param data: Tensor of shape [N] with the data to fit the GMM to.
+    :param args: Object with the attribute `seed`.
+    """
+    logger.debug(f"Fitting Gaussian Mixture Model to Data")
+    gmm = GaussianMixture(n_components=2, random_state=args.seed)
+
+    # Fit and predict using the GMM
+    predtypes_hard = torch.tensor(
+        gmm.fit_predict(data.view(-1, 1).cpu().numpy()), device=data.device
+    )
+    predtypes_soft = torch.tensor(
+        gmm.predict_proba(data.view(-1, 1).cpu().numpy()), device=data.device
+    )
+
+    # Retrieve the means of the two clusters
+    mean_0 = torch.mean(data[predtypes_hard == 0], dim=0)
+    mean_1 = torch.mean(data[predtypes_hard == 1], dim=0)
+
+    logger.debug(f"Mean for Type 0: {mean_0}, Mean for Type 1: {mean_1}")
+
+    # Swapping clusters if necessary
+    if mean_1 < mean_0:
+        logger.debug("Type 1 has lower mean than Type 0. Swapping types")
+        predtypes_hard = 1 - predtypes_hard  # Swap the types
+        predtypes_soft = predtypes_soft[
+            :, [1, 0]
+        ]  # Swap probability columns so that the first column is the probability of type 0 and the second column is the probability of type 1
+    else:
+        logger.debug("Type 0 has lower mean than Type 1. Keeping types")
+
+    return predtypes_hard, predtypes_soft
+
+
+def _resolve_conflicts(
+    entropy_predtypes_soft: torch.Tensor, energy_predtypes_soft: torch.Tensor
+) -> torch.Tensor:
+    """
+    Resolves conflicts between entropy and energy predictions by selecting the type with the highest confidence.
+
+    :param entropy_predtypes_soft: Tensor of shape [N, 2] with soft predictions from the entropy GMM.
+    :param energy_predtypes_soft: Tensor of shape [N, 2] with soft predictions from the energy GMM.
+    :return: Tensor of shape [N] with resolved hard predictions (0 or 1).
+    """
+    logger.debug("Resolving Conflicts")
+    assert (
+        entropy_predtypes_soft.shape == energy_predtypes_soft.shape
+    ), f"Entropy and Energy predictions must have the same shape. Got Entropy.shape: {entropy_predtypes_soft.shape}, Energy.shape: {energy_predtypes_soft.shape}"
+
+    # Compute hard predictions and their confidence scores
+    entropy_predtypes_hard = torch.argmax(entropy_predtypes_soft, dim=1)
+    energy_predtypes_hard = torch.argmax(energy_predtypes_soft, dim=1)
+
+    # for each sample, get the confidence of the predicted type
+    entropy_confidence = entropy_predtypes_soft[
+        torch.arange(entropy_predtypes_soft.size(0)), entropy_predtypes_hard
+    ]
+    energy_confidence = energy_predtypes_soft[
+        torch.arange(energy_predtypes_soft.size(0)), energy_predtypes_hard
+    ]
+
+    # Resolve conflicts by selecting the type with the highest confidence
+    # torch.where(condition, x, y) returns x if condition is True, otherwise y
+    resolved_predictions = torch.where(
+        entropy_predtypes_hard
+        == energy_predtypes_hard,  # if the predictions are the same
+        entropy_predtypes_hard,  # return the prediction
+        torch.where(  # otherwise
+            energy_confidence
+            > entropy_confidence,  # if the energy confidence is higher
+            energy_predtypes_hard,  # use the energy prediction
+            entropy_predtypes_hard,
+        ),  # otherwise use the entropy prediction
+    )
+
+    return resolved_predictions
+
+
+def label_ood_for_session(
+    args,
+    session_dataset: TransformedTensorDataset,
+    model: ENTCLModel,
+    return_new_dataset: bool = False,
+) -> Union[TransformedTensorDataset, torch.Tensor]:
+    """
+    OOD Labelling for a session dataset. This function computes entropy and/or energy scores for the dataset based on `args.ood_score` and fits a Gaussian Mixture Model to each of the scores. The GMM has 2 components, one for in-distribution samples and one for OOD samples. The function then resolves conflicts between the entropy and energy predictions by selecting the type with the highest confidence. Finally, the function returns a new dataset for the session, including the predicted types.
+    :param args: Objects with the attributes `ood_score`, `ood_eps`, `seed` and `device` (Program Arguments).
+    :param session_dataset: Dataset for the session.
+    :param model: The model to evaluate.
+    :param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted types (useful when theres more to do before training).
+    :return: A TransformedTensorDataset with the predicted types or just a torch.Tensor of the predicted types.
+    """
+    logger.debug("Starting OOD Labelling for Session")
+
+    # first we dataload the session dataset, for memory efficiency
+    session_loader = torch.utils.data.DataLoader(
+        session_dataset,
+        batch_size=args.batch_size,
+        num_workers=args.num_workers,
+        pin_memory=args.pin_memory,
+        shuffle=False,
+    )
+
+    # next we run the dataset through the model to retrieve an entropy/energy score for each sample
+    entropies, energies = _get_scores(session_loader, model, args)
+
+    # next step depends on the selected OOD Score
+    # placeholder for the final predictions
+    final_predtypes = None
+    # if we are using entropy only
+    if args.ood_score == "entropy":
+        logger.debug("Using Entropy Only")
+        predtypes_hard, _ = _fit_predict_gmm(
+            entropies, args
+        )  # fit a GMM to the entropy scores, we do not care about the soft predictions
+        final_predtypes = predtypes_hard
+
+    # if we are using energy only
+    elif args.ood_score == "energy":
+        logger.debug("Using Energy Only")
+        predtypes_hard, _ = _fit_predict_gmm(energies, args)
+        final_predtypes = predtypes_hard
+
+    # if we are using both entropy and energy
+    elif args.ood_score == "both":
+        logger.debug("Using Both Entropy and Energy")
+        _, entropy_predtypes_soft = _fit_predict_gmm(
+            entropies, args
+        )  # we do not care about the hard predictions
+        _, energy_predtypes_soft = _fit_predict_gmm(
+            energies, args
+        )  # we do not care about the hard predictions
+
+        final_predtypes = _resolve_conflicts(
+            entropy_predtypes_soft, energy_predtypes_soft
+        )
+    else:
+        raise ValueError(f"Invalid OOD Score: {args.ood_score}")
+
+    if return_new_dataset:
+        logger.debug("Returning New Dataset")
+        # return the new dataset with the predicted types
+        session_dataset = TransformedTensorDataset(
+            tensor_dataset=torch.utils.data.TensorDataset(
+                session_dataset.tensor_dataset.tensors[0],  # the data
+                session_dataset.tensor_dataset.tensors[1],  # the labels
+                final_predtypes,  # the predicted types (duh)
+            ),
+            transform=session_dataset.transform,
+        )
+
+        return session_dataset
+    else:
+        logger.debug("Returning Predicted Types")
+        # return just the predicted types
+        return final_predtypes
diff --git a/entcl/utils/util.py b/entcl/utils/util.py
new file mode 100644
index 0000000..fae6827
--- /dev/null
+++ b/entcl/utils/util.py
@@ -0,0 +1,15 @@
+
+import torch
+import numpy as np
+import os
+import random
+
+def seed(seed=8008135):
+    random.seed(seed)
+    os.environ["PYTHONHASHSEED"] = str(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
+    torch.backends.cudnn.benchmark = False
+    torch.backends.cudnn.deterministic = True
\ No newline at end of file
-- 
GitLab