From f91daf052d433f59ff5ed677a706a54f8b3116e8 Mon Sep 17 00:00:00 2001 From: Joseph Omar <j.omar@soton.ac.uk> Date: Mon, 18 Nov 2024 12:24:10 +0000 Subject: [PATCH] rename linear2 to MLPHead. Pretrain now gives a results csv (hopefully, untested) --- entcl/models/linear_head.py | 2 +- entcl/pretrain.py | 48 ++++----- entcl/run.py | 15 +-- entcl/utils/findk.py | 102 +++++++++++++++++++ entcl/utils/ncd.py | 190 ++++++++++++++++-------------------- entcl/utils/ood.py | 4 +- entcl/utils/util.py | 26 ++++- 7 files changed, 242 insertions(+), 145 deletions(-) create mode 100644 entcl/utils/findk.py diff --git a/entcl/models/linear_head.py b/entcl/models/linear_head.py index 755e0f9..6931789 100644 --- a/entcl/models/linear_head.py +++ b/entcl/models/linear_head.py @@ -9,7 +9,7 @@ class LinearHead(torch.nn.Module): def forward(self, x): return self.fc(x) -class LinearHead2(torch.nn.Module): +class MLPHead(torch.nn.Module): def __init__(self, in_features: int, out_features: int, hidden_dim1:int, hidden_dim2:int): super().__init__() self.mlp = torch.nn.Sequential( diff --git a/entcl/pretrain.py b/entcl/pretrain.py index 25e9295..5bfc2f8 100644 --- a/entcl/pretrain.py +++ b/entcl/pretrain.py @@ -1,5 +1,6 @@ import os from loguru import logger +import pandas as pd import torch from tqdm import tqdm @@ -34,11 +35,11 @@ def pretrain(args, model): criterion = torch.nn.CrossEntropyLoss() - accuracies = [] + results = None if args.pretrain_load is not None: logger.debug(f"Loading pretrained model from {args.pretrain_load}") - model.load_state_dict(torch.load(args.pretrain_load, weights_only=True)) + model.head.load_state_dict(torch.load(args.pretrain_load, weights_only=True)) model = model.to(args.device) model, accuracy = _validate(args, model, val_dataloader) logger.info(f"Loaded Pretrained Model Accuracy: {accuracy}") @@ -47,39 +48,32 @@ def pretrain(args, model): logger.debug("No pretrained model to load, training from scratch") model = model.to(args.device) for epoch in range(args.pretrain_epochs): + logger.debug(f"Epoch {epoch} Started") + # train model - print(f"Epoch {epoch}:") model, train_loss = _train(args, model, train_dataloader, optimiser, criterion) - logger.info(f"Epoch {epoch}: Loss: {train_loss}") + logger.info(f"Epoch {epoch}: TRAINING : Train Loss: {train_loss}") # validate model model, accuracy, val_loss = _validate(args, model, val_dataloader, criterion) - accuracies.append(accuracy) + logger.info(f"Epoch {epoch}: VALIDATION : Accuracy: {accuracy}, Val Loss: {val_loss}") + + # just save the head, the backbone is frozen, and is fucking massive + model.head.save(f"{args.exp_root}/{args.name}/head_pretrain_{epoch}.pth") + + # create a dataframe with the results + epoch_results = pd.DataFrame( + [[epoch, train_loss, val_loss, accuracy]], + columns=["epoch", "train_loss", "val_loss", "accuracy"], + ) - logger.info(f"Epoch {epoch}: Accuracy: {accuracy}, Loss: {val_loss}") - torch.save(model.state_dict(), os.path.join(args.exp_dir, f"model_{epoch}.pt")) - - # select the best model - if args.pretrain_sel_strat == 'best': - best_epoch = accuracies.index(max(accuracies)) - logger.info(f"Best Epoch: {best_epoch}") - model.load_state_dict(torch.load(os.path.join(args.exp_dir, f"model_{best_epoch}.pt"), weights_only=True)) + # append the results to the results dataframe + results = epoch_results if results is None else results.append(epoch_results) - # remove all other models - if not args.retain_all: - for epoch in range(args.pretrain_epochs): - if epoch != best_epoch: - os.remove(os.path.join(args.exp_dir, f"model_{epoch}.pt")) - - # delete all models except the last one, return the last model - elif args.pretrain_sel_strat == 'last': - if not args.retain_all: - for epoch in range(args.pretrain_epochs - 1): - os.remove(os.path.join(args.exp_dir, f"model_{epoch}.pt")) - elif args.pretrain_sel_strat == 'load': - model.load_state_dict(torch.load(args.pretrain_load, weights_only=True)) - + # save the results dataframe + results.to_csv(f"{args.exp_root}/{args.name}/results_pretrain.csv", index=False) + logger.debug(f"Epoch {epoch} Finished. Pretrain Results Saved to {args.exp_root}/{args.name}/results_pretrain.csv") return model def _train(args, model, train_dataloader, optimiser, criterion): diff --git a/entcl/run.py b/entcl/run.py index dcdd668..48604ef 100644 --- a/entcl/run.py +++ b/entcl/run.py @@ -16,7 +16,10 @@ def main(args: argparse.Namespace): model = ENTCLModel(head=args.head, backbone_url=args.backbone_url, backbone=args.backbone, backbone_source=args.backbone_source) logger.debug(f"Model: {model}") - model = pretrain(args, model) + if args.mode == 'pretrain': + model = pretrain(args, model) + else: + raise NotImplementedError(f"Mode {args.mode} not implemented") @@ -68,7 +71,7 @@ if __name__ == "__main__": parser.add_argument('--retain_all', action='store_true', default=False, help='Keep all model checkpoints') # model args - parser.add_argument('--head', type=str, default='linear2', help='Classification head to use', choices=['linear','linear2', 'dino_head']) + parser.add_argument('--head', type=str, default='mlp', help='Classification head to use', choices=['linear','mlp', 'dino_head']) parser.add_argument('--backbone_url', type=str, default="/cl/entcl/entcl/models/dinov2", help="URL to the repo containing the backbone model") parser.add_argument("--backbone", type=str, default="dinov2_vitb14", help="Name of the backbone model to use") parser.add_argument("--backbone_source", type=str, default="local", help="Source of the backbone model") @@ -121,10 +124,10 @@ if __name__ == "__main__": elif args.head == 'dino_head': from entcl.models.dinohead import DINOHead args.head = DINOHead(768, args.dataset.known, nlayers=3) - elif args.head == 'linear2': - from entcl.models.linear_head import LinearHead2 - args.head = LinearHead2(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256) - logger.debug(f"Using Linear2 Head: {args.head}") + elif args.head == 'mlp': + from entcl.models.linear_head import MLPHead + args.head = MLPHead(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256) + logger.debug(f"Using MLP Head: {args.head}") argstr = "Arguments: \n" for arg in vars(args): diff --git a/entcl/utils/findk.py b/entcl/utils/findk.py new file mode 100644 index 0000000..8241a22 --- /dev/null +++ b/entcl/utils/findk.py @@ -0,0 +1,102 @@ +import torch +import numpy as np +from sklearn.cluster import KMeans +from loguru import logger +from tqdm import tqdm + +def elbow(features: torch.Tensor, args) -> int: + """ + Finds the optimal number of clusters using the elbow method and the kneed library. + :param features: torch.Tensor with the features to cluster. + :param args: Arguments object with the attribute `seed`. + :return: int with the optimal number of clusters. + """ + from kneed import KneeLocator + inertias = [] + + ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1) + logger.debug(f"Finding k using elbow method for k in {ks}") + + # Calculate the inertia for each k + for k in tqdm(ks, desc="Calculating Inertias", leave=True, unit="k"): + kmeans = KMeans(n_clusters=k, random_state=args.seed) + kmeans.fit(features) + inertias.append(kmeans.inertia_) + + logger.debug(f"Elbow Inertias: {inertias}") + logger.debug(f"Running KneeLocator") + + # Find the knee + kneedle = KneeLocator(ks, inertias, curve="convex", direction="decreasing") + logger.info(f"Elbow Method Found k: {kneedle.knee}") + return kneedle.knee + +def silhouette(features: torch.Tensor, args) -> int: + """ + Finds the optimal number of clusters using the silhouette method. + :param features: torch.Tensor with the features to cluster. + :param args: Arguments object with the attribute `seed`. + :return: int with the optimal number of clusters. + """ + from sklearn.metrics import silhouette_score + silhouettes = [] + + ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1) + logger.debug(f"Finding k using silhouette method for k in {ks}") + + # Calculate the silhouette score for each k + for k in tqdm(ks, desc="Calculating Silhouettes", leave=True, unit="k"): + kmeans = KMeans(n_clusters=k, random_state=args.seed) + pseudo_labels = kmeans.fit_predict(features) + silhouette = silhouette_score(features, pseudo_labels) + silhouettes.append(silhouette) + + logger.debug(f"Silhouette Scores: {silhouettes}") + + # Find the best k + best_k = ks[silhouettes.index(max(silhouettes))] + logger.info(f"Silhouette Method Found k: {best_k}") + return best_k + +def gap(features: torch.Tensor, args) -> int: + """ + Finds the optimal number of clusters using the gap statistic. + :param features: torch.Tensor with the features to cluster. + :param args: Arguments object with the attribute `seed`. + :return: int with the optimal number of clusters. + """ + logger.debug("Finding k using gap statistic") + features_np = features.cpu().numpy() + n_samples, n_features = features_np.shape + + min_vals = features_np.min(axis=0) + max_vals = features_np.max(axis=0) + + gap_stats = {} + wks = [] + wks_refs = [] + + for k in tqdm(range(args.ncd_findk_mink, args.ncd_findk_maxk + 1), desc="Calculating Gap Statistic", leave=True, unit="k"): + logger.debug(f"Calculating Gap Statistic for k={k}") + kmeans = KMeans(n_clusters=k, random_state=args.seed) + kmeans.fit(features_np) + wk = kmeans.inertia_ + wks.append(np.log(wk)) + + wk_refs = [] + for _ in tqdm(range(args.ncd_gap_nrefs), desc="Calculating Reference Gap Statistic", leave=True, unit="ref"): + logger.debug(f"Calculating Reference Gap Statistic for k={k}") + ref_data = np.random.uniform(min_vals, max_vals, (n_samples, n_features)) + kmeans = KMeans(n_clusters=k, random_state=args.seed) + kmeans.fit(ref_data) + wk_ref = kmeans.inertia_ + wk_refs.append(np.log(wk_ref)) + wks_refs.append(np.mean(wk_refs)) + + gap_stats[k] = wks_refs[-1] - wks[-1] + + optimal_k = max(gap_stats, key=gap_stats.get) + + logger.info(f"Gap Statistic Found k: {optimal_k}") + + return optimal_k \ No newline at end of file diff --git a/entcl/utils/ncd.py b/entcl/utils/ncd.py index 2fe4a58..60df52c 100644 --- a/entcl/utils/ncd.py +++ b/entcl/utils/ncd.py @@ -1,11 +1,15 @@ +import os +from typing import Union from entcl.data.util import TransformedTensorDataset +from entcl.utils.findk import elbow, gap, silhouette from loguru import logger import numpy as np from sklearn.cluster import KMeans from sklearn.metrics import confusion_matrix from scipy.optimize import linear_sum_assignment +from entcl.utils.util import generate_unique_path import torch from tqdm import tqdm @@ -38,120 +42,84 @@ def _extract_features(args, dataset: TransformedTensorDataset, model: torch.nn.M return features -def _elbow_findk(features: torch.Tensor, args) -> int: +def plot_confmat(confmat: torch.Tensor, path: str) -> None: """ - Finds the optimal number of clusters using the elbow method and the kneed library. - :param features: torch.Tensor with the features to cluster. - :param args: Arguments object with the attribute `seed`. - :return: int with the optimal number of clusters. + Plots a confusion matrix and saves it to the specified path. + :param confmat: torch.Tensor with the confusion matrix. + :param path: str with the path to save the confusion matrix plot. """ - from kneed import KneeLocator - inertias = [] - - ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1) - logger.debug(f"Finding k using elbow method for k in {ks}") - - # Calculate the inertia for each k - for k in ks: - kmeans = KMeans(n_clusters=k, random_state=args.seed) - kmeans.fit(features) - inertias.append(kmeans.inertia_) - - logger.debug(f"Elbow Inertias: {inertias}") - logger.debug(f"Running KneeLocator") - - # Find the knee - kneedle = KneeLocator(ks, inertias, curve="convex", direction="decreasing") - logger.info(f"Elbow Method Found k: {kneedle.knee}") - return kneedle.knee - -def _silhouette_findk(features: torch.Tensor, args) -> int: - """ - Finds the optimal number of clusters using the silhouette method. - :param features: torch.Tensor with the features to cluster. - :param args: Arguments object with the attribute `seed`. - :return: int with the optimal number of clusters. - """ - from sklearn.metrics import silhouette_score - silhouettes = [] - - ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1) - logger.debug(f"Finding k using silhouette method for k in {ks}") - - # Calculate the silhouette score for each k - for k in ks: - kmeans = KMeans(n_clusters=k, random_state=args.seed) - pseudo_labels = kmeans.fit_predict(features) - silhouette = silhouette_score(features, pseudo_labels) - silhouettes.append(silhouette) - - logger.debug(f"Silhouette Scores: {silhouettes}") - - # Find the best k - best_k = ks[silhouettes.index(max(silhouettes))] - logger.info(f"Silhouette Method Found k: {best_k}") - return best_k + try: + import matplotlib.pyplot as plt + import seaborn as sns + from sklearn.metrics import ConfusionMatrixDisplay -def _gap_findk(features: torch.Tensor, args) -> int: + fig, ax = plt.subplots(figsize=(10, 10)) + sns.heatmap(confmat, annot=True, fmt='d', cmap='Blues', ax=ax) + ax.set_xlabel('Predicted Label') + ax.set_ylabel('True Label') + ax.set_title('Confusion Matrix') + plt.savefig(path) + plt.close() + except ImportError: + logger.error("Could not import matplotlib or seaborn. Cannot plot confusion matrix. Confusion Matrix not saved.") + +def _calculate_clustering_accuracy(true_labels: torch.Tensor, pseudo_labels: torch.Tensor, args) -> None: """ - Finds the optimal number of clusters using the gap statistic. - :param features: torch.Tensor with the features to cluster. - :param args: Arguments object with the attribute `seed`. - :return: int with the optimal number of clusters. + Calculates the clustering accuracy between the true labels and the pseudo labels. + :param true_labels: torch.Tensor with the true labels for the novel samples. + :param pseudo_labels: torch.Tensor with the pseudo labels for the novel samples. """ - logger.debug("Finding k using gap statistic") - features_np = features.cpu().numpy() - n_samples, n_features = features_np.shape - - min_vals = features_np.min(axis=0) - max_vals = features_np.max(axis=0) - - gap_stats = {} - wks = [] - wks_refs = [] - - for k in tqdm(range(args.ncd_findk_mink, args.ncd_findk_maxk + 1), desc="Calculating Gap Statistic", leave=True, unit="k"): - logger.debug(f"Calculating Gap Statistic for k={k}") - kmeans = KMeans(n_clusters=k, random_state=args.seed) - kmeans.fit(features_np) - wk = kmeans.inertia_ - wks.append(np.log(wk)) - - wk_refs = [] - for _ in tqdm(range(args.ncd_gap_nrefs), desc="Calculating Reference Gap Statistic", leave=True, unit="ref"): - logger.debug(f"Calculating Reference Gap Statistic for k={k}") - ref_data = np.random.uniform(min_vals, max_vals, (n_samples, n_features)) - kmeans = KMeans(n_clusters=k, random_state=args.seed) - kmeans.fit(ref_data) - wk_ref = kmeans.inertia_ - wk_refs.append(np.log(wk_ref)) - wks_refs.append(np.mean(wk_refs)) + assert true_labels.shape == pseudo_labels.shape, f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}" - gap_stats[k] = wks_refs[-1] - wks[-1] - optimal_k = max(gap_stats, key=gap_stats.get) + # true labels will be > 50 and psuedo labels will start at 0. we need to adjust the pseudo labels to match the true labels. + # we will assume the true labels are sequential, and the lowest true label is 0. - logger.info(f"Gap Statistic Found k: {optimal_k}") + true_labels -= true_labels.min() # Adjust the true labels to start at 0. - return optimal_k - -def calculate_clustering_accuracy(true_labels: torch.Tensor, pseudo_labels: torch.Tensor) -> None: - """ - Calculates the clustering accuracy between the true labels and the pseudo labels. - :param true_labels: torch.Tensor with the true labels. - :param pseudo_labels: torch.Tensor with the pseudo labels. - """ - assert true_labels.shape == pseudo_labels.shape, f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}" - unique_true_labels = torch.unique(true_labels) - unique_pseudo_labels = torch.unique(pseudo_labels) + # optimal assignment of pseudo labels to true labels + confusion_mat = torch.from_numpy(confusion_matrix(true_labels.cpu().numpy(), pseudo_labels.cpu().numpy())) # Rows are true labels, columns are pseudo labels + cost_mat = -confusion_mat # Hungarian algorithm minimizes the cost, so we negate the confusion matrix. + row_idx, col_idx = linear_sum_assignment(cost_mat) # Hungarian algorithm to find the optimal assignment. + + # Maps pseudo labels to true labels. + label_assignments = dict(zip(col_idx, row_idx)) # Maps pseudo labels to true labels. + aligned_predicted_labels = torch.tensor([label_assignments[pseudo_label] for pseudo_label in pseudo_labels]) # Aligns the predicted labels with the true labels. + # Calculate the accuracy + correct_assignments = (aligned_predicted_labels == true_labels).sum().item() # Number of correct assignments. + accuracy = correct_assignments / true_labels.shape[0] # Accuracy is the number of correct assignments divided by the number of samples. - confusion_mat = torch.from_numpy(confusion_matrix(true_labels.cpu().numpy(), pseudo_labels.cpu().numpy())) - cost_mat = -confusion_mat + # Plot the confusion matrix + confmat_path = generate_unique_path(os.path.join(args.exp_dir, args.name, "confusion_matrix.png")) + plot_confmat(confusion_mat, path=confmat_path) - row_idx, col_idx = linear_sum_assignment(cost_mat) + unique_true_labels = np.unique(true_labels.cpu().numpy()) + unique_pseudo_labels = np.unique(pseudo_labels.cpu().numpy()) + # find any ignored labels + ignored_true_labels = set(unique_true_labels) - set(unique_true_labels[row_idx]) + ignored_pseudo_labels = set(unique_pseudo_labels) - set(unique_pseudo_labels[col_idx]) + # Log the results + + string = f"Clustering Accuracy Computed: {accuracy*100:.2f}%" + string += f"\n True Labels (adjusted): {unique_true_labels}" + string += f"\n Pseudo Labels: {unique_pseudo_labels}" + string += f"\n # of True Labels {len(unique_true_labels)}, # of Pseudo Labels {len(unique_pseudo_labels)}" + string += f"\n Confusion Matrix: Plot saved to `{confmat_path}`" + string += f"Ignored True Labels: {ignored_true_labels} Count: {len(ignored_true_labels)}" + string += f"Ignored Pseudo Labels: {ignored_pseudo_labels} Count: {len(ignored_pseudo_labels)}" + string += f"\n If there are ignored true labels, the number of clusters has been overestimated. If there are ignored pseudo labels, the number of clusters has been underestimated." + string += f"\n Ignored labels are not included in the accuracy calculation." + string += f"\n" + string += f"Per True Label Accuracy:" + for true_label in unique_true_labels: + mask = true_labels == true_label + true_label_accuracy = (aligned_predicted_labels[mask] == true_label).sum().item() / mask.sum().item() + string += f"\n True Label: {true_label} Accuracy: {true_label_accuracy*100:.2f}%" + + logger.info(string) def _cluster_features(args, features:torch.Tensor) -> torch.Tensor: @@ -165,11 +133,11 @@ def _cluster_features(args, features:torch.Tensor) -> torch.Tensor: if args.ncd_findk_method == "cheat": k = args.novel_classes_per_session elif args.ncd_findk_method == "elbow": - k = _elbow_findk(features, args) + k = elbow(features, args) elif args.ncd_findk_method == "silhouette": - k = _silhouette_findk(features, args) + k = silhouette(features, args) elif args.ncd_findk_method == "gap": - k = _gap_findk(features, args) + k = gap(features, args) else: raise ValueError(f"Unknown method for finding k: {args.ncd_findk_method}") @@ -179,11 +147,17 @@ def _cluster_features(args, features:torch.Tensor) -> torch.Tensor: pseudo_labels = torch.tensor(kmeans.fit_predict(features)) return pseudo_labels - - - - - +def find_novel_classes_for_session(args, session_dataset: TransformedTensorDataset, model: torch.nn.Module) -> Union[torch.Tensor, TransformedTensorDataset]: + """ + Finds the novel classes in the given session dataset using KMeans clustering and the given model + :param args: Arguments object with the attributes `device`, `ncd_findk_method`, `novel_classes_per_session`, `seed`. + :param session_dataset: TransformedTensorDataset with the session data. + :param model: torch.nn.Module with the model to use for clustering. + :return: torch.Tensor with the pseudo-labels for the novel classes. + """ + features = _extract_features(args, session_dataset, model) + pseudo_labels = _cluster_features(args, features) + return pseudo_labels def discover_classes_in_session_dataset( diff --git a/entcl/utils/ood.py b/entcl/utils/ood.py index 18dff0f..4656bd5 100644 --- a/entcl/utils/ood.py +++ b/entcl/utils/ood.py @@ -4,7 +4,7 @@ from loguru import logger from sklearn.mixture import GaussianMixture import torch from tqdm import tqdm - +from entcl.utils.findk import elbow, silhouette, gap def _get_scores( session_dataset: TransformedTensorDataset, model: torch.nn.Module, args @@ -39,7 +39,7 @@ def _get_scores( entropies, energies = [], [] - for x, _ in tqdm(loader, desc="Calculating Scores", leave=True, unit="batch"): + for x, _ in tqdm(session_loader, desc="Calculating Scores", leave=True, unit="batch"): x = x.to(args.device) with torch.no_grad(): logits, _ = model(x) diff --git a/entcl/utils/util.py b/entcl/utils/util.py index fae6827..4fa4500 100644 --- a/entcl/utils/util.py +++ b/entcl/utils/util.py @@ -12,4 +12,28 @@ def seed(seed=8008135): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True \ No newline at end of file + torch.backends.cudnn.deterministic = True + +def generate_unique_path(base_path): + """ + Generate a unique path by appending an integer if the base path exists. + :param base_path: str with the base path. + :return: str with the unique path. + """ + if not os.path.exists(base_path): + return base_path + + # Split the path into name and extension + file_dir, file_name = os.path.split(base_path) + name, ext = os.path.splitext(file_name) + + # Start the counter and generate unique file name + counter = 0 + while True: + new_name = f"{name}_{counter}{ext}" + new_path = os.path.join(file_dir, new_name) + if not os.path.exists(new_path): + return new_path + counter += 1 + + -- GitLab