Skip to content
Snippets Groups Projects
Commit 24a0d00f authored by eca1g19's avatar eca1g19
Browse files

added more data tracking stuff

parent a905f09e
No related branches found
No related tags found
No related merge requests found
...@@ -24,6 +24,7 @@ class FrozenBert(nn.Module): ...@@ -24,6 +24,7 @@ class FrozenBert(nn.Module):
def forward(self, input_ids, attention_mask, token_type_ids=None): def forward(self, input_ids, attention_mask, token_type_ids=None):
# Return Context Embeddings # Return Context Embeddings
return self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) return self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
......
from sklearn.model_selection import KFold
from pytorch_lightning import Trainer, LightningModule
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
# Used if you ever want to verify if some percentages add to 100 - useless # Used if you ever want to verify if some percentages add to 100 - useless
def verify_split_percentages(test_split_percentage, train_split_percentage, validate_split_percentage=None, def verify_split_percentages(test_split_percentage, train_split_percentage, validate_split_percentage=None,
verbose=False): verbose=False):
...@@ -22,3 +28,29 @@ def verify_split_percentages(test_split_percentage, train_split_percentage, vali ...@@ -22,3 +28,29 @@ def verify_split_percentages(test_split_percentage, train_split_percentage, vali
output += f"\nTotal percentage: {total_percentage}%." output += f"\nTotal percentage: {total_percentage}%."
print(output) print(output)
return False return False
def cross_validate(model, test_dataloader, k_folds=10):
# kfold doesnt work on loaders
dataset = list(test_dataloader.dataset)
# create k-folder
kfolder = KFold(n_splits=k_folds, shuffle=True)
progress_bar = tqdm(enumerate(kfolder.split(dataset)), total=k_folds, desc="Folds")
results = []
for fold, (train_ids, val_ids) in enumerate(kfolder.split(dataset)):
# Subsets made according to ids returned by the kfolder
train_subsample = Subset(dataset, train_ids)
val_subsample = Subset(dataset, val_ids)
# Load corresponding dataset
# train_data_loader = DataLoader(train_subsample, batch_size=test_dataloader.batch_size, shuffle=True)
val_data_loader = DataLoader(val_subsample, batch_size=test_dataloader.batch_size, shuffle=False)
trainer = Trainer(max_epochs=0)
trainer.validate(model=model, dataloaders=val_data_loader)
trainer = Trainer(max_epochs=0)
trainer.test(dataloaders=test_dataloader)
import torch import torch
from itertools import groupby
def beam_search_bert(model_outputs, max_length=512, beam_size=3, min_length=None, tokenizer_eos_token_id=None): def shrink_list(lst):
groups = groupby(lst)
result = []
for key, group in groups:
count = sum(1 for _ in group)
if count > 1:
result.append(f"{key}*{count}")
else:
result.append(key)
return result
# Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation
def beam_search_bert(model_outputs, max_length=512, beam_size=3, min_length=None, tokenizer_eos_token_id=None,
length_penalty_enabled=True, alpha=0.6):
device = model_outputs.device device = model_outputs.device
if len(model_outputs.shape) == 2: if len(model_outputs.shape) == 2:
...@@ -45,4 +62,15 @@ def beam_search_bert(model_outputs, max_length=512, beam_size=3, min_length=None ...@@ -45,4 +62,15 @@ def beam_search_bert(model_outputs, max_length=512, beam_size=3, min_length=None
beam_scores = top_scores beam_scores = top_scores
beam_lengths = beam_lengths.gather(1, prev_beam) + 1 beam_lengths = beam_lengths.gather(1, prev_beam) + 1
if length_penalty_enabled:
# Apply length penalty
length_penalty = ((5.0 + beam_lengths.float()) / 6.0).pow(alpha)
beam_scores = beam_scores / length_penalty
return beams, beam_scores, beam_lengths return beams, beam_scores, beam_lengths
def greedy_search(model_outputs):
# Argmax probability dimension
max_scores, token_ids = torch.max(model_outputs, dim=-1)
return token_ids
\ No newline at end of file
...@@ -4,13 +4,14 @@ from torch.optim import AdamW ...@@ -4,13 +4,14 @@ from torch.optim import AdamW
from torchmetrics.functional.text.rouge import rouge_score from torchmetrics.functional.text.rouge import rouge_score
from transformers import BertTokenizerFast from transformers import BertTokenizerFast
import torch import torch
from beam_search import beam_search_bert from decode_utils import beam_search_bert, greedy_search, shrink_list
class BertLightning(LightningModule): class BertLightning(LightningModule):
def __init__(self, model, learning_rate=1e-5, tokenizer=None, greedy_decode=False, beam_width=3, name_override=None): def __init__(self, model, learning_rate=1e-5, tokenizer=None, greedy_decode=True, beam_search_decode=True,
beam_width=3, name_override=None):
super().__init__() super().__init__()
self.save_hyperparameters()
self.model = model self.model = model
try: try:
self.name = self.model.name self.name = self.model.name
...@@ -24,6 +25,7 @@ class BertLightning(LightningModule): ...@@ -24,6 +25,7 @@ class BertLightning(LightningModule):
self.greedy_decode = greedy_decode self.greedy_decode = greedy_decode
self.beam_width = beam_width self.beam_width = beam_width
self.beam_search_decode = beam_search_decode
self.validation_step_outputs = [] self.validation_step_outputs = []
self.validation_step_labels = [] self.validation_step_labels = []
...@@ -53,8 +55,54 @@ class BertLightning(LightningModule): ...@@ -53,8 +55,54 @@ class BertLightning(LightningModule):
outputs = self(input_ids=input_ids, attention_mask=attention_mask) outputs = self(input_ids=input_ids, attention_mask=attention_mask)
loss = self.criterion(outputs.view(-1, self.model.bert.config.vocab_size), labels.view(-1)) loss = self.criterion(outputs.view(-1, self.model.bert.config.vocab_size), labels.view(-1))
self.log('val_loss', loss, prog_bar=True) self.log('val_loss', loss, prog_bar=True)
#self.validation_step_outputs.append(predicted_indices) input_ids = batch["input_ids"]
#self.validation_step_labels.append(labels) attention_mask = batch["attention_mask"]
labels = batch["labels"]
decoded_lables = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
outputs = self(input_ids=input_ids, attention_mask=attention_mask)
if self.greedy_decode:
outputs = greedy_search(outputs)
decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
scores = rouge_score(decoded_outputs, decoded_lables)
for key, value in scores.items():
self.log(f'test_greedy_{key}', value)
if self.beam_search_decode:
beams, beam_scores, beam_lengths = beam_search_bert(outputs, length_penalty_enabled=False, beam_size=10)
best_beam_indices = torch.argmax(beam_scores, dim=-1)
best_beams = beams[torch.arange(beams.size(0)), best_beam_indices] # Currently based on highest score
decoded_outputs = self.tokenizer.batch_decode(best_beams[0], skip_special_tokens=True)
scores = rouge_score(decoded_outputs, decoded_lables)
for key, value in scores.items():
if key == "rouge1_fmeasure" or (key == "val_rouge2_fmeasure" and value != 0.0):
self.log(f'val_beam_decode_{key}', value, prog_bar=True)
else:
self.log(f'val_beam_decode_{key}', value)
@torch.no_grad()
def test_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
decoded_lables = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
outputs = self(input_ids=input_ids, attention_mask=attention_mask)
if self.greedy_decode:
outputs = greedy_search(outputs)
decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
scores = rouge_score(decoded_outputs, decoded_lables)
for key, value in scores.items():
self.log(f'test_greedy_{key}', value, prog_bar=False)
if self.beam_search_decode:
beams, beam_scores, beam_lengths = beam_search_bert(outputs, length_penalty_enabled=False, beam_size=10)
best_beam_indices = torch.argmax(beam_scores, dim=-1)
best_beams = beams[torch.arange(beams.size(0)), best_beam_indices] # Currently based on highest score
decoded_outputs = self.tokenizer.batch_decode(best_beams[0], skip_special_tokens=True)
scores = rouge_score(decoded_outputs, decoded_lables)
for key, value in scores.items():
if key == "rouge1_fmeasure" or (key == "val_rouge2_fmeasure" and value != 0.0):
self.log(f'test_beam_decode_{key}', value, prog_bar=True)
else:
self.log(f'test_beam_decode_{key}', value, prog_bar=False)
def on_validation_epoch_end(self): def on_validation_epoch_end(self):
# Calculate ROUGE score # Calculate ROUGE score
...@@ -62,7 +110,7 @@ class BertLightning(LightningModule): ...@@ -62,7 +110,7 @@ class BertLightning(LightningModule):
# Log ROUGE scores # Log ROUGE scores
for key, value in scores.items(): for key, value in scores.items():
self.log(f'val_{key}', value, prog_bar=True) self.log(f'{self.name}val_{key}', value, prog_bar=True)
def configure_optimizers(self): def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=self.learning_rate) optimizer = AdamW(self.parameters(), lr=self.learning_rate)
......
from pytorch_lightning.loggers import WandbLogger
from sklearn.model_selection import KFold from sklearn.model_selection import KFold
from torch.optim import AdamW from torch.optim import AdamW
from base_models import * from base_models import *
...@@ -15,6 +16,12 @@ from socket import gethostname ...@@ -15,6 +16,12 @@ from socket import gethostname
import evaluate import evaluate
import torch import torch
import train_utils import train_utils
import os
import glob
import wandb
wandb.login()
# Output cuda/cpu # Output cuda/cpu
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'==============') print(f'==============')
...@@ -119,15 +126,20 @@ test_dataset = CNNDailyMailDataset(tokenizer=tokenizer, split_type='test', split ...@@ -119,15 +126,20 @@ test_dataset = CNNDailyMailDataset(tokenizer=tokenizer, split_type='test', split
test_loader = DataLoader(test_dataset, batch_size=batch_size) test_loader = DataLoader(test_dataset, batch_size=batch_size)
# Define Model Object # Define Model Object
models = [BertDoubleDense(), BertBiLSTM(), BertSingleDense()] models = [BertSingleDense(),BertDoubleDense(), BertBiLSTM()]
# Define KFold Object, set to None if not cross validating # Define KFold Object, set to None if not cross validating
cross_validation_k_fold = KFold(n_splits=num_k_folds) if num_k_folds > 0 else None cross_validation_k_fold = KFold(n_splits=num_k_folds) if num_k_folds > 0 else None
# Define Optimizer (AdamW) - Filters to only optimize params that are not frozen (i.e. not bert) # Define Optimizer (AdamW) - Filters to only optimize params that are not frozen (i.e. not bert)
# Define loss function object # Define loss function object
criterion = nn.NLLLoss() criterion = nn.NLLLoss()
train = False
num_cpus = os.cpu_count()
available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
output_config() output_config()
wandb_logger = WandbLogger()
for model in models: for model in models:
if train:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.05) optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.05)
if not use_lightning: if not use_lightning:
model.to(device) model.to(device)
...@@ -154,11 +166,11 @@ for model in models: ...@@ -154,11 +166,11 @@ for model in models:
else: else:
train_loader = DataLoader(train_dataset, batch_size=batch_size) train_loader = DataLoader(train_dataset, batch_size=batch_size)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size) validation_loader = DataLoader(validation_dataset, batch_size=batch_size)
available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
print(available_gpus) print(available_gpus)
model = BertLightning(model) model = BertLightning(model)
print(f"Available GPUs: {len(available_gpus)}") print(f"Available GPUs: {len(available_gpus)}")
logger = CSVLogger("logs", name=model.name+"Logger") logger = CSVLogger("logs", name=model.name+"Logger")
trainer = Trainer(logger=wandb_logger)
loss_checkpoint_callback = ModelCheckpoint( loss_checkpoint_callback = ModelCheckpoint(
monitor='val_loss', monitor='val_loss',
dirpath='Models/', dirpath='Models/',
...@@ -191,3 +203,19 @@ for model in models: ...@@ -191,3 +203,19 @@ for model in models:
precision="16", precision="16",
callbacks=[loss_checkpoint_callback]) callbacks=[loss_checkpoint_callback])
trainer.fit(model, train_loader, validation_loader) trainer.fit(model, train_loader, validation_loader)
else:
checkpoint_dir = "Models/"
checkpoints = [f for f in os.listdir(checkpoint_dir)]
for checkpoint_path in glob.glob("Models/*.ckpt"):
for m in models:
try:
model = BertLightning.load_from_checkpoint(checkpoint_path, model=m)
model.beam_search_decode = True
wandb_logger = WandbLogger()
batch_size = 8 # I am a mortal on local machine
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0)
trainer = Trainer(logger=wandb_logger)
trainer.test(model=model, dataloaders=test_loader, ckpt_path=checkpoint_path)
except Exception as pokemon:
print(pokemon)
pass
...@@ -82,5 +82,6 @@ def create_pip_installs_from_requirements(text=None, file_path=None, print_file= ...@@ -82,5 +82,6 @@ def create_pip_installs_from_requirements(text=None, file_path=None, print_file=
if __name__ == '__main__': if __name__ == '__main__':
file_path = 'requirements.txt' file_path = 'requirements.txt'
fix_requirements(file_path=file_path)
create_pip_installs_from_requirements(file_path=file_path) create_pip_installs_from_requirements(file_path=file_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment