Added additional measurements during training
Compare changes
+ 56
− 21
@@ -123,7 +123,7 @@ class BertLightning(LightningModule):
@@ -123,7 +123,7 @@ class BertLightning(LightningModule):
@@ -153,6 +153,9 @@ class Seq2SeqLightning(LightningModule):
@@ -153,6 +153,9 @@ class Seq2SeqLightning(LightningModule):
@@ -187,8 +190,33 @@ class Seq2SeqLightning(LightningModule):
@@ -187,8 +190,33 @@ class Seq2SeqLightning(LightningModule):
@@ -197,24 +225,31 @@ class Seq2SeqLightning(LightningModule):
@@ -197,24 +225,31 @@ class Seq2SeqLightning(LightningModule):
greedy_decoded_outputs = self.model.tokenizer.batch_decode(greedy_outputs, skip_special_tokens=True)
beam_search_outputs, beam_search_scores, beam_search_lengths = beam_search_bert(outputs, beam_width=self.beam_width, max_length=self.max_length)
best_beams = beam_search_outputs[torch.arange(beam_search_outputs.shape[0]), best_beam_indices, :].tolist()