From 267400f4ccfadca9de8c3f1281572f662696a063 Mon Sep 17 00:00:00 2001 From: Joseph Omar <j.omar@soton.ac.uk> Date: Fri, 22 Nov 2024 14:48:01 +0000 Subject: [PATCH] added freeze, expand to head --- entcl/models/linear_head.py | 37 +++++++++++++++++++++++++++ entcl/run.py | 51 ++++++++++++++++--------------------- 2 files changed, 59 insertions(+), 29 deletions(-) diff --git a/entcl/models/linear_head.py b/entcl/models/linear_head.py index 6931789..fec068e 100644 --- a/entcl/models/linear_head.py +++ b/entcl/models/linear_head.py @@ -8,6 +8,43 @@ class LinearHead(torch.nn.Module): def forward(self, x): return self.fc(x) + + def expand(self, num: int, init_new: bool = False): + """ + Expand the number of output features of the head + :param num: int with the number of features to expand by. + """ + logger.info(f"Expanding Head: {self.fc.out_features} -> {self.fc.out_features + num}") + old_fc = self.fc + self.fc = torch.nn.Linear(old_fc.in_features, old_fc.out_features + num) + self.fc.weight.data[:old_fc.out_features] = old_fc.weight.data + self.fc.bias.data[:old_fc.out_features] = old_fc.bias.data + self.fc.weight.data[old_fc.out_features:] = 0 + self.fc.bias.data[old_fc.out_features:] = 0 + + if init_new: + torch.nn.init.kaiming_normal_(self.fc.weight.data[old_fc.out_features:]) + torch.nn.init.zeros_(self.fc.bias.data[old_fc.out_features:]) + + + def freeze(self, start_idx: int, end_idx: int): + """ + Freeze the weights of the head between the given indices + :param start_idx: int with the start index to freeze from. + :param end_idx: int with the end index to freeze to. + """ + logger.info(f"Freezing Head: {start_idx} -> {end_idx}") + self.fc.weight.data[start_idx:end_idx].requires_grad = False + self.fc.bias.data[start_idx:end_idx].requires_grad = False + + def unfreeze(self): + """ + Unfreeze the weights of the head + """ + logger.info("Unfreezing Head") + self.fc.weight.requires_grad = True + self.fc.bias.requires_grad = True + class MLPHead(torch.nn.Module): def __init__(self, in_features: int, out_features: int, hidden_dim1:int, hidden_dim2:int): diff --git a/entcl/run.py b/entcl/run.py index 073a505..0fd3514 100644 --- a/entcl/run.py +++ b/entcl/run.py @@ -25,35 +25,28 @@ def main(args: argparse.Namespace): if args.mode in ['cl', 'both']: logger.info("Starting Continual Learning") - for session in range(1, args.sessions + 1): - logger.info(f"Starting Continual Learning Session {session}") - args.current_session = session - session_dataset = args.dataset.get_dataset(session) + for session in range(1, args.sessions + 1): + logger.info(f"Starting Continual Learning Session {session}") + args.current_session = session + session_dataset = args.dataset.get_dataset(session) + + # OOD detection + session_dataset = label_ood_for_session(args, session_dataset, model) # returns a new dataset with the OOD samples labelled + + # NCD + session_dataset, mapping = find_novel_classes_for_session(args, session_dataset, model) # returns a new dataset with the novel samples labelled + + # dataset should now have the form (data, true labels, true types, pred types, pseudo labels) + + # Expand Classification Head & Initialise + model.head.expand(args.dataset.novel_inc) # we are cheating here, we know the number of novel classes + + # freeze the weights for the existing classes. We are only training unknown samples (EG: 50 (known) + (2 (session) - 1) * 10 (novel_inc) = 60 classes have been seen in cl session 2) + model.head.freeze(start_idx=0, end_idx=args.dataset.known + ((session -1) * args.dataset.novel_inc)) + + # run continual learning session + model = cl_session(args, session_dataset, model, mapping) - # OOD detection - session_dataset = label_ood_for_session(args, session_dataset, model) # returns a new dataset with the OOD samples labelled - - # NCD - session_dataset, mapping = find_novel_classes_for_session(args, session_dataset, model) # returns a new dataset with the novel samples labelled - - # dataset should now have the form (data, true labels, true types, pred types, pseudo labels) - - # Expand Classification Head & Initialise - model.head.expand(args.dataset.novel_inc) # we are cheating here, we know the number of novel classes - - # freeze the weights for the existing classes. We are only training unknown samples (EG: 50 (known) + (2 (session) - 1) * 10 (novel_inc) = 60 classes have been seen in cl session 2) - model.head.freeze(start_idx=0, end_idx=args.dataset.known + ((session -1) * args.dataset.novel_inc)) - - # run continual learning session - model = cl_session(args, session_dataset, model, mapping) - - - - - - - - if __name__ == "__main__": logger.debug("Entry Point: run.py") @@ -103,7 +96,7 @@ if __name__ == "__main__": parser.add_argument('--retain_all', action='store_true', default=False, help='Keep all model checkpoints') # model args - parser.add_argument('--head', type=str, default='mlp', help='Classification head to use', choices=['linear','mlp', 'dino_head']) + parser.add_argument('--head', type=str, default='linear', help='Classification head to use', choices=['linear','mlp', 'dino_head']) parser.add_argument('--backbone_url', type=str, default="/cl/entcl/entcl/models/dinov2", help="URL to the repo containing the backbone model") parser.add_argument("--backbone", type=str, default="dinov2_vitb14", help="Name of the backbone model to use") parser.add_argument("--backbone_source", type=str, default="local", help="Source of the backbone model") -- GitLab