diff --git a/src/train.py b/src/train.py index ba06af756a6a73c78371c1205cd2557ad5778046..9b6304aeda29d50371156fa4378c265b3bb2351a 100644 --- a/src/train.py +++ b/src/train.py @@ -137,7 +137,7 @@ def train_pipleine(args): # Report metrics every 4 batches if ((batch_ct + 1) % 4) == 0: # Log epoch and loss - wandb.log({'epoch': epoch, 'loss': loss, 'lr': lr_scheduler.get_last_lr()}, step=batch_ct) + wandb.log({'epoch': epoch, 'loss': loss, 'lr': lr_scheduler.get_last_lr()[0]}, step=batch_ct) print(f"Epoch {epoch}, Batch {i}, Loss after {str(batch_ct)} batches: {loss:.3f}," f" Batch time: {time_elapsed:.4f}, lr: {lr_scheduler.get_last_lr()[0]:.3f}")