Skip to content
Snippets Groups Projects
Commit 7e093db8 authored by Joseph Omar's avatar Joseph Omar
Browse files

Removed LR Scheduler, updated args

parent 5b99f0c3
No related branches found
No related tags found
No related merge requests found
......@@ -71,14 +71,6 @@ def cl_session(args, session_dataset: torch.utils.data.Dataset, model: torch.nn.
weight_decay=args.weight_decay,
)
if not args.no_sched:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimiser,
T_max=args.cl_epochs,
eta_min=args.lr * 1e-3,
verbose=args.verbose,
)
criterion = torch.nn.CrossEntropyLoss()
results = None
......@@ -128,8 +120,6 @@ def cl_session(args, session_dataset: torch.utils.data.Dataset, model: torch.nn.
results.to_csv(f"{args.exp_dir}/results_s{args.current_session}.csv", index=False)
logger.debug(f"Session {args.current_session} Results saved to {args.exp_dir}/results_s{args.current_session}.csv")
if not args.no_sched:
scheduler.step()
return model
......
......@@ -33,17 +33,6 @@ def pretrain(args, model):
weight_decay=args.weight_decay,
)
if not args.no_sched:
logger.debug("Using Cosine Annealing LR Scheduler")
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimiser,
T_max=args.pretrain_epochs,
eta_min=args.lr * 1e-3,
verbose=args.verbose,
)
else:
logger.debug("Not Using Scheduler")
criterion = torch.nn.CrossEntropyLoss()
results = None
......@@ -88,9 +77,6 @@ def pretrain(args, model):
results.to_csv(f"{args.exp_dir}/results_s0.csv", index=False)
logger.debug(f"Epoch {epoch} Finished. Pretrain Results Saved to {args.exp_dir}/results_s0.csv")
if not args.no_sched:
scheduler.step()
return model
else:
raise ValueError(f"No Model to load and mode is not pretrain or both. Mode: {args.mode}, Pretrain Load: {args.pretrain_load}")
......
......@@ -77,11 +77,10 @@ if __name__ == "__main__":
parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset to use', choices=['cifar100'])
# optimiser args
parser.add_argument('--lr', type=float, default=0.1, help='Learning Rate for all optimisers')
parser.add_argument('--gamma', type=float, default=0.1, help='Gamma for all optimisers')
parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for all optimisers')
parser.add_argument('--weight_decay', type=float, default=5e-5, help='Weight Decay for all optimisers')
parser.add_argument('--no_sched', action='store_true', help='Do not use a scheduler')
parser.add_argument('--pretrain_lr', type=float, default=1e-3, help='Learning Rate for Pretraining Optimiser')
parser.add_argument('--cl_lr', type=float, default=1e-6, help='Learning Rate for CL Optimiser')
parser.add_argument('--momentum', type=float, default=0, help='Momentum for all optimisers')
parser.add_argument('--weight_decay', type=float, default=0, help='Weight Decay for all optimisers')
# cl args
parser.add_argument('--known', type=int, default=50, help='Number of known classes. The rest are novel classes')
......@@ -96,21 +95,19 @@ if __name__ == "__main__":
# pretrain args
parser.add_argument('--pretrain_epochs', type=int, default=100, help='Number of epochs for pretraining')
parser.add_argument('--pretrain_sel_strat', type=str, default='last', choices=['last', 'best'], help='Pretrain Model Selection Strategy')
parser.add_argument('--pretrain_load', type=str, default=None, help='Path to a pretrained model to load')
parser.add_argument('--retain_all', action='store_true', default=False, help='Keep all model checkpoints')
# model args
parser.add_argument('--head', type=str, default='linear', help='Classification head to use', choices=['linear','mlp', 'dino_head'])
parser.add_argument("--backbone", type=int, default=1, help="Version of DINO to use", choices=[1, 2])
# ood args
parser.add_argument('--ood_score', type=str, default='entropy', help='Changes the metric(s) to base OOD detection on', choices=['entropy', 'energy', 'both'])
parser.add_argument('--ood_score', type=str, default='entropy', help='Changes the metric(s) to base OOD detection on', choices=['entropy', 'energy', 'both', 'cheat'])
parser.add_argument('--ood_eps', type=float, default=1e-8, help='Epsilon value for computing entropy in OOD detection')
# ncd args
parser.add_argument('--ncd_findk_method', type=str, default='cheat', help='Method to use for finding the number of novel classes', choices=['elbow', 'silhouette', 'gap', 'cheat'])
parser.add_argument('--cheat_ncd', action='store_true', default=False, help='Cheat NCD. Use the true labels for NCD')
args = parser.parse_args()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment