Skip to content
Snippets Groups Projects
Commit 8d5ce588 authored by Fanis Baikas's avatar Fanis Baikas
Browse files

Added learning rate scheduler (lr_scheduler.py). Also set the number of...

Added learning rate scheduler (lr_scheduler.py). Also set the number of workers in the train_dataloader instantiation to 24.
parent 7987092e
Branches
No related tags found
No related merge requests found
......@@ -16,4 +16,5 @@ config.dataset = 'casia_webface'
config.data_dir = 'data/casia_webface'
config.num_classes = 10575
config.num_image = 494414
config.num_epoch = 1
\ No newline at end of file
config.warmup_epoch = 0
config.num_epoch = 40
\ No newline at end of file
......@@ -16,4 +16,5 @@ config.dataset = 'ms1mv3'
config.data_dir = 'data/ms1m-retinaface-t1'
config.num_classes = 93431
config.num_image = 5179510
config.warmup_epoch = 0
config.num_epoch = 40
\ No newline at end of file
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import SGD
import torch
import warnings
class PolynomialLRWarmup(_LRScheduler):
def __init__(self, optimizer, warmup_iters, total_iters=5, power=1.0, last_epoch=-1, verbose=False):
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
self.total_iters = total_iters
self.power = power
self.warmup_iters = warmup_iters
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if self.last_epoch == 0 or self.last_epoch > self.total_iters:
return [group["lr"] for group in self.optimizer.param_groups]
if self.last_epoch <= self.warmup_iters:
return [base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs]
else:
l = self.last_epoch
w = self.warmup_iters
t = self.total_iters
decay_factor = ((1.0 - (l - w) / (t - w)) / (1.0 - (l - 1 - w) / (t - w))) ** self.power
return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
if self.last_epoch <= self.warmup_iters:
return [
base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs]
else:
return [
(
base_lr * (1.0 - (min(self.total_iters, self.last_epoch) - self.warmup_iters) / (self.total_iters - self.warmup_iters)) ** self.power
)
for base_lr in self.base_lrs
]
if __name__ == "__main__":
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(32, 32)
def forward(self, x):
return self.linear(x)
test_module = TestModule()
test_module_pfc = TestModule()
lr_pfc_weight = 1 / 3
base_lr = 10
total_steps = 1000
sgd = SGD([
{"params": test_module.parameters(), "lr": base_lr},
{"params": test_module_pfc.parameters(), "lr": base_lr * lr_pfc_weight}
], base_lr)
scheduler = PolynomialLRWarmup(sgd, total_steps//10, total_steps, power=2)
x = []
y = []
y_pfc = []
for i in range(total_steps):
scheduler.step()
lr = scheduler.get_last_lr()[0]
lr_pfc = scheduler.get_last_lr()[1]
x.append(i)
y.append(lr)
y_pfc.append(lr_pfc)
import matplotlib.pyplot as plt
fontsize=15
plt.figure(figsize=(6, 6))
plt.plot(x, y, linestyle='-', linewidth=2, )
plt.plot(x, y_pfc, linestyle='-', linewidth=2, )
plt.xlabel('Iterations') # x_label
plt.ylabel("Lr") # y_label
plt.savefig("tmp.png", dpi=600, bbox_inches='tight')
......@@ -14,6 +14,7 @@ from backbones import get_model
from dataset import MXFaceDataset
from losses import CombinedMarginLoss
from fc_layer import FC_Layer
from lr_scheduler import PolynomialLRWarmup
def train_pipleine(args):
# Append parent directory to sys.path to ensure proper import of config modules
......@@ -86,11 +87,16 @@ def train_pipleine(args):
params=[{"params": backbone.parameters()}, {"params": fc_layer.parameters()}],
lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
# Set learning rate scheduler --------------------------------------------------------------------------------------
cfg.warmup_step = cfg.num_image // cfg.batch_size * cfg.warmup_epoch
cfg.total_step = cfg.num_image // cfg.batch_size * cfg.num_epoch
lr_scheduler = PolynomialLRWarmup(optimizer=opt, warmup_iters=cfg.warmup_step, total_iters=cfg.total_step)
# Set loss function ------------------------------------------------------------------------------------------------
loss_fn = torch.nn.CrossEntropyLoss()
# Create dataloader
train_dataloader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True)
# Create dataloader ------------------------------------------------------------------------------------------------
train_dataloader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True, num_workers=24, pin_memory=True)
# Monitor model through wandb
wandb.watch(backbone, loss_fn, log="all", log_freq=10)
......@@ -122,15 +128,18 @@ def train_pipleine(args):
# Adjust learning weights
opt.step()
# Adjust learning rate
lr_scheduler.step()
example_ct += len(imgs)
batch_ct += 1
time_elapsed = time.time() - batch_start_time
# Report metrics every 4 batches
if ((batch_ct + 1) % 4) == 0:
# Log epoch and loss
wandb.log({'epoch': epoch, 'loss': loss}, step=example_ct)
print(f"Epoch {epoch}, Batch {i}, Loss after {str(example_ct)} examples: {loss:.3f},"
f" Batch time: {time_elapsed:.4f}")
wandb.log({'epoch': epoch, 'loss': loss, 'lr': lr_scheduler.get_last_lr()}, step=batch_ct)
print(f"Epoch {epoch}, Batch {i}, Loss after {str(batch_ct)} batches: {loss:.3f},"
f" Batch time: {time_elapsed:.4f}, lr: {lr_scheduler.get_last_lr()[0]:.3f}")
trained_models_dir = 'trained_models'
if not os.path.exists(trained_models_dir):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment