From a2f5b5b1342dbe6d08fef5c75cc7c431d0a32903 Mon Sep 17 00:00:00 2001
From: Joseph Omar <j.omar@soton.ac.uk>
Date: Tue, 3 Dec 2024 11:51:14 +0000
Subject: [PATCH] Enhance logging for learning rate, update CIFAR100
 transformations, and improve confusion matrix plotting. Added DataFrame
 export for clustering results.

---
 entcl/cl.py            |  2 ++
 entcl/data/cifar100.py | 10 ++++++++--
 entcl/utils/ncd.py     | 29 +++++++++++++++--------------
 3 files changed, 25 insertions(+), 16 deletions(-)

diff --git a/entcl/cl.py b/entcl/cl.py
index f42242f..2a3d403 100644
--- a/entcl/cl.py
+++ b/entcl/cl.py
@@ -87,6 +87,8 @@ def cl_session(args, session_dataset: TransformedTensorDataset, model: torch.nn.
     
     results = None
     
+    print("LR is ", optimiser.param_groups[0]["lr"])
+    print("LR is set to ")
     # train the model
     for epoch in range (args.cl_epochs):
         logger.debug(f"Session {args.current_session} Epoch {epoch} Started")
diff --git a/entcl/data/cifar100.py b/entcl/data/cifar100.py
index 1e534d6..4e9d751 100644
--- a/entcl/data/cifar100.py
+++ b/entcl/data/cifar100.py
@@ -8,11 +8,17 @@ from torchvision.datasets import CIFAR100 as _CIFAR100
 import torchvision.transforms.v2 as transforms
 from entcl.config import CIFAR100_DIR
 from loguru import logger
+
 CIFAR100_TRANSFORM = transforms.Compose(
     [
+        transforms.Resize(int(224/0.875), interpolation=3),
+        transforms.CenterCrop(224),
         transforms.ToTensor(),
-        transforms.Resize((224, 224), antialias=True),
-        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+        transforms.Normalize(
+            mean=[0.485, 0.456, 0.406],
+            std=[0.229, 0.224, 0.225]
+        )
+        
     ]
 )
 
diff --git a/entcl/utils/ncd.py b/entcl/utils/ncd.py
index 49ba177..5bdb669 100644
--- a/entcl/utils/ncd.py
+++ b/entcl/utils/ncd.py
@@ -69,7 +69,8 @@ def plot_confmat(confmat: torch.Tensor, path: str) -> None:
         ax.set_xlabel("Predicted Label")
         ax.set_ylabel("True Label")
         ax.set_title("Confusion Matrix")
-        plt.savefig(path)
+        ax.set_aspect("equal")
+        plt.savefig(path, bbox_inches="tight")
         plt.close()
     except ImportError:
         logger.error(
@@ -126,9 +127,14 @@ def generate_mapping(
         per_class_accuracy[true_class] = np.mean(
             true_labels[mask] == pseudo_labels_aligned[mask]
         )
-
+        
+    true_labels_for_df = []
+    accs_for_df = []
+    
     string = f"NCD Clustering Accuracies for Session {args.current_session}:"
     for true_class, acc in per_class_accuracy.items():
+        true_labels_for_df.append(true_class+labels_start)
+        accs_for_df.append(acc)
         string += f"\nTrue Class {true_class+labels_start}: {acc*100:4f}%"
     string += f"\n"
     string += f"\nOverall Accuracy: {overall_accuracy*100:4f}%"
@@ -153,7 +159,13 @@ def generate_mapping(
     for true_label, pseudo_label in mapping.items():
         string += f"\nTrue Label {true_label} -> Pseudo Label {pseudo_label}"
     logger.info(string)
-        
+    
+    psuedo_for_df = [mapping[true_label] for true_label in true_labels_for_df]
+    clustering_df = pd.DataFrame(columns=["true_labels", "pseudo_labels", "accuracy"])
+    clustering_df["true_labels"] = true_labels_for_df
+    clustering_df["pseudo_labels"] = psuedo_for_df
+    clustering_df["accuracy"] = accs_for_df
+    clustering_df.to_csv(os.path.join(args.exp_dir, f"clustering_s{args.current_session}.csv"), index=False)
     
     return mapping
 
@@ -232,17 +244,6 @@ def find_novel_classes_for_session(
         pseudo_labels  # whereever the sample is novel, assign the corresponding pseudo label.
     )
 
-    # just for checking that this stuff above works properly
-    clustering_df = pd.DataFrame(columns=["true_labels", "type", "pseudo_labels"])
-    clustering_df["true_labels"] = session_dataset.tensor_dataset.tensors[1].cpu().numpy()
-    clustering_df["type"] = session_dataset.tensor_dataset.tensors[2].cpu().numpy()
-    clustering_df["predtype"] = session_dataset.tensor_dataset.tensors[3].cpu().numpy()
-    clustering_df["pseudo_labels"] = pseudo_labels_aligned.cpu().numpy()
-
-    # save the dataset to a csv file
-    dataset_path = os.path.join(args.exp_dir, f"clustering_s{args.current_session}.csv")
-    clustering_df.to_csv(dataset_path, index=False)
-
     # create a new dataset with the pseudo labels alongside the original tensors and the same transform
     # NOTE: the tensors attribute of a TensorDataset is a tuple. we can't append to it so instead we create a new TensorDataset with the new pseudo labels tensor.
     # NOTE: The dataset at the end of NCD is in the for data, true labels, true type, predicted type, psuedo labels
-- 
GitLab