Skip to content
Snippets Groups Projects
Commit a905f09e authored by eca1g19's avatar eca1g19
Browse files

fixed requirements - added lightning

parent 66b3ac50
No related branches found
No related tags found
No related merge requests found
from pytorch_lightning import LightningModule
from torch import nn
from torch.optim import AdamW
from torchmetrics.functional.text.rouge import rouge_score
from transformers import BertTokenizerFast
import torch
from beam_search import beam_search_bert
class BertLightning(LightningModule):
def __init__(self, model, learning_rate=1e-5, tokenizer=None, greedy_decode=False, beam_width=3, name_override=None):
super().__init__()
self.model = model
try:
self.name = self.model.name
except AttributeError:
self.name = "BertLightning"
if name_override:
self.name = name_override
self.learning_rate = learning_rate
self.criterion = nn.NLLLoss()
self.greedy_decode = greedy_decode
self.beam_width = beam_width
self.validation_step_outputs = []
self.validation_step_labels = []
if tokenizer is None:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
self.tokenizer = tokenizer
def forward(self, input_ids, attention_mask, token_type_ids=None):
return self.model(input_ids, attention_mask, token_type_ids)
def training_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
outputs = self(input_ids, attention_mask)
loss = self.criterion(outputs.view(-1, self.model.bert.config.vocab_size), labels.view(-1))
self.log('train_loss', loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
outputs = self(input_ids=input_ids, attention_mask=attention_mask)
loss = self.criterion(outputs.view(-1, self.model.bert.config.vocab_size), labels.view(-1))
self.log('val_loss', loss, prog_bar=True)
#self.validation_step_outputs.append(predicted_indices)
#self.validation_step_labels.append(labels)
def on_validation_epoch_end(self):
# Calculate ROUGE score
scores = rouge_score(self.validation_step_outputs, self.validation_step_labels)
# Log ROUGE scores
for key, value in scores.items():
self.log(f'val_{key}', value, prog_bar=True)
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=self.learning_rate)
return optimizer
absl-py
aiohttp
aiosignal
async-timeout
cachetools
certifi
click
datasets
dill
evaluate
filelock
fonttools
frozenlist
fsspec
google-api-core
google-api-python-client
google-auth
google-auth-httplib2
googleapis-common-protos
httplib2
huggingface-hub
ipython-genutils
joblib
Jupyter-Beeper
lightning-utilities
mkl-fft
mkl-random
mkl-service
mpmath
multidict
multiprocess
munkres
networkx
nltk
oauth2client
pandas
Pillow
ply
protobuf
pyarrow
pyasn1
pyasn1-modules
PyDrive
pyenchant
PyQt5
pytorch-beam-search
pytorch-lightning
pywin32
PyYAML
pyzmq
regex
responses
rouge-score
rsa
scikit-learn
scipy
sentencepiece
seqeval
sympy
threadpoolctl
tokenizers
torch
torch-utils
torchaudio
torchdata
torchmetrics
torchtext
torchvision
transformers
uritemplate
webencodings
wincertstore
xxhash
yarl
absl-py==1.4.0
aiohttp==3.8.3
aiosignal==1.3.1
async-timeout==4.0.2
cachetools==5.3.1
certifi==2022.12.7
click==8.1.3
datasets==2.8.0
dill==0.3.6
evaluate==0.4.0
filelock==3.9.0
fonttools==4.25.0
frozenlist==1.3.3
fsspec==2022.11.0
google-api-core==2.11.0
google-api-python-client==2.89.0
google-auth==2.19.1
google-auth-httplib2==0.1.0
googleapis-common-protos==1.59.1
httplib2==0.22.0
huggingface-hub==0.11.1
ipython-genutils==0.2.0
joblib==1.2.0
Jupyter-Beeper==1.0.3
lightning-utilities==0.8.0
mkl-fft==1.3.1
mkl-random==1.2.2
mkl-service==2.4.0
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.14
munkres==1.1.4
networkx==3.1
nltk==3.8.1
oauth2client==4.1.3
pandas==1.5.1
Pillow==9.2.0
ply==3.11
protobuf==3.20.3
pyarrow==10.0.1
pyasn1==0.5.0
pyasn1-modules==0.3.0
PyDrive==1.3.1
pyenchant==3.2.2
PyQt5==5.15.7
pytorch-beam-search==1.2.2
pytorch-lightning==2.0.3
pywin32==302
PyYAML==6.0
pyzmq==25.1.0
regex==2022.10.31
responses==0.18.0
rouge-score==0.1.2
rsa==4.9
scikit-learn==1.1.3
scipy==1.9.3
sentencepiece==0.1.97
seqeval==1.2.2
sympy==1.11.1
threadpoolctl==3.1.0
tokenizers==0.13.2
torch==2.0.1+cu118
torch-utils==0.1.2
torchaudio==2.0.2+cu118
torchdata==0.6.1
torchmetrics==0.11.4
torchtext==0.15.2
torchvision==0.15.2+cu118
transformers==4.25.1
uritemplate==4.1.1
webencodings==0.5.1
wincertstore==0.2
xxhash==3.2.0
yarl==1.8.2
!pip install absl-py!pip install aiohttp!pip install aiosignal!pip install async-timeout!pip install cachetools!pip install certifi!pip install click!pip install datasets!pip install dill!pip install evaluate!pip install filelock!pip install fonttools!pip install frozenlist!pip install fsspec!pip install google-api-core!pip install google-api-python-client!pip install google-auth!pip install google-auth-httplib2!pip install googleapis-common-protos!pip install httplib2!pip install huggingface-hub!pip install ipython-genutils!pip install joblib!pip install Jupyter-Beeper!pip install lightning-utilities!pip install mkl-fft!pip install mkl-random!pip install mkl-service!pip install mpmath!pip install multidict!pip install multiprocess!pip install munkres!pip install networkx!pip install nltk!pip install oauth2client!pip install pandas!pip install Pillow!pip install ply!pip install protobuf!pip install pyarrow!pip install pyasn1!pip install pyasn1-modules!pip install PyDrive!pip install pyenchant!pip install PyQt5!pip install pytorch-beam-search!pip install pytorch-lightning!pip install pywin32!pip install PyYAML!pip install pyzmq!pip install regex!pip install responses!pip install rouge-score!pip install rsa!pip install scikit-learn!pip install scipy!pip install sentencepiece!pip install seqeval!pip install sympy!pip install threadpoolctl!pip install tokenizers!pip install torch!pip install torch-utils!pip install torchaudio!pip install torchdata!pip install torchmetrics!pip install torchtext!pip install torchvision!pip install transformers!pip install uritemplate!pip install webencodings!pip install wincertstore!pip install xxhash!pip install yarl
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment