diff --git a/base_models.py b/base_models.py index b555f2d0aa7baf874db2f02abdac1326f3ca1b30..564fb9ef4b6758b44b1ead4d70f69e63284b1a63 100644 --- a/base_models.py +++ b/base_models.py @@ -24,6 +24,7 @@ class FrozenBert(nn.Module): def forward(self, input_ids, attention_mask, token_type_ids=None): # Return Context Embeddings + return self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) diff --git a/data_loading_utils.py b/data_loading_utils.py index 9f18e26ed3474bf831f18b68925dc60d7be0b193..a7dff9ab471221d5f7585d15736b857bf3caf7ce 100644 --- a/data_loading_utils.py +++ b/data_loading_utils.py @@ -1,3 +1,9 @@ +from sklearn.model_selection import KFold +from pytorch_lightning import Trainer, LightningModule +from torch.utils.data import DataLoader, Subset +from tqdm import tqdm + + # Used if you ever want to verify if some percentages add to 100 - useless def verify_split_percentages(test_split_percentage, train_split_percentage, validate_split_percentage=None, verbose=False): @@ -22,3 +28,29 @@ def verify_split_percentages(test_split_percentage, train_split_percentage, vali output += f"\nTotal percentage: {total_percentage}%." print(output) return False + + +def cross_validate(model, test_dataloader, k_folds=10): + # kfold doesnt work on loaders + dataset = list(test_dataloader.dataset) + + # create k-folder + kfolder = KFold(n_splits=k_folds, shuffle=True) + + progress_bar = tqdm(enumerate(kfolder.split(dataset)), total=k_folds, desc="Folds") + + results = [] + for fold, (train_ids, val_ids) in enumerate(kfolder.split(dataset)): + # Subsets made according to ids returned by the kfolder + train_subsample = Subset(dataset, train_ids) + val_subsample = Subset(dataset, val_ids) + + # Load corresponding dataset + # train_data_loader = DataLoader(train_subsample, batch_size=test_dataloader.batch_size, shuffle=True) + val_data_loader = DataLoader(val_subsample, batch_size=test_dataloader.batch_size, shuffle=False) + + trainer = Trainer(max_epochs=0) + trainer.validate(model=model, dataloaders=val_data_loader) + + trainer = Trainer(max_epochs=0) + trainer.test(dataloaders=test_dataloader) diff --git a/beam_search.py b/decode_utils.py similarity index 68% rename from beam_search.py rename to decode_utils.py index 9e1deb961d34b7e77cc585c85ba4942fb4ff7ed6..bf3392cc001cee81b6f71c536e2ca2272f1854b0 100644 --- a/beam_search.py +++ b/decode_utils.py @@ -1,7 +1,24 @@ import torch +from itertools import groupby -def beam_search_bert(model_outputs, max_length=512, beam_size=3, min_length=None, tokenizer_eos_token_id=None): +def shrink_list(lst): + groups = groupby(lst) + result = [] + + for key, group in groups: + count = sum(1 for _ in group) + if count > 1: + result.append(f"{key}*{count}") + else: + result.append(key) + + return result + + +# Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation +def beam_search_bert(model_outputs, max_length=512, beam_size=3, min_length=None, tokenizer_eos_token_id=None, + length_penalty_enabled=True, alpha=0.6): device = model_outputs.device if len(model_outputs.shape) == 2: @@ -45,4 +62,15 @@ def beam_search_bert(model_outputs, max_length=512, beam_size=3, min_length=None beam_scores = top_scores beam_lengths = beam_lengths.gather(1, prev_beam) + 1 + if length_penalty_enabled: + # Apply length penalty + length_penalty = ((5.0 + beam_lengths.float()) / 6.0).pow(alpha) + beam_scores = beam_scores / length_penalty + return beams, beam_scores, beam_lengths + + +def greedy_search(model_outputs): + # Argmax probability dimension + max_scores, token_ids = torch.max(model_outputs, dim=-1) + return token_ids \ No newline at end of file diff --git a/lightning_models.py b/lightning_models.py index cb967294c2d701db630c666ca8b1894b754ef7a6..c123c1a4d2e46cbf0fc264a9cd4aa6f847ab5a7b 100644 --- a/lightning_models.py +++ b/lightning_models.py @@ -4,13 +4,14 @@ from torch.optim import AdamW from torchmetrics.functional.text.rouge import rouge_score from transformers import BertTokenizerFast import torch -from beam_search import beam_search_bert +from decode_utils import beam_search_bert, greedy_search, shrink_list class BertLightning(LightningModule): - def __init__(self, model, learning_rate=1e-5, tokenizer=None, greedy_decode=False, beam_width=3, name_override=None): + def __init__(self, model, learning_rate=1e-5, tokenizer=None, greedy_decode=True, beam_search_decode=True, + beam_width=3, name_override=None): super().__init__() - + self.save_hyperparameters() self.model = model try: self.name = self.model.name @@ -24,6 +25,7 @@ class BertLightning(LightningModule): self.greedy_decode = greedy_decode self.beam_width = beam_width + self.beam_search_decode = beam_search_decode self.validation_step_outputs = [] self.validation_step_labels = [] @@ -53,8 +55,54 @@ class BertLightning(LightningModule): outputs = self(input_ids=input_ids, attention_mask=attention_mask) loss = self.criterion(outputs.view(-1, self.model.bert.config.vocab_size), labels.view(-1)) self.log('val_loss', loss, prog_bar=True) - #self.validation_step_outputs.append(predicted_indices) - #self.validation_step_labels.append(labels) + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + labels = batch["labels"] + decoded_lables = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + outputs = self(input_ids=input_ids, attention_mask=attention_mask) + if self.greedy_decode: + outputs = greedy_search(outputs) + decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + scores = rouge_score(decoded_outputs, decoded_lables) + for key, value in scores.items(): + self.log(f'test_greedy_{key}', value) + if self.beam_search_decode: + beams, beam_scores, beam_lengths = beam_search_bert(outputs, length_penalty_enabled=False, beam_size=10) + best_beam_indices = torch.argmax(beam_scores, dim=-1) + best_beams = beams[torch.arange(beams.size(0)), best_beam_indices] # Currently based on highest score + decoded_outputs = self.tokenizer.batch_decode(best_beams[0], skip_special_tokens=True) + scores = rouge_score(decoded_outputs, decoded_lables) + for key, value in scores.items(): + if key == "rouge1_fmeasure" or (key == "val_rouge2_fmeasure" and value != 0.0): + self.log(f'val_beam_decode_{key}', value, prog_bar=True) + else: + self.log(f'val_beam_decode_{key}', value) + + + @torch.no_grad() + def test_step(self, batch, batch_idx): + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + labels = batch["labels"] + decoded_lables = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + outputs = self(input_ids=input_ids, attention_mask=attention_mask) + if self.greedy_decode: + outputs = greedy_search(outputs) + decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + scores = rouge_score(decoded_outputs, decoded_lables) + for key, value in scores.items(): + self.log(f'test_greedy_{key}', value, prog_bar=False) + if self.beam_search_decode: + beams, beam_scores, beam_lengths = beam_search_bert(outputs, length_penalty_enabled=False, beam_size=10) + best_beam_indices = torch.argmax(beam_scores, dim=-1) + best_beams = beams[torch.arange(beams.size(0)), best_beam_indices] # Currently based on highest score + decoded_outputs = self.tokenizer.batch_decode(best_beams[0], skip_special_tokens=True) + scores = rouge_score(decoded_outputs, decoded_lables) + for key, value in scores.items(): + if key == "rouge1_fmeasure" or (key == "val_rouge2_fmeasure" and value != 0.0): + self.log(f'test_beam_decode_{key}', value, prog_bar=True) + else: + self.log(f'test_beam_decode_{key}', value, prog_bar=False) def on_validation_epoch_end(self): # Calculate ROUGE score @@ -62,7 +110,7 @@ class BertLightning(LightningModule): # Log ROUGE scores for key, value in scores.items(): - self.log(f'val_{key}', value, prog_bar=True) + self.log(f'{self.name}val_{key}', value, prog_bar=True) def configure_optimizers(self): optimizer = AdamW(self.parameters(), lr=self.learning_rate) diff --git a/main.py b/main.py index a58918174ab19136cc96463942d0fd3f2820a92d..45935752acb8a9ebf089e08082fc1096051c2fa4 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +from pytorch_lightning.loggers import WandbLogger from sklearn.model_selection import KFold from torch.optim import AdamW from base_models import * @@ -15,6 +16,12 @@ from socket import gethostname import evaluate import torch import train_utils +import os +import glob +import wandb + +wandb.login() + # Output cuda/cpu device = "cuda" if torch.cuda.is_available() else "cpu" print(f'==============') @@ -119,75 +126,96 @@ test_dataset = CNNDailyMailDataset(tokenizer=tokenizer, split_type='test', split test_loader = DataLoader(test_dataset, batch_size=batch_size) # Define Model Object -models = [BertDoubleDense(), BertBiLSTM(), BertSingleDense()] +models = [BertSingleDense(),BertDoubleDense(), BertBiLSTM()] # Define KFold Object, set to None if not cross validating cross_validation_k_fold = KFold(n_splits=num_k_folds) if num_k_folds > 0 else None # Define Optimizer (AdamW) - Filters to only optimize params that are not frozen (i.e. not bert) # Define loss function object criterion = nn.NLLLoss() +train = False +num_cpus = os.cpu_count() +available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())] output_config() +wandb_logger = WandbLogger() for model in models: - optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.05) - if not use_lightning: - model.to(device) - # Load Old Rouge scorer - rouge_score = evaluate.load("rouge") - train_utils.train_model( - model=model, - num_epochs=10, - train_dataset=train_dataset, - validation_dataset=validation_dataset, - add_time_to_model_name=False, - criterion=criterion, - optimizer=optimizer, - scorer=rouge_score, - evaluate=True, - verbose=1, - save_after_epoch=True, - save_best=True, - device=device, - tokenizer=tokenizer, - batch_size=batch_size - ) # model_name="BertSingleDense" Omitted to test custom BERT wrappers + autoname - also means this is technically - # more generalized code + if train: + optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.05) + if not use_lightning: + model.to(device) + # Load Old Rouge scorer + rouge_score = evaluate.load("rouge") + train_utils.train_model( + model=model, + num_epochs=10, + train_dataset=train_dataset, + validation_dataset=validation_dataset, + add_time_to_model_name=False, + criterion=criterion, + optimizer=optimizer, + scorer=rouge_score, + evaluate=True, + verbose=1, + save_after_epoch=True, + save_best=True, + device=device, + tokenizer=tokenizer, + batch_size=batch_size + ) # model_name="BertSingleDense" Omitted to test custom BERT wrappers + autoname - also means this is technically + # more generalized code + else: + train_loader = DataLoader(train_dataset, batch_size=batch_size) + validation_loader = DataLoader(validation_dataset, batch_size=batch_size) + print(available_gpus) + model = BertLightning(model) + print(f"Available GPUs: {len(available_gpus)}") + logger = CSVLogger("logs", name=model.name+"Logger") + trainer = Trainer(logger=wandb_logger) + loss_checkpoint_callback = ModelCheckpoint( + monitor='val_loss', + dirpath='Models/', + filename='{epoch}-{val_loss:.2f}-{rouge:.2f}', + save_top_k=2, + mode='min' + ) + ''' + rogue2_checkpoint_callback = ModelCheckpoint( + monitor='val_rouge2_fmeasure', + dirpath='Models/', + filename='{epoch}-{val_loss:.2f}-{rouge:.2f}', + save_top_k=2, + mode='max' + ) + rogue_checkpoint_callback = ModelCheckpoint( + monitor='val_rouge1_fmeasure', + dirpath='Models/', + filename='{epoch}-{val_loss:.2f}-{rouge:.2f}', + save_top_k=2, + mode='max' + ) + ''' + + print(f"Training {model.name}") + trainer = Trainer( + max_epochs=num_epochs, + devices=len(available_gpus), + accelerator="auto", + precision="16", + callbacks=[loss_checkpoint_callback]) + trainer.fit(model, train_loader, validation_loader) else: - train_loader = DataLoader(train_dataset, batch_size=batch_size) - validation_loader = DataLoader(validation_dataset, batch_size=batch_size) - available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())] - print(available_gpus) - model = BertLightning(model) - print(f"Available GPUs: {len(available_gpus)}") - logger = CSVLogger("logs", name=model.name+"Logger") - loss_checkpoint_callback = ModelCheckpoint( - monitor='val_loss', - dirpath='Models/', - filename='{epoch}-{val_loss:.2f}-{rouge:.2f}', - save_top_k=2, - mode='min' - ) - ''' - rogue2_checkpoint_callback = ModelCheckpoint( - monitor='val_rouge2_fmeasure', - dirpath='Models/', - filename='{epoch}-{val_loss:.2f}-{rouge:.2f}', - save_top_k=2, - mode='max' - ) - rogue_checkpoint_callback = ModelCheckpoint( - monitor='val_rouge1_fmeasure', - dirpath='Models/', - filename='{epoch}-{val_loss:.2f}-{rouge:.2f}', - save_top_k=2, - mode='max' - ) - ''' - - print(f"Training {model.name}") - trainer = Trainer( - max_epochs=num_epochs, - devices=len(available_gpus), - accelerator="auto", - precision="16", - callbacks=[loss_checkpoint_callback]) - trainer.fit(model, train_loader, validation_loader) \ No newline at end of file + checkpoint_dir = "Models/" + checkpoints = [f for f in os.listdir(checkpoint_dir)] + for checkpoint_path in glob.glob("Models/*.ckpt"): + for m in models: + try: + model = BertLightning.load_from_checkpoint(checkpoint_path, model=m) + model.beam_search_decode = True + wandb_logger = WandbLogger() + batch_size = 8 # I am a mortal on local machine + test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0) + trainer = Trainer(logger=wandb_logger) + trainer.test(model=model, dataloaders=test_loader, ckpt_path=checkpoint_path) + except Exception as pokemon: + print(pokemon) + pass diff --git a/string_utils.py b/string_utils.py index fcd9268e822ffef9238e0decbada5ffe41a98822..cff059abb03a6b13a4345686c7cc9a998adb09da 100644 --- a/string_utils.py +++ b/string_utils.py @@ -82,5 +82,6 @@ def create_pip_installs_from_requirements(text=None, file_path=None, print_file= if __name__ == '__main__': file_path = 'requirements.txt' + fix_requirements(file_path=file_path) create_pip_installs_from_requirements(file_path=file_path)