Skip to content
Snippets Groups Projects
Commit 4df3d566 authored by eca1g19's avatar eca1g19
Browse files

Commit local changes (there are a lot)

parent 49c5f1e5
No related branches found
No related tags found
No related merge requests found
...@@ -161,7 +161,7 @@ class Seq2SeqLightning(LightningModule): ...@@ -161,7 +161,7 @@ class Seq2SeqLightning(LightningModule):
output = self(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask) output = self(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
loss = self.criterion(output.view(-1, output.size(-1)), labels.view(-1), ignore_index=0) loss = self.criterion(output.view(-1, output.size(-1)), labels.view(-1), ignore_index=0)
self.log('train_loss', loss) self.log('train_loss', loss)
return loss return loss.item()
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
input_ids = batch['input_ids'] input_ids = batch['input_ids']
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment