From f91daf052d433f59ff5ed677a706a54f8b3116e8 Mon Sep 17 00:00:00 2001
From: Joseph Omar <j.omar@soton.ac.uk>
Date: Mon, 18 Nov 2024 12:24:10 +0000
Subject: [PATCH] rename linear2 to MLPHead. Pretrain now gives a results csv
 (hopefully, untested)

---
 entcl/models/linear_head.py |   2 +-
 entcl/pretrain.py           |  48 ++++-----
 entcl/run.py                |  15 +--
 entcl/utils/findk.py        | 102 +++++++++++++++++++
 entcl/utils/ncd.py          | 190 ++++++++++++++++--------------------
 entcl/utils/ood.py          |   4 +-
 entcl/utils/util.py         |  26 ++++-
 7 files changed, 242 insertions(+), 145 deletions(-)
 create mode 100644 entcl/utils/findk.py

diff --git a/entcl/models/linear_head.py b/entcl/models/linear_head.py
index 755e0f9..6931789 100644
--- a/entcl/models/linear_head.py
+++ b/entcl/models/linear_head.py
@@ -9,7 +9,7 @@ class LinearHead(torch.nn.Module):
     def forward(self, x):
         return self.fc(x)
     
-class LinearHead2(torch.nn.Module):
+class MLPHead(torch.nn.Module):
     def __init__(self, in_features: int, out_features: int, hidden_dim1:int, hidden_dim2:int):
         super().__init__()
         self.mlp = torch.nn.Sequential(
diff --git a/entcl/pretrain.py b/entcl/pretrain.py
index 25e9295..5bfc2f8 100644
--- a/entcl/pretrain.py
+++ b/entcl/pretrain.py
@@ -1,5 +1,6 @@
 import os
 from loguru import logger
+import pandas as pd
 import torch
 from tqdm import tqdm
 
@@ -34,11 +35,11 @@ def pretrain(args, model):
     
     criterion = torch.nn.CrossEntropyLoss()
     
-    accuracies = []
+    results = None
     
     if args.pretrain_load is not None:
         logger.debug(f"Loading pretrained model from {args.pretrain_load}")
-        model.load_state_dict(torch.load(args.pretrain_load, weights_only=True))
+        model.head.load_state_dict(torch.load(args.pretrain_load, weights_only=True))
         model = model.to(args.device)
         model, accuracy = _validate(args, model, val_dataloader)
         logger.info(f"Loaded Pretrained Model Accuracy: {accuracy}")
@@ -47,39 +48,32 @@ def pretrain(args, model):
         logger.debug("No pretrained model to load, training from scratch")
         model = model.to(args.device)
         for epoch in range(args.pretrain_epochs):
+            
             logger.debug(f"Epoch {epoch} Started")
+            
             # train model
-            print(f"Epoch {epoch}:")
             model, train_loss = _train(args, model, train_dataloader, optimiser, criterion)
-            logger.info(f"Epoch {epoch}: Loss: {train_loss}")
+            logger.info(f"Epoch {epoch}: TRAINING : Train Loss: {train_loss}")
             
             # validate model
             model, accuracy, val_loss = _validate(args, model, val_dataloader, criterion)
-            accuracies.append(accuracy)
+            logger.info(f"Epoch {epoch}: VALIDATION : Accuracy: {accuracy}, Val Loss: {val_loss}")
+            
+            # just save the head, the backbone is frozen, and is fucking massive
+            model.head.save(f"{args.exp_root}/{args.name}/head_pretrain_{epoch}.pth")
+            
+            # create a dataframe with the results
+            epoch_results = pd.DataFrame(
+                [[epoch, train_loss, val_loss, accuracy]],
+                columns=["epoch", "train_loss", "val_loss", "accuracy"],
+            )
             
-            logger.info(f"Epoch {epoch}: Accuracy: {accuracy}, Loss: {val_loss}")
-            torch.save(model.state_dict(), os.path.join(args.exp_dir, f"model_{epoch}.pt"))
-        
-        # select the best model
-        if args.pretrain_sel_strat == 'best':
-            best_epoch = accuracies.index(max(accuracies))
-            logger.info(f"Best Epoch: {best_epoch}")
-            model.load_state_dict(torch.load(os.path.join(args.exp_dir, f"model_{best_epoch}.pt"), weights_only=True))
+            # append the results to the results dataframe
+            results = epoch_results if results is None else results.append(epoch_results)
             
-            # remove all other models
-            if not args.retain_all:
-                for epoch in range(args.pretrain_epochs):
-                    if epoch != best_epoch:
-                        os.remove(os.path.join(args.exp_dir, f"model_{epoch}.pt"))
-        
-        # delete all models except the last one, return the last model
-        elif args.pretrain_sel_strat == 'last':
-            if not args.retain_all:
-                for epoch in range(args.pretrain_epochs - 1):
-                    os.remove(os.path.join(args.exp_dir, f"model_{epoch}.pt"))
-        elif args.pretrain_sel_strat == 'load':
-            model.load_state_dict(torch.load(args.pretrain_load, weights_only=True))
-        
+            # save the results dataframe
+            results.to_csv(f"{args.exp_root}/{args.name}/results_pretrain.csv", index=False)
+            logger.debug(f"Epoch {epoch} Finished. Pretrain Results Saved to {args.exp_root}/{args.name}/results_pretrain.csv")
         return model
 
 def _train(args, model, train_dataloader, optimiser, criterion):
diff --git a/entcl/run.py b/entcl/run.py
index dcdd668..48604ef 100644
--- a/entcl/run.py
+++ b/entcl/run.py
@@ -16,7 +16,10 @@ def main(args: argparse.Namespace):
     model = ENTCLModel(head=args.head, backbone_url=args.backbone_url, backbone=args.backbone, backbone_source=args.backbone_source)
     logger.debug(f"Model: {model}")
     
-    model = pretrain(args, model)
+    if args.mode == 'pretrain':
+        model = pretrain(args, model)
+    else:
+        raise NotImplementedError(f"Mode {args.mode} not implemented")
 
 
 
@@ -68,7 +71,7 @@ if __name__ == "__main__":
     parser.add_argument('--retain_all', action='store_true', default=False, help='Keep all model checkpoints')
     
     # model args
-    parser.add_argument('--head', type=str, default='linear2', help='Classification head to use', choices=['linear','linear2', 'dino_head'])
+    parser.add_argument('--head', type=str, default='mlp', help='Classification head to use', choices=['linear','mlp', 'dino_head'])
     parser.add_argument('--backbone_url', type=str, default="/cl/entcl/entcl/models/dinov2", help="URL to the repo containing the backbone model")
     parser.add_argument("--backbone", type=str, default="dinov2_vitb14", help="Name of the backbone model to use")
     parser.add_argument("--backbone_source", type=str, default="local", help="Source of the backbone model")
@@ -121,10 +124,10 @@ if __name__ == "__main__":
     elif args.head == 'dino_head':
         from entcl.models.dinohead import DINOHead
         args.head = DINOHead(768, args.dataset.known, nlayers=3)
-    elif args.head == 'linear2':
-        from entcl.models.linear_head import LinearHead2
-        args.head = LinearHead2(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256)
-        logger.debug(f"Using Linear2 Head: {args.head}")
+    elif args.head == 'mlp':
+        from entcl.models.linear_head import MLPHead
+        args.head = MLPHead(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256)
+        logger.debug(f"Using MLP Head: {args.head}")
         
     argstr = "Arguments: \n"
     for arg in vars(args):
diff --git a/entcl/utils/findk.py b/entcl/utils/findk.py
new file mode 100644
index 0000000..8241a22
--- /dev/null
+++ b/entcl/utils/findk.py
@@ -0,0 +1,102 @@
+import torch
+import numpy as np
+from sklearn.cluster import KMeans
+from loguru import logger
+from tqdm import tqdm
+
+def elbow(features: torch.Tensor, args) -> int:
+    """
+    Finds the optimal number of clusters using the elbow method and the kneed library.
+    :param features: torch.Tensor with the features to cluster.
+    :param args: Arguments object with the attribute `seed`.
+    :return: int with the optimal number of clusters.
+    """
+    from kneed import KneeLocator
+    inertias = []
+    
+    ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1)
+    logger.debug(f"Finding k using elbow method for k in {ks}")
+    
+    # Calculate the inertia for each k
+    for k in tqdm(ks, desc="Calculating Inertias", leave=True, unit="k"):
+        kmeans = KMeans(n_clusters=k, random_state=args.seed)
+        kmeans.fit(features)
+        inertias.append(kmeans.inertia_)
+        
+    logger.debug(f"Elbow Inertias: {inertias}")
+    logger.debug(f"Running KneeLocator")
+    
+    # Find the knee
+    kneedle = KneeLocator(ks, inertias, curve="convex", direction="decreasing")
+    logger.info(f"Elbow Method Found k: {kneedle.knee}")
+    return kneedle.knee
+    
+def silhouette(features: torch.Tensor, args) -> int:
+    """
+    Finds the optimal number of clusters using the silhouette method.
+    :param features: torch.Tensor with the features to cluster.
+    :param args: Arguments object with the attribute `seed`.
+    :return: int with the optimal number of clusters.
+    """
+    from sklearn.metrics import silhouette_score
+    silhouettes = []
+    
+    ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1)
+    logger.debug(f"Finding k using silhouette method for k in {ks}")
+    
+    # Calculate the silhouette score for each k
+    for k in tqdm(ks, desc="Calculating Silhouettes", leave=True, unit="k"):
+        kmeans = KMeans(n_clusters=k, random_state=args.seed)
+        pseudo_labels = kmeans.fit_predict(features)
+        silhouette = silhouette_score(features, pseudo_labels)
+        silhouettes.append(silhouette)
+        
+    logger.debug(f"Silhouette Scores: {silhouettes}")
+    
+    # Find the best k
+    best_k = ks[silhouettes.index(max(silhouettes))]
+    logger.info(f"Silhouette Method Found k: {best_k}")
+    return best_k
+
+def gap(features: torch.Tensor, args) -> int:
+    """
+    Finds the optimal number of clusters using the gap statistic.
+    :param features: torch.Tensor with the features to cluster.
+    :param args: Arguments object with the attribute `seed`.
+    :return: int with the optimal number of clusters.
+    """
+    logger.debug("Finding k using gap statistic")
+    features_np = features.cpu().numpy()
+    n_samples, n_features = features_np.shape
+    
+    min_vals = features_np.min(axis=0)
+    max_vals = features_np.max(axis=0)
+    
+    gap_stats = {}
+    wks = []
+    wks_refs = []
+    
+    for k in tqdm(range(args.ncd_findk_mink, args.ncd_findk_maxk + 1), desc="Calculating Gap Statistic", leave=True, unit="k"):
+        logger.debug(f"Calculating Gap Statistic for k={k}")
+        kmeans = KMeans(n_clusters=k, random_state=args.seed)
+        kmeans.fit(features_np)
+        wk = kmeans.inertia_
+        wks.append(np.log(wk))
+        
+        wk_refs = []
+        for _ in tqdm(range(args.ncd_gap_nrefs), desc="Calculating Reference Gap Statistic", leave=True, unit="ref"):
+            logger.debug(f"Calculating Reference Gap Statistic for k={k}")
+            ref_data = np.random.uniform(min_vals, max_vals, (n_samples, n_features))
+            kmeans = KMeans(n_clusters=k, random_state=args.seed)
+            kmeans.fit(ref_data)
+            wk_ref = kmeans.inertia_
+            wk_refs.append(np.log(wk_ref))
+        wks_refs.append(np.mean(wk_refs))
+
+        gap_stats[k] = wks_refs[-1] - wks[-1]
+    
+    optimal_k = max(gap_stats, key=gap_stats.get)
+    
+    logger.info(f"Gap Statistic Found k: {optimal_k}")
+    
+    return optimal_k
\ No newline at end of file
diff --git a/entcl/utils/ncd.py b/entcl/utils/ncd.py
index 2fe4a58..60df52c 100644
--- a/entcl/utils/ncd.py
+++ b/entcl/utils/ncd.py
@@ -1,11 +1,15 @@
 
 
+import os
+from typing import Union
 from entcl.data.util import TransformedTensorDataset
+from entcl.utils.findk import elbow, gap, silhouette
 from loguru import logger
 import numpy as np
 from sklearn.cluster import KMeans
 from sklearn.metrics import confusion_matrix
 from scipy.optimize import linear_sum_assignment
+from entcl.utils.util import generate_unique_path
 import torch
 from tqdm import tqdm
 
@@ -38,120 +42,84 @@ def _extract_features(args, dataset: TransformedTensorDataset, model: torch.nn.M
     
     return features
 
-def _elbow_findk(features: torch.Tensor, args) -> int:
+def plot_confmat(confmat: torch.Tensor, path: str) -> None:
     """
-    Finds the optimal number of clusters using the elbow method and the kneed library.
-    :param features: torch.Tensor with the features to cluster.
-    :param args: Arguments object with the attribute `seed`.
-    :return: int with the optimal number of clusters.
+    Plots a confusion matrix and saves it to the specified path.
+    :param confmat: torch.Tensor with the confusion matrix.
+    :param path: str with the path to save the confusion matrix plot.
     """
-    from kneed import KneeLocator
-    inertias = []
-    
-    ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1)
-    logger.debug(f"Finding k using elbow method for k in {ks}")
-    
-    # Calculate the inertia for each k
-    for k in ks:
-        kmeans = KMeans(n_clusters=k, random_state=args.seed)
-        kmeans.fit(features)
-        inertias.append(kmeans.inertia_)
-        
-    logger.debug(f"Elbow Inertias: {inertias}")
-    logger.debug(f"Running KneeLocator")
-    
-    # Find the knee
-    kneedle = KneeLocator(ks, inertias, curve="convex", direction="decreasing")
-    logger.info(f"Elbow Method Found k: {kneedle.knee}")
-    return kneedle.knee
-    
-def _silhouette_findk(features: torch.Tensor, args) -> int:
-    """
-    Finds the optimal number of clusters using the silhouette method.
-    :param features: torch.Tensor with the features to cluster.
-    :param args: Arguments object with the attribute `seed`.
-    :return: int with the optimal number of clusters.
-    """
-    from sklearn.metrics import silhouette_score
-    silhouettes = []
-    
-    ks = range(args.ncd_findk_mink, args.ncd_findk_maxk + 1)
-    logger.debug(f"Finding k using silhouette method for k in {ks}")
-    
-    # Calculate the silhouette score for each k
-    for k in ks:
-        kmeans = KMeans(n_clusters=k, random_state=args.seed)
-        pseudo_labels = kmeans.fit_predict(features)
-        silhouette = silhouette_score(features, pseudo_labels)
-        silhouettes.append(silhouette)
-        
-    logger.debug(f"Silhouette Scores: {silhouettes}")
-    
-    # Find the best k
-    best_k = ks[silhouettes.index(max(silhouettes))]
-    logger.info(f"Silhouette Method Found k: {best_k}")
-    return best_k
+    try:
+        import matplotlib.pyplot as plt
+        import seaborn as sns
+        from sklearn.metrics import ConfusionMatrixDisplay
 
-def _gap_findk(features: torch.Tensor, args) -> int:
+        fig, ax = plt.subplots(figsize=(10, 10))
+        sns.heatmap(confmat, annot=True, fmt='d', cmap='Blues', ax=ax)
+        ax.set_xlabel('Predicted Label')
+        ax.set_ylabel('True Label')
+        ax.set_title('Confusion Matrix')
+        plt.savefig(path)
+        plt.close()
+    except ImportError:
+        logger.error("Could not import matplotlib or seaborn. Cannot plot confusion matrix. Confusion Matrix not saved.")
+        
+def _calculate_clustering_accuracy(true_labels: torch.Tensor, pseudo_labels: torch.Tensor, args) -> None:
     """
-    Finds the optimal number of clusters using the gap statistic.
-    :param features: torch.Tensor with the features to cluster.
-    :param args: Arguments object with the attribute `seed`.
-    :return: int with the optimal number of clusters.
+    Calculates the clustering accuracy between the true labels and the pseudo labels.
+    :param true_labels: torch.Tensor with the true labels for the novel samples.
+    :param pseudo_labels: torch.Tensor with the pseudo labels for the novel samples.
     """
-    logger.debug("Finding k using gap statistic")
-    features_np = features.cpu().numpy()
-    n_samples, n_features = features_np.shape
-    
-    min_vals = features_np.min(axis=0)
-    max_vals = features_np.max(axis=0)
-    
-    gap_stats = {}
-    wks = []
-    wks_refs = []
-    
-    for k in tqdm(range(args.ncd_findk_mink, args.ncd_findk_maxk + 1), desc="Calculating Gap Statistic", leave=True, unit="k"):
-        logger.debug(f"Calculating Gap Statistic for k={k}")
-        kmeans = KMeans(n_clusters=k, random_state=args.seed)
-        kmeans.fit(features_np)
-        wk = kmeans.inertia_
-        wks.append(np.log(wk))
-        
-        wk_refs = []
-        for _ in tqdm(range(args.ncd_gap_nrefs), desc="Calculating Reference Gap Statistic", leave=True, unit="ref"):
-            logger.debug(f"Calculating Reference Gap Statistic for k={k}")
-            ref_data = np.random.uniform(min_vals, max_vals, (n_samples, n_features))
-            kmeans = KMeans(n_clusters=k, random_state=args.seed)
-            kmeans.fit(ref_data)
-            wk_ref = kmeans.inertia_
-            wk_refs.append(np.log(wk_ref))
-        wks_refs.append(np.mean(wk_refs))
+    assert true_labels.shape == pseudo_labels.shape, f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}"
 
-        gap_stats[k] = wks_refs[-1] - wks[-1]
     
-    optimal_k = max(gap_stats, key=gap_stats.get)
+    # true labels will be > 50 and psuedo labels will start at 0. we need to adjust the pseudo labels to match the true labels.
+    # we will assume the true labels are sequential, and the lowest true label is 0.
     
-    logger.info(f"Gap Statistic Found k: {optimal_k}")
+    true_labels -= true_labels.min() # Adjust the true labels to start at 0.
     
-    return optimal_k
-        
-def calculate_clustering_accuracy(true_labels: torch.Tensor, pseudo_labels: torch.Tensor) -> None:
-    """
-    Calculates the clustering accuracy between the true labels and the pseudo labels.
-    :param true_labels: torch.Tensor with the true labels.
-    :param pseudo_labels: torch.Tensor with the pseudo labels.
-    """
-    assert true_labels.shape == pseudo_labels.shape, f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}"
-    unique_true_labels = torch.unique(true_labels)
-    unique_pseudo_labels = torch.unique(pseudo_labels)
+    # optimal assignment of pseudo labels to true labels
+    confusion_mat = torch.from_numpy(confusion_matrix(true_labels.cpu().numpy(), pseudo_labels.cpu().numpy())) # Rows are true labels, columns are pseudo labels
+    cost_mat = -confusion_mat # Hungarian algorithm minimizes the cost, so we negate the confusion matrix.
+    row_idx, col_idx = linear_sum_assignment(cost_mat) # Hungarian algorithm to find the optimal assignment.
+    
+    # Maps pseudo labels to true labels.
+    label_assignments = dict(zip(col_idx, row_idx)) # Maps pseudo labels to true labels.
+    aligned_predicted_labels = torch.tensor([label_assignments[pseudo_label] for pseudo_label in pseudo_labels]) # Aligns the predicted labels with the true labels.
     
