Skip to content
Snippets Groups Projects
Commit 5f16808a authored by eca1g19's avatar eca1g19
Browse files

Merge branch 'Sequence2Sequence' into 'main'

Added additional measurements during training

See merge request !3
parents 5d15a87b 18572a85
No related branches found
No related tags found
1 merge request!3Added additional measurements during training
......@@ -123,7 +123,7 @@ class BertLightning(LightningModule):
class Seq2SeqLightning(LightningModule):
def __init__(self, model, learning_rate=1e-4, greedy_decode=False, beam_search_decode=True,
beam_width=3, name_override=None, max_output_length=512):
beam_width=3, name_override=None, max_output_length=512, generate_during_validation=True, generate_during_test=True):
super().__init__()
self.model = model
......@@ -153,6 +153,9 @@ class Seq2SeqLightning(LightningModule):
self.max_length = max_output_length
self.scorer = ROUGEScore()
self.generate_during_validation = generate_during_validation
self.generate_during_test = generate_during_test
def rename_keys(self, name, rouge_result):
new_dict = {}
for key, value in rouge_result.items():
......@@ -187,9 +190,34 @@ class Seq2SeqLightning(LightningModule):
logits = self(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels)
val_loss = self.criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
self.log('val_loss', val_loss)
if self.generate_during_validation:
validation_step_metrics = {}
decoded_targets = self.model.tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
if self.greedy_decode:
greedy_outputs = greedy_search(logits)
greedy_decoded_outputs = self.model.tokenizer.batch_decode(greedy_outputs, skip_special_tokens=True)
rouge_result = self.scorer(greedy_decoded_outputs, decoded_targets)
for k, v in rouge_result.items():
self.log(f"validation_greedy_{k}", v, on_step=False, on_epoch=True)
validation_step_metrics[f"validation_greedy_{k}"] = v
if self.beam_search_decode:
beam_search_outputs, beam_search_scores, beam_search_lengths = beam_search_bert(logits,
beam_width=self.beam_width,
max_length=self.max_length)
best_beam_indices = torch.argmax(beam_search_scores, dim=-1)
best_beams = beam_search_outputs[torch.arange(beam_search_outputs.shape[0]), best_beam_indices,
:].tolist()
beam_decoded_outputs = self.model.tokenizer.batch_decode(best_beams, skip_special_tokens=True)
rouge_result = self.scorer(beam_decoded_outputs, decoded_targets)
for k, v in rouge_result.items():
self.log(f"validation_beam_{k}", v, on_step=False, on_epoch=True)
validation_step_metrics[f"validation_beam_{k}"] = v
validation_step_metrics['val_loss'] = val_loss
self.log('val_loss', val_loss)
return validation_step_metrics
else:
return {'val_loss': val_loss}
@torch.no_grad()
def test_step(self, batch, batch_idx):
input_ids = squeeze_if_needed(batch['input_ids'])
......@@ -197,24 +225,31 @@ class Seq2SeqLightning(LightningModule):
decoder_input_ids = squeeze_if_needed(batch['decoder_input_ids'])
decoder_attention_mask = squeeze_if_needed(batch['decoder_attention_mask'])
labels = squeeze_if_needed(batch['labels'])
outputs = self(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels)
test_loss = self.criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
logits = self(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels)
test_loss = self.criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
self.log('test_loss', test_loss)
if self.generate_during_test:
test_step_metrics = {}
decoded_targets = self.model.tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
if self.greedy_decode:
greedy_outputs = greedy_search(outputs)
greedy_outputs = greedy_search(logits)
greedy_decoded_outputs = self.model.tokenizer.batch_decode(greedy_outputs, skip_special_tokens=True)
rouge_result = self.scorer(greedy_decoded_outputs, decoded_targets)
self.log_dict(rouge_result, on_step=False, on_epoch=True)
for k, v in rouge_result.items():
self.log(f"test_greedy_{k}", v, on_step=False, on_epoch=True)
test_step_metrics[f"test_greedy_{k}"] = v
if self.beam_search_decode:
beam_search_outputs, beam_search_scores, beam_search_lengths = beam_search_bert(outputs, beam_width=self.beam_width, max_length=self.max_length)
beam_search_outputs, beam_search_scores, beam_search_lengths = beam_search_bert(logits, beam_width=self.beam_width, max_length=self.max_length)
best_beam_indices = torch.argmax(beam_search_scores, dim=-1)
best_beams = beam_search_outputs[torch.arange(beam_search_outputs.shape[0]), best_beam_indices, :].tolist()
beam_decoded_outputs = self.model.tokenizer.batch_decode(best_beams, skip_special_tokens=True)
rouge_result = self.scorer(beam_decoded_outputs, decoded_targets)
self.log_dict(rouge_result, on_step=False, on_epoch=True)
rouge_result['test_loss'] = test_loss
self.log('test_loss', test_loss)
return rouge_result
for k, v in rouge_result.items():
self.log(f"test_greedy_{k}", v, on_step=False, on_epoch=True)
test_step_metrics[f"test_greedy_{k}"] = v
test_step_metrics['test_loss'] = test_loss
return test_step_metrics
else:
return {'test_loss': test_loss}
......@@ -40,28 +40,42 @@ def main():
verbose = 1
add_time_to_model_name = True
TRAINING_FINAL_MODEL = False
# Training config
if TRAINING_FINAL_MODEL:
num_epochs = 10 # 3 for debugging, 8 or 10 for training
num_k_folds = 0 # For Cross-Validating to assess model performance
cross_validation_k_folder = None
if num_k_folds > 0:
cross_validation_k_folder = KFold(n_splits=num_k_folds)
batch_size = 16
batch_size = 32
gradient_accumulation_steps = 2
# Dataset config
# percentage of each split to load
# Dataset configs
train_split_percentage = 100 # 100 for train, 1 for debugging
validate_split_percentage = 100 # 10 for validation, 1 for debugging
validate_split_percentage = 10 # 10 for validation, 1 for debugging
test_split_percentage = 100 # 10 for test, 1 for debugging
# number of entries to load
train_split_num = None # 80 for debugging - None for train
validate_split_num = None # 20 for debugging - None for validation
test_split_num = None # 20 for debugging - None for test
else:
num_epochs = 3
batch_size = 4
gradient_accumulation_steps = 2
# Dataset configs
train_split_percentage = 10 # 100 for train, 1 for debugging
validate_split_percentage = 1 # 10 for validation, 1 for debugging
test_split_percentage = 1 # 10 for test, 1 for debugging
train_split_num = max([200, (batch_size*50)]) # must be high enough for logger to log metrics at least once
validate_split_num = 20 # 20 for debugging - None for validation
test_split_num = 20 # 20 for debugging - None for test
num_k_folds = 0 # For Cross-Validating to assess model performance
cross_validation_kfolder = None
if num_k_folds > 0:
cross_validation_kfolder = KFold(n_splits=num_k_folds)
use_lightning = True
use_fp16 = True
mixed_precision = "16-mixed"
......@@ -109,10 +123,6 @@ def main():
# Define Model Object
decoders = [SingleDenseBertDecoder(), DoubleDenseBertDecoder(), BiLSTMBertDecoder()]
# Define KFold Object, set to None if not cross validating
# Define Optimizer (AdamW) - Filters to only optimize params that are not frozen (i.e. not bert)
# Define loss function object
criterion = nn.NLLLoss(ignore_index=tokenizer.pad_token_id)
# Load Datasets into data-loaders
test_loader = DataLoader(test_dataset, batch_size=batch_size)
......@@ -131,9 +141,9 @@ def main():
datetime_string = string_utils.get_datetime_string()
if cross_validation_k_folder is not None:
if cross_validation_kfolder is not None:
print("Cross Validating (this may take a while)")
for fold, (train_idx, val_idx) in enumerate(cross_validation_k_folder.split(train_dataset)):
for fold, (train_idx, val_idx) in enumerate(cross_validation_kfolder.split(train_dataset)):
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
......@@ -155,18 +165,19 @@ def main():
)
trainer.fit(model, train_loader, val_loader)
for decoder in decoders:
for i, decoder in enumerate(decoders):
model = EncoderDecoderBase(encoder=encoder, decoder=decoder, tokenizer=tokenizer)
model = Seq2SeqLightning(model)
wandb_name = f"seq2seq_{model.name}_run_" + datetime_string
wandb_logger = WandbLogger(name=wandb_name, project="seq2seq_lightning")
checkpoint_callback = ModelCheckpoint(
loss_checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
dirpath='checkpoints',
filename='seq2seq-{epoch:02d}-{val_loss:.2f}',
filename='seq2seq-'+model.name+'-{epoch:02d}-{val_loss:.2f}',
save_top_k=2,
mode='min',
)
early_stopping = EarlyStopping(monitor='val_loss', patience=3, mode='min')
trainer = Trainer(
max_epochs=num_epochs,
......@@ -174,12 +185,17 @@ def main():
devices="auto",
precision=mixed_precision if use_fp16 else 32,
logger=wandb_logger,
callbacks=[checkpoint_callback, early_stopping],
callbacks=[loss_checkpoint_callback, early_stopping],
accumulate_grad_batches=gradient_accumulation_steps
)
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)
wandb.alert()
if i == len(decoders) - 1: # Check if final model has been trained
wandb.alert(
title="Training Complete",
text="Training has completed for all models.",
level=wandb.AlertLevel.INFO,
)
wandb.finish()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment