From 267400f4ccfadca9de8c3f1281572f662696a063 Mon Sep 17 00:00:00 2001
From: Joseph Omar <j.omar@soton.ac.uk>
Date: Fri, 22 Nov 2024 14:48:01 +0000
Subject: [PATCH] added freeze, expand to head

---
 entcl/models/linear_head.py | 37 +++++++++++++++++++++++++++
 entcl/run.py                | 51 ++++++++++++++++---------------------
 2 files changed, 59 insertions(+), 29 deletions(-)

diff --git a/entcl/models/linear_head.py b/entcl/models/linear_head.py
index 6931789..fec068e 100644
--- a/entcl/models/linear_head.py
+++ b/entcl/models/linear_head.py
@@ -8,6 +8,43 @@ class LinearHead(torch.nn.Module):
         
     def forward(self, x):
         return self.fc(x)
+
+    def expand(self, num: int, init_new: bool = False):
+        """
+        Expand the number of output features of the head
+        :param num: int with the number of features to expand by.
+        """
+        logger.info(f"Expanding Head: {self.fc.out_features} -> {self.fc.out_features + num}")
+        old_fc = self.fc
+        self.fc = torch.nn.Linear(old_fc.in_features, old_fc.out_features + num)
+        self.fc.weight.data[:old_fc.out_features] = old_fc.weight.data
+        self.fc.bias.data[:old_fc.out_features] = old_fc.bias.data
+        self.fc.weight.data[old_fc.out_features:] = 0
+        self.fc.bias.data[old_fc.out_features:] = 0
+        
+        if init_new:
+            torch.nn.init.kaiming_normal_(self.fc.weight.data[old_fc.out_features:])
+            torch.nn.init.zeros_(self.fc.bias.data[old_fc.out_features:])
+        
+        
+    def freeze(self, start_idx: int, end_idx: int):
+        """
+        Freeze the weights of the head between the given indices
+        :param start_idx: int with the start index to freeze from.
+        :param end_idx: int with the end index to freeze to.
+        """
+        logger.info(f"Freezing Head: {start_idx} -> {end_idx}")
+        self.fc.weight.data[start_idx:end_idx].requires_grad = False
+        self.fc.bias.data[start_idx:end_idx].requires_grad = False
+    
+    def unfreeze(self):
+        """
+        Unfreeze the weights of the head
+        """
+        logger.info("Unfreezing Head")
+        self.fc.weight.requires_grad = True
+        self.fc.bias.requires_grad = True
+    
     
 class MLPHead(torch.nn.Module):
     def __init__(self, in_features: int, out_features: int, hidden_dim1:int, hidden_dim2:int):
diff --git a/entcl/run.py b/entcl/run.py
index 073a505..0fd3514 100644
--- a/entcl/run.py
+++ b/entcl/run.py
@@ -25,35 +25,28 @@ def main(args: argparse.Namespace):
     if args.mode in ['cl', 'both']:
         logger.info("Starting Continual Learning")
     
-    for session in range(1, args.sessions + 1):
-        logger.info(f"Starting Continual Learning Session {session}")
-        args.current_session = session
-        session_dataset = args.dataset.get_dataset(session)
+        for session in range(1, args.sessions + 1):
+            logger.info(f"Starting Continual Learning Session {session}")
+            args.current_session = session
+            session_dataset = args.dataset.get_dataset(session)
+            
+            # OOD detection
+            session_dataset = label_ood_for_session(args, session_dataset, model) # returns a new dataset with the OOD samples labelled 
+            
+            # NCD
+            session_dataset, mapping = find_novel_classes_for_session(args, session_dataset, model) # returns a new dataset with the novel samples labelled
+            
+            # dataset should now have the form (data, true labels, true types, pred types, pseudo labels)
+            
+            # Expand Classification Head & Initialise
+            model.head.expand(args.dataset.novel_inc) # we are cheating here, we know the number of novel classes
+            
+            # freeze the weights for the existing classes. We are only training unknown samples (EG: 50 (known) + (2 (session) - 1) * 10 (novel_inc) = 60 classes have been seen in cl session 2)
+            model.head.freeze(start_idx=0, end_idx=args.dataset.known + ((session -1) * args.dataset.novel_inc))
+            
+            # run continual learning session
+            model = cl_session(args, session_dataset, model, mapping)
         
-        # OOD detection
-        session_dataset = label_ood_for_session(args, session_dataset, model) # returns a new dataset with the OOD samples labelled 
-        
-        # NCD
-        session_dataset, mapping = find_novel_classes_for_session(args, session_dataset, model) # returns a new dataset with the novel samples labelled
-        
-        # dataset should now have the form (data, true labels, true types, pred types, pseudo labels)
-        
-        # Expand Classification Head & Initialise
-        model.head.expand(args.dataset.novel_inc) # we are cheating here, we know the number of novel classes
-        
-        # freeze the weights for the existing classes. We are only training unknown samples (EG: 50 (known) + (2 (session) - 1) * 10 (novel_inc) = 60 classes have been seen in cl session 2)
-        model.head.freeze(start_idx=0, end_idx=args.dataset.known + ((session -1) * args.dataset.novel_inc))
-        
-        # run continual learning session
-        model = cl_session(args, session_dataset, model, mapping)
-        
-        
-        
-        
-        
-    
-
-
 
 if __name__ == "__main__":
     logger.debug("Entry Point: run.py")
@@ -103,7 +96,7 @@ if __name__ == "__main__":
     parser.add_argument('--retain_all', action='store_true', default=False, help='Keep all model checkpoints')
     
     # model args
-    parser.add_argument('--head', type=str, default='mlp', help='Classification head to use', choices=['linear','mlp', 'dino_head'])
+    parser.add_argument('--head', type=str, default='linear', help='Classification head to use', choices=['linear','mlp', 'dino_head'])
     parser.add_argument('--backbone_url', type=str, default="/cl/entcl/entcl/models/dinov2", help="URL to the repo containing the backbone model")
     parser.add_argument("--backbone", type=str, default="dinov2_vitb14", help="Name of the backbone model to use")
     parser.add_argument("--backbone_source", type=str, default="local", help="Source of the backbone model")
-- 
GitLab