diff --git a/entcl/cl.py b/entcl/cl.py index 4efdcb3cdd552799488e22f7900aa2e6ee4937e5..4a960e9bcca97c5479a76c58970bc0c3ac10247c 100644 --- a/entcl/cl.py +++ b/entcl/cl.py @@ -71,14 +71,6 @@ def cl_session(args, session_dataset: torch.utils.data.Dataset, model: torch.nn. weight_decay=args.weight_decay, ) - if not args.no_sched: - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer=optimiser, - T_max=args.cl_epochs, - eta_min=args.lr * 1e-3, - verbose=args.verbose, - ) - criterion = torch.nn.CrossEntropyLoss() results = None @@ -128,8 +120,6 @@ def cl_session(args, session_dataset: torch.utils.data.Dataset, model: torch.nn. results.to_csv(f"{args.exp_dir}/results_s{args.current_session}.csv", index=False) logger.debug(f"Session {args.current_session} Results saved to {args.exp_dir}/results_s{args.current_session}.csv") - if not args.no_sched: - scheduler.step() return model diff --git a/entcl/pretrain.py b/entcl/pretrain.py index 3d44526a80e01124d839874920b77d3818c3917d..328a795ae51ccbccf8e5cfe71898d081823d8e36 100644 --- a/entcl/pretrain.py +++ b/entcl/pretrain.py @@ -33,17 +33,6 @@ def pretrain(args, model): weight_decay=args.weight_decay, ) - if not args.no_sched: - logger.debug("Using Cosine Annealing LR Scheduler") - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer=optimiser, - T_max=args.pretrain_epochs, - eta_min=args.lr * 1e-3, - verbose=args.verbose, - ) - else: - logger.debug("Not Using Scheduler") - criterion = torch.nn.CrossEntropyLoss() results = None @@ -88,9 +77,6 @@ def pretrain(args, model): results.to_csv(f"{args.exp_dir}/results_s0.csv", index=False) logger.debug(f"Epoch {epoch} Finished. Pretrain Results Saved to {args.exp_dir}/results_s0.csv") - if not args.no_sched: - scheduler.step() - return model else: raise ValueError(f"No Model to load and mode is not pretrain or both. Mode: {args.mode}, Pretrain Load: {args.pretrain_load}") diff --git a/entcl/run.py b/entcl/run.py index 5cc75c5717f9d140a94bef38c65d32d4825aabbb..dba4b8fc2920709933f6f5cc4cd08bcd46e359a1 100644 --- a/entcl/run.py +++ b/entcl/run.py @@ -77,11 +77,10 @@ if __name__ == "__main__": parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset to use', choices=['cifar100']) # optimiser args - parser.add_argument('--lr', type=float, default=0.1, help='Learning Rate for all optimisers') - parser.add_argument('--gamma', type=float, default=0.1, help='Gamma for all optimisers') - parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for all optimisers') - parser.add_argument('--weight_decay', type=float, default=5e-5, help='Weight Decay for all optimisers') - parser.add_argument('--no_sched', action='store_true', help='Do not use a scheduler') + parser.add_argument('--pretrain_lr', type=float, default=1e-3, help='Learning Rate for Pretraining Optimiser') + parser.add_argument('--cl_lr', type=float, default=1e-6, help='Learning Rate for CL Optimiser') + parser.add_argument('--momentum', type=float, default=0, help='Momentum for all optimisers') + parser.add_argument('--weight_decay', type=float, default=0, help='Weight Decay for all optimisers') # cl args parser.add_argument('--known', type=int, default=50, help='Number of known classes. The rest are novel classes') @@ -96,21 +95,19 @@ if __name__ == "__main__": # pretrain args parser.add_argument('--pretrain_epochs', type=int, default=100, help='Number of epochs for pretraining') - parser.add_argument('--pretrain_sel_strat', type=str, default='last', choices=['last', 'best'], help='Pretrain Model Selection Strategy') parser.add_argument('--pretrain_load', type=str, default=None, help='Path to a pretrained model to load') - parser.add_argument('--retain_all', action='store_true', default=False, help='Keep all model checkpoints') - # model args parser.add_argument('--head', type=str, default='linear', help='Classification head to use', choices=['linear','mlp', 'dino_head']) parser.add_argument("--backbone", type=int, default=1, help="Version of DINO to use", choices=[1, 2]) # ood args - parser.add_argument('--ood_score', type=str, default='entropy', help='Changes the metric(s) to base OOD detection on', choices=['entropy', 'energy', 'both']) + parser.add_argument('--ood_score', type=str, default='entropy', help='Changes the metric(s) to base OOD detection on', choices=['entropy', 'energy', 'both', 'cheat']) parser.add_argument('--ood_eps', type=float, default=1e-8, help='Epsilon value for computing entropy in OOD detection') # ncd args - parser.add_argument('--ncd_findk_method', type=str, default='cheat', help='Method to use for finding the number of novel classes', choices=['elbow', 'silhouette', 'gap', 'cheat']) + parser.add_argument('--cheat_ncd', action='store_true', default=False, help='Cheat NCD. Use the true labels for NCD') + args = parser.parse_args()