+    # Calculate the accuracy
+    correct_assignments = (aligned_predicted_labels == true_labels).sum().item() # Number of correct assignments.
+    accuracy = correct_assignments / true_labels.shape[0] # Accuracy is the number of correct assignments divided by the number of samples.
     
-    confusion_mat = torch.from_numpy(confusion_matrix(true_labels.cpu().numpy(), pseudo_labels.cpu().numpy()))
-    cost_mat = -confusion_mat
+    # Plot the confusion matrix
+    confmat_path = generate_unique_path(os.path.join(args.exp_dir, args.name, "confusion_matrix.png"))
+    plot_confmat(confusion_mat, path=confmat_path)
     
-    row_idx, col_idx = linear_sum_assignment(cost_mat)
+    unique_true_labels = np.unique(true_labels.cpu().numpy())
+    unique_pseudo_labels = np.unique(pseudo_labels.cpu().numpy())
     
+    # find any ignored labels
+    ignored_true_labels = set(unique_true_labels) - set(unique_true_labels[row_idx])
+    ignored_pseudo_labels = set(unique_pseudo_labels) - set(unique_pseudo_labels[col_idx])
     
+    # Log the results
+    
+    string = f"Clustering Accuracy Computed: {accuracy*100:.2f}%"
+    string += f"\n True Labels (adjusted): {unique_true_labels}"
+    string += f"\n Pseudo Labels: {unique_pseudo_labels}"
+    string += f"\n # of True Labels {len(unique_true_labels)}, # of Pseudo Labels {len(unique_pseudo_labels)}"
+    string += f"\n Confusion Matrix: Plot saved to `{confmat_path}`"
+    string += f"Ignored True Labels: {ignored_true_labels} Count: {len(ignored_true_labels)}"
+    string += f"Ignored Pseudo Labels: {ignored_pseudo_labels} Count: {len(ignored_pseudo_labels)}"
+    string += f"\n If there are ignored true labels, the number of clusters has been overestimated. If there are ignored pseudo labels, the number of clusters has been underestimated."
+    string += f"\n Ignored labels are not included in the accuracy calculation."
+    string += f"\n"
+    string += f"Per True Label Accuracy:"
+    for true_label in unique_true_labels:
+        mask = true_labels == true_label
+        true_label_accuracy = (aligned_predicted_labels[mask] == true_label).sum().item() / mask.sum().item()
+        string += f"\n True Label: {true_label} Accuracy: {true_label_accuracy*100:.2f}%"
+    
+    logger.info(string)
     
 
 def _cluster_features(args, features:torch.Tensor) -> torch.Tensor:
