diff --git a/entcl/cl.py b/entcl/cl.py index 2aea96b164189a49d9ab57c24403faa201e43afd..8b137891791fe96927ad78e64b0aad7bded08bdc 100644 --- a/entcl/cl.py +++ b/entcl/cl.py @@ -1,8 +1 @@ -from typing import Optional, Union -from entcl.data.util import TransformedTensorDataset -from entcl.models.model import ENTCLModel -import torch -from loguru import logger -from sklearn.mixture import GaussianMixture -import tqdm diff --git a/entcl/run.py b/entcl/run.py index a27f8c61316b3b4755f924a8d29cb92e247609ed..dcdd668d652b9ed9184f43d1944e2b94c63003b1 100644 --- a/entcl/run.py +++ b/entcl/run.py @@ -110,6 +110,7 @@ if __name__ == "__main__": # initialise dataset if args.dataset == 'cifar100': from entcl.data.cifar100 import CIFAR100Dataset + args.novel_classes_per_session = (100 - args.known) // args.sessions args.dataset = CIFAR100Dataset(known=args.known, pretrain_n_known=args.pretrain_n_known, cl_n_known=args.cl_n_known, cl_n_novel=args.cl_n_novel, cl_n_prevnovel=args.cl_n_prevnovel, sessions=5) diff --git a/entcl/utils/ncd.py b/entcl/utils/ncd.py new file mode 100644 index 0000000000000000000000000000000000000000..7093397118f03cb1ef63cc2535b32a26e1319d19 --- /dev/null +++ b/entcl/utils/ncd.py @@ -0,0 +1,205 @@ + + +from entcl.data.util import TransformedTensorDataset +from loguru import logger +import numpy as np +from sklearn.cluster import KMeans +from sklearn.metrics import confusion_matrix +from sklearn.utils.linear_assignment_ import linear_sum_assignment +import torch +from tqdm import tqdm + +def _extract_features(args, dataset: TransformedTensorDataset, model: torch.nn.Module) -> torch.Tensor: + """ + Extracts features from the data in the dataset using the provided model's backbone. + :param args: Arguments object with the attributes `device`, . + :param dataset: TransformedTensorDataset with the data to extract features from. + """ + logger.debug("Extracting Features") + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + shuffle=False, + ) + + model.eval() + features = None + with torch.no_grad(): + for batch in tqdm(dataloader, desc="Extracting Features", leave=True, unit="batch"): + x = batch[0] # Only the data is needed, assumes batch is iterable. + x = x.to(args.device) + x_proj, _= model(x) # we do not need the predicted logits, only the features. + if features is None: + features = x_proj + else: + features = torch.cat((features, x_proj), dim=0) + + return features + +def _elbow_findk(features: torch.Tensor, args) -> int: + """ + Finds the optimal number of clusters using the elbow method and the kneed library. + :param features: torch.Tensor with the features to cluster. + :param args: Arguments object with the attribute `seed`. + :return: int with the optimal number of clusters. + """ + from kneed import KneeLocator + inertias = [] + + ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1) + logger.debug(f"Finding k using elbow method for k in {ks}") + + # Calculate the inertia for each k + for k in ks: + kmeans = KMeans(n_clusters=k, random_state=args.seed) + kmeans.fit(features) + inertias.append(kmeans.inertia_) + + logger.debug(f"Elbow Inertias: {inertias}") + logger.debug(f"Running KneeLocator") + + # Find the knee + kneedle = KneeLocator(ks, inertias, curve="convex", direction="decreasing") + logger.info(f"Elbow Method Found k: {kneedle.knee}") + return kneedle.knee + +def _silhouette_findk(features: torch.Tensor, args) -> int: + """ + Finds the optimal number of clusters using the silhouette method. + :param features: torch.Tensor with the features to cluster. + :param args: Arguments object with the attribute `seed`. + :return: int with the optimal number of clusters. + """ + from sklearn.metrics import silhouette_score + silhouettes = [] + + ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1) + logger.debug(f"Finding k using silhouette method for k in {ks}") + + # Calculate the silhouette score for each k + for k in ks: + kmeans = KMeans(n_clusters=k, random_state=args.seed) + pseudo_labels = kmeans.fit_predict(features) + silhouette = silhouette_score(features, pseudo_labels) + silhouettes.append(silhouette) + + logger.debug(f"Silhouette Scores: {silhouettes}") + + # Find the best k + best_k = ks[silhouettes.index(max(silhouettes))] + logger.info(f"Silhouette Method Found k: {best_k}") + return best_k + +def _gap_findk(features: torch.Tensor, args) -> int: + """ + Finds the optimal number of clusters using the gap statistic. + :param features: torch.Tensor with the features to cluster. + :param args: Arguments object with the attribute `seed`. + :return: int with the optimal number of clusters. + """ + logger.debug("Finding k using gap statistic") + features_np = features.cpu().numpy() + n_samples, n_features = features_np.shape + + min_vals = features_np.min(axis=0) + max_vals = features_np.max(axis=0) + + gap_stats = {} + wks = [] + wks_refs = [] + + for k in tqdm(range(args.ncd_findk_mink, args.ncd_findk_maxk + 1), desc="Calculating Gap Statistic", leave=True, unit="k"): + logger.debug(f"Calculating Gap Statistic for k={k}") + kmeans = KMeans(n_clusters=k, random_state=args.seed) + kmeans.fit(features_np) + wk = kmeans.inertia_ + wks.append(np.log(wk)) + + wk_refs = [] + for _ in tqdm(range(args.ncd_gap_nrefs), desc="Calculating Reference Gap Statistic", leave=True, unit="ref"): + logger.debug(f"Calculating Reference Gap Statistic for k={k}") + ref_data = np.random.uniform(min_vals, max_vals, (n_samples, n_features)) + kmeans = KMeans(n_clusters=k, random_state=args.seed) + kmeans.fit(ref_data) + wk_ref = kmeans.inertia_ + wk_refs.append(np.log(wk_ref)) + wks_refs.append(np.mean(wk_refs)) + + gap_stats[k] = wks_refs[-1] - wks[-1] + + optimal_k = max(gap_stats, key=gap_stats.get) + + logger.info(f"Gap Statistic Found k: {optimal_k}") + + return optimal_k + +def calculate_clustering_accuracy(true_labels: torch.Tensor, pseudo_labels: torch.Tensor) -> None: + """ + Calculates the clustering accuracy between the true labels and the pseudo labels. + :param true_labels: torch.Tensor with the true labels. + :param pseudo_labels: torch.Tensor with the pseudo labels. + """ + assert true_labels.shape == pseudo_labels.shape, f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}" + unique_true_labels = torch.unique(true_labels) + unique_pseudo_labels = torch.unique(pseudo_labels) + + + confusion_mat = torch.from_numpy(confusion_matrix(true_labels.cpu().numpy(), pseudo_labels.cpu().numpy())) + cost_mat = -confusion_mat + + row_idx, col_idx = linear_sum_assignment(cost_mat) + + + + +def _cluster_features(args, features:torch.Tensor) -> torch.Tensor: + """ + Clusters the features using K Means. + :param args: Arguments object with the attribute `ncd_findk_method`, `novel_classes_per_session`, `seed`. + :param features: torch.Tensor with the features to cluster. Assumes the features are from novel-classed samples only. + :return: torch.Tensor with unadjusted psuedo-labels for given features. + """ + + if args.ncd_findk_method == "cheat": + k = args.novel_classes_per_session + elif args.ncd_findk_method == "elbow": + k = _elbow_findk(features, args) + elif args.ncd_findk_method == "silhouette": + k = _silhouette_findk(features, args) + elif args.ncd_findk_method == "gap": + k = _gap_findk(features, args) + else: + raise ValueError(f"Unknown method for finding k: {args.ncd_findk_method}") + + logger.info(f"FindK Method: {args.ncd_findk_method} k: {k} True k: {args.novel_classes_per_session}") + + kmeans = KMeans(n_clusters=k, random_state=args.seed) + pseudo_labels = torch.tensor(kmeans.fit_predict(features)) + return pseudo_labels + + + + + + + + +def discover_classes_in_session_dataset( + args, + session_dataset: TransformedTensorDataset, + model: torch.nn.Module, + return_new_dataset: bool = True, +): + """ + Discovers classes in the session dataset using the provided model. + This step is for after OOD detection, which modifies the session dataset to include the OOD label in the third column. + + :param args: Arguments object. + :param session_dataset: TransformedTensorDataset for the session (data, target, ood). + :param model: The model to evaluate with. + :param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted classes. + """ + + \ No newline at end of file diff --git a/entcl/utils/ood.py b/entcl/utils/ood.py index c31ac1e0d46e5c460b6f17af85e591bb142b4738..18dff0f1675ef76474ccfd1d4a3a2432423bfdd9 100644 --- a/entcl/utils/ood.py +++ b/entcl/utils/ood.py @@ -1,6 +1,5 @@ from typing import Iterable, Tuple, Union from entcl.data.util import TransformedTensorDataset -from entcl.models.model import ENTCLModel from loguru import logger from sklearn.mixture import GaussianMixture import torch @@ -8,7 +7,7 @@ from tqdm import tqdm def _get_scores( - loader: Union[torch.utils.data.DataLoader, Iterable], model: ENTCLModel, args + session_dataset: TransformedTensorDataset, model: torch.nn.Module, args ) -> Union[ Tuple[torch.Tensor, torch.Tensor], Tuple[None, None], @@ -23,6 +22,15 @@ def _get_scores( :param args: Object with the attributes `ood_score` and `device`. :return: Tuple of torch.Tensors with the entropy and energy scores respectively. """ + + session_loader = torch.utils.data.DataLoader( + session_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + shuffle=False, + ) + logger.debug(f"Getting Scores: {args.ood_score}") model.eval() @@ -140,30 +148,21 @@ def _resolve_conflicts( def label_ood_for_session( args, session_dataset: TransformedTensorDataset, - model: ENTCLModel, - return_new_dataset: bool = False, + model: torch.nn.Module, + return_new_dataset: bool = True, ) -> Union[TransformedTensorDataset, torch.Tensor]: """ OOD Labelling for a session dataset. This function computes entropy and/or energy scores for the dataset based on `args.ood_score` and fits a Gaussian Mixture Model to each of the scores. The GMM has 2 components, one for in-distribution samples and one for OOD samples. The function then resolves conflicts between the entropy and energy predictions by selecting the type with the highest confidence. Finally, the function returns a new dataset for the session, including the predicted types. :param args: Objects with the attributes `ood_score`, `ood_eps`, `seed` and `device` (Program Arguments). :param session_dataset: Dataset for the session. :param model: The model to evaluate. - :param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted types (useful when theres more to do before training). + :param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted types. :return: A TransformedTensorDataset with the predicted types or just a torch.Tensor of the predicted types. """ logger.debug("Starting OOD Labelling for Session") - # first we dataload the session dataset, for memory efficiency - session_loader = torch.utils.data.DataLoader( - session_dataset, - batch_size=args.batch_size, - num_workers=args.num_workers, - pin_memory=args.pin_memory, - shuffle=False, - ) - - # next we run the dataset through the model to retrieve an entropy/energy score for each sample - entropies, energies = _get_scores(session_loader, model, args) + # first we run the dataset through the model to retrieve an entropy/energy score for each sample + entropies, energies = _get_scores(session_dataset, model, args) # next step depends on the selected OOD Score # placeholder for the final predictions