Skip to content
Snippets Groups Projects

Added additional measurements during training

Merged eca1g19 requested to merge Sequence2Sequence into main
2 files
+ 102
51
Compare changes
  • Side-by-side
  • Inline

Files

+ 56
21
@@ -123,7 +123,7 @@ class BertLightning(LightningModule):
@@ -123,7 +123,7 @@ class BertLightning(LightningModule):
class Seq2SeqLightning(LightningModule):
class Seq2SeqLightning(LightningModule):
def __init__(self, model, learning_rate=1e-4, greedy_decode=False, beam_search_decode=True,
def __init__(self, model, learning_rate=1e-4, greedy_decode=False, beam_search_decode=True,
beam_width=3, name_override=None, max_output_length=512):
beam_width=3, name_override=None, max_output_length=512, generate_during_validation=True, generate_during_test=True):
super().__init__()
super().__init__()
self.model = model
self.model = model
@@ -153,6 +153,9 @@ class Seq2SeqLightning(LightningModule):
@@ -153,6 +153,9 @@ class Seq2SeqLightning(LightningModule):
self.max_length = max_output_length
self.max_length = max_output_length
self.scorer = ROUGEScore()
self.scorer = ROUGEScore()
 
self.generate_during_validation = generate_during_validation
 
self.generate_during_test = generate_during_test
 
def rename_keys(self, name, rouge_result):
def rename_keys(self, name, rouge_result):
new_dict = {}
new_dict = {}
for key, value in rouge_result.items():
for key, value in rouge_result.items():
@@ -187,8 +190,33 @@ class Seq2SeqLightning(LightningModule):
@@ -187,8 +190,33 @@ class Seq2SeqLightning(LightningModule):
logits = self(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels)
logits = self(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels)
val_loss = self.criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
val_loss = self.criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
self.log('val_loss', val_loss)
self.log('val_loss', val_loss)
return {'val_loss': val_loss}
if self.generate_during_validation:
validation_step_metrics = {}
 
decoded_targets = self.model.tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
 
if self.greedy_decode:
 
greedy_outputs = greedy_search(logits)
 
greedy_decoded_outputs = self.model.tokenizer.batch_decode(greedy_outputs, skip_special_tokens=True)
 
rouge_result = self.scorer(greedy_decoded_outputs, decoded_targets)
 
for k, v in rouge_result.items():
 
self.log(f"validation_greedy_{k}", v, on_step=False, on_epoch=True)
 
validation_step_metrics[f"validation_greedy_{k}"] = v
 
if self.beam_search_decode:
 
beam_search_outputs, beam_search_scores, beam_search_lengths = beam_search_bert(logits,
 
beam_width=self.beam_width,
 
max_length=self.max_length)
 
best_beam_indices = torch.argmax(beam_search_scores, dim=-1)
 
best_beams = beam_search_outputs[torch.arange(beam_search_outputs.shape[0]), best_beam_indices,
 
:].tolist()
 
beam_decoded_outputs = self.model.tokenizer.batch_decode(best_beams, skip_special_tokens=True)
 
rouge_result = self.scorer(beam_decoded_outputs, decoded_targets)
 
for k, v in rouge_result.items():
 
self.log(f"validation_beam_{k}", v, on_step=False, on_epoch=True)
 
validation_step_metrics[f"validation_beam_{k}"] = v
 
validation_step_metrics['val_loss'] = val_loss
 
self.log('val_loss', val_loss)
 
return validation_step_metrics
 
else:
 
return {'val_loss': val_loss}
@torch.no_grad()
@torch.no_grad()
def test_step(self, batch, batch_idx):
def test_step(self, batch, batch_idx):
@@ -197,24 +225,31 @@ class Seq2SeqLightning(LightningModule):
@@ -197,24 +225,31 @@ class Seq2SeqLightning(LightningModule):
decoder_input_ids = squeeze_if_needed(batch['decoder_input_ids'])
decoder_input_ids = squeeze_if_needed(batch['decoder_input_ids'])
decoder_attention_mask = squeeze_if_needed(batch['decoder_attention_mask'])
decoder_attention_mask = squeeze_if_needed(batch['decoder_attention_mask'])
labels = squeeze_if_needed(batch['labels'])
labels = squeeze_if_needed(batch['labels'])
outputs = self(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels)
logits = self(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, labels)
test_loss = self.criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
test_loss = self.criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
self.log('test_loss', test_loss)
decoded_targets = self.model.tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
if self.greedy_decode:
greedy_outputs = greedy_search(outputs)
greedy_decoded_outputs = self.model.tokenizer.batch_decode(greedy_outputs, skip_special_tokens=True)
rouge_result = self.scorer(greedy_decoded_outputs, decoded_targets)
self.log_dict(rouge_result, on_step=False, on_epoch=True)
if self.beam_search_decode:
beam_search_outputs, beam_search_scores, beam_search_lengths = beam_search_bert(outputs, beam_width=self.beam_width, max_length=self.max_length)
best_beam_indices = torch.argmax(beam_search_scores, dim=-1)
best_beams = beam_search_outputs[torch.arange(beam_search_outputs.shape[0]), best_beam_indices, :].tolist()
beam_decoded_outputs = self.model.tokenizer.batch_decode(best_beams, skip_special_tokens=True)
rouge_result = self.scorer(beam_decoded_outputs, decoded_targets)
self.log_dict(rouge_result, on_step=False, on_epoch=True)
rouge_result['test_loss'] = test_loss
self.log('test_loss', test_loss)
self.log('test_loss', test_loss)
return rouge_result
if self.generate_during_test:
 
test_step_metrics = {}
 
decoded_targets = self.model.tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
 
if self.greedy_decode:
 
greedy_outputs = greedy_search(logits)
 
greedy_decoded_outputs = self.model.tokenizer.batch_decode(greedy_outputs, skip_special_tokens=True)
 
rouge_result = self.scorer(greedy_decoded_outputs, decoded_targets)
 
for k, v in rouge_result.items():
 
self.log(f"test_greedy_{k}", v, on_step=False, on_epoch=True)
 
test_step_metrics[f"test_greedy_{k}"] = v
 
if self.beam_search_decode:
 
beam_search_outputs, beam_search_scores, beam_search_lengths = beam_search_bert(logits, beam_width=self.beam_width, max_length=self.max_length)
 
best_beam_indices = torch.argmax(beam_search_scores, dim=-1)
 
best_beams = beam_search_outputs[torch.arange(beam_search_outputs.shape[0]), best_beam_indices, :].tolist()
 
beam_decoded_outputs = self.model.tokenizer.batch_decode(best_beams, skip_special_tokens=True)
 
rouge_result = self.scorer(beam_decoded_outputs, decoded_targets)
 
for k, v in rouge_result.items():
 
self.log(f"test_greedy_{k}", v, on_step=False, on_epoch=True)
 
test_step_metrics[f"test_greedy_{k}"] = v
 
test_step_metrics['test_loss'] = test_loss
 
return test_step_metrics
 
else:
 
return {'test_loss': test_loss}
Loading