@@ -165,11 +133,11 @@ def _cluster_features(args, features:torch.Tensor) -> torch.Tensor:
     if args.ncd_findk_method == "cheat":
         k = args.novel_classes_per_session
     elif args.ncd_findk_method == "elbow":
-        k = _elbow_findk(features, args)
+        k = elbow(features, args)
     elif args.ncd_findk_method == "silhouette":
-        k = _silhouette_findk(features, args)
+        k = silhouette(features, args)
     elif args.ncd_findk_method == "gap":
-        k = _gap_findk(features, args)
+        k = gap(features, args)
     else:
         raise ValueError(f"Unknown method for finding k: {args.ncd_findk_method}")
     
@@ -179,11 +147,17 @@ def _cluster_features(args, features:torch.Tensor) -> torch.Tensor:
     pseudo_labels = torch.tensor(kmeans.fit_predict(features))
     return pseudo_labels
 
-        
-
-
-    
-    
+def find_novel_classes_for_session(args, session_dataset: TransformedTensorDataset, model: torch.nn.Module) -> Union[torch.Tensor, TransformedTensorDataset]:
+    """
+    Finds the novel classes in the given session dataset using KMeans clustering and the given model
+    :param args: Arguments object with the attributes `device`, `ncd_findk_method`, `novel_classes_per_session`, `seed`.
+    :param session_dataset: TransformedTensorDataset with the session data.
+    :param model: torch.nn.Module with the model to use for clustering.
+    :return: torch.Tensor with the pseudo-labels for the novel classes.
+    """
+    features = _extract_features(args, session_dataset, model)
+    pseudo_labels = _cluster_features(args, features)
+    return pseudo_labels
     
 
 def discover_classes_in_session_dataset(
diff --git a/entcl/utils/ood.py b/entcl/utils/ood.py
index 18dff0f..4656bd5 100644
--- a/entcl/utils/ood.py
+++ b/entcl/utils/ood.py
@@ -4,7 +4,7 @@ from loguru import logger
 from sklearn.mixture import GaussianMixture
 import torch
 from tqdm import tqdm
-
+from entcl.utils.findk import elbow, silhouette, gap
 
 def _get_scores(
     session_dataset: TransformedTensorDataset, model: torch.nn.Module, args
@@ -39,7 +39,7 @@ def _get_scores(
 
     entropies, energies = [], []
 
-    for x, _ in tqdm(loader, desc="Calculating Scores", leave=True, unit="batch"):
+    for x, _ in tqdm(session_loader, desc="Calculating Scores", leave=True, unit="batch"):
         x = x.to(args.device)
         with torch.no_grad():
             logits, _ = model(x)
diff --git a/entcl/utils/util.py b/entcl/utils/util.py
index fae6827..4fa4500 100644
--- a/entcl/utils/util.py
+++ b/entcl/utils/util.py
@@ -12,4 +12,28 @@ def seed(seed=8008135):
     torch.cuda.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
     torch.backends.cudnn.benchmark = False
-    torch.backends.cudnn.deterministic = True
\ No newline at end of file
+    torch.backends.cudnn.deterministic = True
+    
+def generate_unique_path(base_path):
+    """
+    Generate a unique path by appending an integer if the base path exists.
+    :param base_path: str with the base path.
+    :return: str with the unique path.
+    """
+    if not os.path.exists(base_path):
+        return base_path
+
+    # Split the path into name and extension
+    file_dir, file_name = os.path.split(base_path)
+    name, ext = os.path.splitext(file_name)
+
+    # Start the counter and generate unique file name
+    counter = 0
+    while True:
+        new_name = f"{name}_{counter}{ext}"
+        new_path = os.path.join(file_dir, new_name)
+        if not os.path.exists(new_path):
+            return new_path
+        counter += 1
+        
+
-- 
GitLab