Skip to content
Snippets Groups Projects
Commit c4535f84 authored by Joseph Omar's avatar Joseph Omar
Browse files

working on ncd. should run offline okay tho :)

parent 0c28c917
No related branches found
No related tags found
No related merge requests found
from typing import Optional, Union
from entcl.data.util import TransformedTensorDataset
from entcl.models.model import ENTCLModel
import torch
from loguru import logger
from sklearn.mixture import GaussianMixture
import tqdm
......@@ -110,6 +110,7 @@ if __name__ == "__main__":
# initialise dataset
if args.dataset == 'cifar100':
from entcl.data.cifar100 import CIFAR100Dataset
args.novel_classes_per_session = (100 - args.known) // args.sessions
args.dataset = CIFAR100Dataset(known=args.known, pretrain_n_known=args.pretrain_n_known, cl_n_known=args.cl_n_known, cl_n_novel=args.cl_n_novel, cl_n_prevnovel=args.cl_n_prevnovel, sessions=5)
......
from entcl.data.util import TransformedTensorDataset
from loguru import logger
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix
from sklearn.utils.linear_assignment_ import linear_sum_assignment
import torch
from tqdm import tqdm
def _extract_features(args, dataset: TransformedTensorDataset, model: torch.nn.Module) -> torch.Tensor:
"""
Extracts features from the data in the dataset using the provided model's backbone.
:param args: Arguments object with the attributes `device`, .
:param dataset: TransformedTensorDataset with the data to extract features from.
"""
logger.debug("Extracting Features")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
shuffle=False,
)
model.eval()
features = None
with torch.no_grad():
for batch in tqdm(dataloader, desc="Extracting Features", leave=True, unit="batch"):
x = batch[0] # Only the data is needed, assumes batch is iterable.
x = x.to(args.device)
x_proj, _= model(x) # we do not need the predicted logits, only the features.
if features is None:
features = x_proj
else:
features = torch.cat((features, x_proj), dim=0)
return features
def _elbow_findk(features: torch.Tensor, args) -> int:
"""
Finds the optimal number of clusters using the elbow method and the kneed library.
:param features: torch.Tensor with the features to cluster.
:param args: Arguments object with the attribute `seed`.
:return: int with the optimal number of clusters.
"""
from kneed import KneeLocator
inertias = []
ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1)
logger.debug(f"Finding k using elbow method for k in {ks}")
# Calculate the inertia for each k
for k in ks:
kmeans = KMeans(n_clusters=k, random_state=args.seed)
kmeans.fit(features)
inertias.append(kmeans.inertia_)
logger.debug(f"Elbow Inertias: {inertias}")
logger.debug(f"Running KneeLocator")
# Find the knee
kneedle = KneeLocator(ks, inertias, curve="convex", direction="decreasing")
logger.info(f"Elbow Method Found k: {kneedle.knee}")
return kneedle.knee
def _silhouette_findk(features: torch.Tensor, args) -> int:
"""
Finds the optimal number of clusters using the silhouette method.
:param features: torch.Tensor with the features to cluster.
:param args: Arguments object with the attribute `seed`.
:return: int with the optimal number of clusters.
"""
from sklearn.metrics import silhouette_score
silhouettes = []
ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1)
logger.debug(f"Finding k using silhouette method for k in {ks}")
# Calculate the silhouette score for each k
for k in ks:
kmeans = KMeans(n_clusters=k, random_state=args.seed)
pseudo_labels = kmeans.fit_predict(features)
silhouette = silhouette_score(features, pseudo_labels)
silhouettes.append(silhouette)
logger.debug(f"Silhouette Scores: {silhouettes}")
# Find the best k
best_k = ks[silhouettes.index(max(silhouettes))]
logger.info(f"Silhouette Method Found k: {best_k}")
return best_k
def _gap_findk(features: torch.Tensor, args) -> int:
"""
Finds the optimal number of clusters using the gap statistic.
:param features: torch.Tensor with the features to cluster.
:param args: Arguments object with the attribute `seed`.
:return: int with the optimal number of clusters.
"""
logger.debug("Finding k using gap statistic")
features_np = features.cpu().numpy()
n_samples, n_features = features_np.shape
min_vals = features_np.min(axis=0)
max_vals = features_np.max(axis=0)
gap_stats = {}
wks = []
wks_refs = []
for k in tqdm(range(args.ncd_findk_mink, args.ncd_findk_maxk + 1), desc="Calculating Gap Statistic", leave=True, unit="k"):
logger.debug(f"Calculating Gap Statistic for k={k}")
kmeans = KMeans(n_clusters=k, random_state=args.seed)
kmeans.fit(features_np)
wk = kmeans.inertia_
wks.append(np.log(wk))
wk_refs = []
for _ in tqdm(range(args.ncd_gap_nrefs), desc="Calculating Reference Gap Statistic", leave=True, unit="ref"):
logger.debug(f"Calculating Reference Gap Statistic for k={k}")
ref_data = np.random.uniform(min_vals, max_vals, (n_samples, n_features))
kmeans = KMeans(n_clusters=k, random_state=args.seed)
kmeans.fit(ref_data)
wk_ref = kmeans.inertia_
wk_refs.append(np.log(wk_ref))
wks_refs.append(np.mean(wk_refs))
gap_stats[k] = wks_refs[-1] - wks[-1]
optimal_k = max(gap_stats, key=gap_stats.get)
logger.info(f"Gap Statistic Found k: {optimal_k}")
return optimal_k
def calculate_clustering_accuracy(true_labels: torch.Tensor, pseudo_labels: torch.Tensor) -> None:
"""
Calculates the clustering accuracy between the true labels and the pseudo labels.
:param true_labels: torch.Tensor with the true labels.
:param pseudo_labels: torch.Tensor with the pseudo labels.
"""
assert true_labels.shape == pseudo_labels.shape, f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}"
unique_true_labels = torch.unique(true_labels)
unique_pseudo_labels = torch.unique(pseudo_labels)
confusion_mat = torch.from_numpy(confusion_matrix(true_labels.cpu().numpy(), pseudo_labels.cpu().numpy()))
cost_mat = -confusion_mat
row_idx, col_idx = linear_sum_assignment(cost_mat)
def _cluster_features(args, features:torch.Tensor) -> torch.Tensor:
"""
Clusters the features using K Means.
:param args: Arguments object with the attribute `ncd_findk_method`, `novel_classes_per_session`, `seed`.
:param features: torch.Tensor with the features to cluster. Assumes the features are from novel-classed samples only.
:return: torch.Tensor with unadjusted psuedo-labels for given features.
"""
if args.ncd_findk_method == "cheat":
k = args.novel_classes_per_session
elif args.ncd_findk_method == "elbow":
k = _elbow_findk(features, args)
elif args.ncd_findk_method == "silhouette":
k = _silhouette_findk(features, args)
elif args.ncd_findk_method == "gap":
k = _gap_findk(features, args)
else:
raise ValueError(f"Unknown method for finding k: {args.ncd_findk_method}")
logger.info(f"FindK Method: {args.ncd_findk_method} k: {k} True k: {args.novel_classes_per_session}")
kmeans = KMeans(n_clusters=k, random_state=args.seed)
pseudo_labels = torch.tensor(kmeans.fit_predict(features))
return pseudo_labels
def discover_classes_in_session_dataset(
args,
session_dataset: TransformedTensorDataset,
model: torch.nn.Module,
return_new_dataset: bool = True,
):
"""
Discovers classes in the session dataset using the provided model.
This step is for after OOD detection, which modifies the session dataset to include the OOD label in the third column.
:param args: Arguments object.
:param session_dataset: TransformedTensorDataset for the session (data, target, ood).
:param model: The model to evaluate with.
:param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted classes.
"""
\ No newline at end of file
from typing import Iterable, Tuple, Union
from entcl.data.util import TransformedTensorDataset
from entcl.models.model import ENTCLModel
from loguru import logger
from sklearn.mixture import GaussianMixture
import torch
......@@ -8,7 +7,7 @@ from tqdm import tqdm
def _get_scores(
loader: Union[torch.utils.data.DataLoader, Iterable], model: ENTCLModel, args
session_dataset: TransformedTensorDataset, model: torch.nn.Module, args
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[None, None],
......@@ -23,6 +22,15 @@ def _get_scores(
:param args: Object with the attributes `ood_score` and `device`.
:return: Tuple of torch.Tensors with the entropy and energy scores respectively.
"""
session_loader = torch.utils.data.DataLoader(
session_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
shuffle=False,
)
logger.debug(f"Getting Scores: {args.ood_score}")
model.eval()
......@@ -140,30 +148,21 @@ def _resolve_conflicts(
def label_ood_for_session(
args,
session_dataset: TransformedTensorDataset,
model: ENTCLModel,
return_new_dataset: bool = False,
model: torch.nn.Module,
return_new_dataset: bool = True,
) -> Union[TransformedTensorDataset, torch.Tensor]:
"""
OOD Labelling for a session dataset. This function computes entropy and/or energy scores for the dataset based on `args.ood_score` and fits a Gaussian Mixture Model to each of the scores. The GMM has 2 components, one for in-distribution samples and one for OOD samples. The function then resolves conflicts between the entropy and energy predictions by selecting the type with the highest confidence. Finally, the function returns a new dataset for the session, including the predicted types.
:param args: Objects with the attributes `ood_score`, `ood_eps`, `seed` and `device` (Program Arguments).
:param session_dataset: Dataset for the session.
:param model: The model to evaluate.
:param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted types (useful when theres more to do before training).
:param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted types.
:return: A TransformedTensorDataset with the predicted types or just a torch.Tensor of the predicted types.
"""
logger.debug("Starting OOD Labelling for Session")
# first we dataload the session dataset, for memory efficiency
session_loader = torch.utils.data.DataLoader(
session_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
shuffle=False,
)
# next we run the dataset through the model to retrieve an entropy/energy score for each sample
entropies, energies = _get_scores(session_loader, model, args)
# first we run the dataset through the model to retrieve an entropy/energy score for each sample
entropies, energies = _get_scores(session_dataset, model, args)
# next step depends on the selected OOD Score
# placeholder for the final predictions
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment