From c4535f84dd50e9359e2e7d4bd6db14cd0d6ac64a Mon Sep 17 00:00:00 2001
From: Joseph Omar <j.omar@soton.ac.uk>
Date: Mon, 18 Nov 2024 10:16:29 +0000
Subject: [PATCH] working on ncd. should run offline okay tho :)

---
 entcl/cl.py        |   7 --
 entcl/run.py       |   1 +
 entcl/utils/ncd.py | 205 +++++++++++++++++++++++++++++++++++++++++++++
 entcl/utils/ood.py |  31 ++++---
 4 files changed, 221 insertions(+), 23 deletions(-)
 create mode 100644 entcl/utils/ncd.py

diff --git a/entcl/cl.py b/entcl/cl.py
index 2aea96b..8b13789 100644
--- a/entcl/cl.py
+++ b/entcl/cl.py
@@ -1,8 +1 @@
-from typing import Optional, Union
-from entcl.data.util import TransformedTensorDataset
-from entcl.models.model import ENTCLModel
-import torch
-from loguru import logger
-from sklearn.mixture import GaussianMixture
-import tqdm
 
diff --git a/entcl/run.py b/entcl/run.py
index a27f8c6..dcdd668 100644
--- a/entcl/run.py
+++ b/entcl/run.py
@@ -110,6 +110,7 @@ if __name__ == "__main__":
     # initialise dataset
     if args.dataset == 'cifar100':
         from entcl.data.cifar100 import CIFAR100Dataset
+        args.novel_classes_per_session = (100 - args.known) // args.sessions
         args.dataset = CIFAR100Dataset(known=args.known, pretrain_n_known=args.pretrain_n_known, cl_n_known=args.cl_n_known, cl_n_novel=args.cl_n_novel, cl_n_prevnovel=args.cl_n_prevnovel, sessions=5)
         
         
diff --git a/entcl/utils/ncd.py b/entcl/utils/ncd.py
new file mode 100644
index 0000000..7093397
--- /dev/null
+++ b/entcl/utils/ncd.py
@@ -0,0 +1,205 @@
+
+
+from entcl.data.util import TransformedTensorDataset
+from loguru import logger
+import numpy as np
+from sklearn.cluster import KMeans
+from sklearn.metrics import confusion_matrix
+from sklearn.utils.linear_assignment_ import linear_sum_assignment
+import torch
+from tqdm import tqdm
+
+def _extract_features(args, dataset: TransformedTensorDataset, model: torch.nn.Module) -> torch.Tensor:
+    """
+    Extracts features from the data in the dataset using the provided model's backbone.
+    :param args: Arguments object with the attributes `device`, .
+    :param dataset: TransformedTensorDataset with the data to extract features from.
+    """
+    logger.debug("Extracting Features")
+    dataloader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=args.batch_size,
+        num_workers=args.num_workers,
+        pin_memory=args.pin_memory,
+        shuffle=False,
+    )
+    
+    model.eval()
+    features = None
+    with torch.no_grad():
+        for batch in tqdm(dataloader, desc="Extracting Features", leave=True, unit="batch"):
+            x = batch[0] # Only the data is needed, assumes batch is iterable.
+            x = x.to(args.device)
+            x_proj, _= model(x) # we do not need the predicted logits, only the features.
+            if features is None:
+                features = x_proj
+            else:
+                features = torch.cat((features, x_proj), dim=0)
+    
+    return features
+
+def _elbow_findk(features: torch.Tensor, args) -> int:
+    """
+    Finds the optimal number of clusters using the elbow method and the kneed library.
+    :param features: torch.Tensor with the features to cluster.
+    :param args: Arguments object with the attribute `seed`.
+    :return: int with the optimal number of clusters.
+    """
+    from kneed import KneeLocator
+    inertias = []
+    
+    ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1)
+    logger.debug(f"Finding k using elbow method for k in {ks}")
+    
+    # Calculate the inertia for each k
+    for k in ks:
+        kmeans = KMeans(n_clusters=k, random_state=args.seed)
+        kmeans.fit(features)
+        inertias.append(kmeans.inertia_)
+        
+    logger.debug(f"Elbow Inertias: {inertias}")
+    logger.debug(f"Running KneeLocator")
+    
+    # Find the knee
+    kneedle = KneeLocator(ks, inertias, curve="convex", direction="decreasing")
+    logger.info(f"Elbow Method Found k: {kneedle.knee}")
+    return kneedle.knee
+    
+def _silhouette_findk(features: torch.Tensor, args) -> int:
+    """
+    Finds the optimal number of clusters using the silhouette method.
+    :param features: torch.Tensor with the features to cluster.
+    :param args: Arguments object with the attribute `seed`.
+    :return: int with the optimal number of clusters.
+    """
+    from sklearn.metrics import silhouette_score
+    silhouettes = []
+    
+    ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1)
+    logger.debug(f"Finding k using silhouette method for k in {ks}")
+    
+    # Calculate the silhouette score for each k
+    for k in ks:
+        kmeans = KMeans(n_clusters=k, random_state=args.seed)
+        pseudo_labels = kmeans.fit_predict(features)
+        silhouette = silhouette_score(features, pseudo_labels)
+        silhouettes.append(silhouette)
+        
+    logger.debug(f"Silhouette Scores: {silhouettes}")
+    
+    # Find the best k
+    best_k = ks[silhouettes.index(max(silhouettes))]
+    logger.info(f"Silhouette Method Found k: {best_k}")
+    return best_k
+
+def _gap_findk(features: torch.Tensor, args) -> int:
+    """
+    Finds the optimal number of clusters using the gap statistic.
+    :param features: torch.Tensor with the features to cluster.
+    :param args: Arguments object with the attribute `seed`.
+    :return: int with the optimal number of clusters.
+    """
+    logger.debug("Finding k using gap statistic")
+    features_np = features.cpu().numpy()
+    n_samples, n_features = features_np.shape
+    
+    min_vals = features_np.min(axis=0)
+    max_vals = features_np.max(axis=0)
+    
+    gap_stats = {}
+    wks = []
+    wks_refs = []
+    
+    for k in tqdm(range(args.ncd_findk_mink, args.ncd_findk_maxk + 1), desc="Calculating Gap Statistic", leave=True, unit="k"):
+        logger.debug(f"Calculating Gap Statistic for k={k}")
+        kmeans = KMeans(n_clusters=k, random_state=args.seed)
+        kmeans.fit(features_np)
+        wk = kmeans.inertia_
+        wks.append(np.log(wk))
+        
+        wk_refs = []
+        for _ in tqdm(range(args.ncd_gap_nrefs), desc="Calculating Reference Gap Statistic", leave=True, unit="ref"):
+            logger.debug(f"Calculating Reference Gap Statistic for k={k}")
+            ref_data = np.random.uniform(min_vals, max_vals, (n_samples, n_features))
+            kmeans = KMeans(n_clusters=k, random_state=args.seed)
+            kmeans.fit(ref_data)
+            wk_ref = kmeans.inertia_
+            wk_refs.append(np.log(wk_ref))
+        wks_refs.append(np.mean(wk_refs))
+
+        gap_stats[k] = wks_refs[-1] - wks[-1]
+    
+    optimal_k = max(gap_stats, key=gap_stats.get)
+    
+    logger.info(f"Gap Statistic Found k: {optimal_k}")
+    
+    return optimal_k
+        
+def calculate_clustering_accuracy(true_labels: torch.Tensor, pseudo_labels: torch.Tensor) -> None:
+    """
+    Calculates the clustering accuracy between the true labels and the pseudo labels.
+    :param true_labels: torch.Tensor with the true labels.
+    :param pseudo_labels: torch.Tensor with the pseudo labels.
+    """
+    assert true_labels.shape == pseudo_labels.shape, f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}"
+    unique_true_labels = torch.unique(true_labels)
+    unique_pseudo_labels = torch.unique(pseudo_labels)
+    
+    
+    confusion_mat = torch.from_numpy(confusion_matrix(true_labels.cpu().numpy(), pseudo_labels.cpu().numpy()))
+    cost_mat = -confusion_mat
+    
+    row_idx, col_idx = linear_sum_assignment(cost_mat)
+    
+    
+    
+
+def _cluster_features(args, features:torch.Tensor) -> torch.Tensor:
+    """
+    Clusters the features using K Means.
+    :param args: Arguments object with the attribute `ncd_findk_method`, `novel_classes_per_session`, `seed`.
+    :param features: torch.Tensor with the features to cluster. Assumes the features are from novel-classed samples only.
+    :return: torch.Tensor with unadjusted psuedo-labels for given features.
+    """
+    
+    if args.ncd_findk_method == "cheat":
+        k = args.novel_classes_per_session
+    elif args.ncd_findk_method == "elbow":
+        k = _elbow_findk(features, args)
+    elif args.ncd_findk_method == "silhouette":
+        k = _silhouette_findk(features, args)
+    elif args.ncd_findk_method == "gap":
+        k = _gap_findk(features, args)
+    else:
+        raise ValueError(f"Unknown method for finding k: {args.ncd_findk_method}")
+    
+    logger.info(f"FindK Method: {args.ncd_findk_method} k: {k} True k: {args.novel_classes_per_session}")
+    
+    kmeans = KMeans(n_clusters=k, random_state=args.seed)
+    pseudo_labels = torch.tensor(kmeans.fit_predict(features))
+    return pseudo_labels
+
+        
+
+
+    
+    
+    
+
+def discover_classes_in_session_dataset(
+    args,
+    session_dataset: TransformedTensorDataset,
+    model: torch.nn.Module,
+    return_new_dataset: bool = True,
+):
+    """
+    Discovers classes in the session dataset using the provided model.
+    This step is for after OOD detection, which modifies the session dataset to include the OOD label in the third column.
+    
+    :param args: Arguments object.
+    :param session_dataset: TransformedTensorDataset for the session (data, target, ood).
+    :param model: The model to evaluate with.
+    :param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted classes.
+    """
+    
+    
\ No newline at end of file
diff --git a/entcl/utils/ood.py b/entcl/utils/ood.py
index c31ac1e..18dff0f 100644
--- a/entcl/utils/ood.py
+++ b/entcl/utils/ood.py
@@ -1,6 +1,5 @@
 from typing import Iterable, Tuple, Union
 from entcl.data.util import TransformedTensorDataset
-from entcl.models.model import ENTCLModel
 from loguru import logger
 from sklearn.mixture import GaussianMixture
 import torch
@@ -8,7 +7,7 @@ from tqdm import tqdm
 
 
 def _get_scores(
-    loader: Union[torch.utils.data.DataLoader, Iterable], model: ENTCLModel, args
+    session_dataset: TransformedTensorDataset, model: torch.nn.Module, args
 ) -> Union[
     Tuple[torch.Tensor, torch.Tensor],
     Tuple[None, None],
@@ -23,6 +22,15 @@ def _get_scores(
     :param args: Object with the attributes `ood_score` and `device`.
     :return: Tuple of torch.Tensors with the entropy and energy scores respectively.
     """
+    
+    session_loader = torch.utils.data.DataLoader(
+        session_dataset,
+        batch_size=args.batch_size,
+        num_workers=args.num_workers,
+        pin_memory=args.pin_memory,
+        shuffle=False,
+    )
+    
     logger.debug(f"Getting Scores: {args.ood_score}")
     model.eval()
 
@@ -140,30 +148,21 @@ def _resolve_conflicts(
 def label_ood_for_session(
     args,
     session_dataset: TransformedTensorDataset,
-    model: ENTCLModel,
-    return_new_dataset: bool = False,
+    model: torch.nn.Module,
+    return_new_dataset: bool = True,
 ) -> Union[TransformedTensorDataset, torch.Tensor]:
     """
     OOD Labelling for a session dataset. This function computes entropy and/or energy scores for the dataset based on `args.ood_score` and fits a Gaussian Mixture Model to each of the scores. The GMM has 2 components, one for in-distribution samples and one for OOD samples. The function then resolves conflicts between the entropy and energy predictions by selecting the type with the highest confidence. Finally, the function returns a new dataset for the session, including the predicted types.
     :param args: Objects with the attributes `ood_score`, `ood_eps`, `seed` and `device` (Program Arguments).
     :param session_dataset: Dataset for the session.
     :param model: The model to evaluate.
-    :param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted types (useful when theres more to do before training).
+    :param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted types.
     :return: A TransformedTensorDataset with the predicted types or just a torch.Tensor of the predicted types.
     """
     logger.debug("Starting OOD Labelling for Session")
 
-    # first we dataload the session dataset, for memory efficiency
-    session_loader = torch.utils.data.DataLoader(
-        session_dataset,
-        batch_size=args.batch_size,
-        num_workers=args.num_workers,
-        pin_memory=args.pin_memory,
-        shuffle=False,
-    )
-
-    # next we run the dataset through the model to retrieve an entropy/energy score for each sample
-    entropies, energies = _get_scores(session_loader, model, args)
+    # first we run the dataset through the model to retrieve an entropy/energy score for each sample
+    entropies, energies = _get_scores(session_dataset, model, args)
 
     # next step depends on the selected OOD Score
     # placeholder for the final predictions
-- 
GitLab