diff --git a/base_models.py b/base_models.py
index b555f2d0aa7baf874db2f02abdac1326f3ca1b30..564fb9ef4b6758b44b1ead4d70f69e63284b1a63 100644
--- a/base_models.py
+++ b/base_models.py
@@ -24,6 +24,7 @@ class FrozenBert(nn.Module):
 
     def forward(self, input_ids, attention_mask, token_type_ids=None):
         # Return Context Embeddings
+
         return self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
 
 
diff --git a/data_loading_utils.py b/data_loading_utils.py
index 9f18e26ed3474bf831f18b68925dc60d7be0b193..a7dff9ab471221d5f7585d15736b857bf3caf7ce 100644
--- a/data_loading_utils.py
+++ b/data_loading_utils.py
@@ -1,3 +1,9 @@
+from sklearn.model_selection import KFold
+from pytorch_lightning import Trainer, LightningModule
+from torch.utils.data import DataLoader, Subset
+from tqdm import tqdm
+
+
 # Used if you ever want to verify if some percentages add to 100 - useless
 def verify_split_percentages(test_split_percentage, train_split_percentage, validate_split_percentage=None,
                              verbose=False):
@@ -22,3 +28,29 @@ def verify_split_percentages(test_split_percentage, train_split_percentage, vali
             output += f"\nTotal percentage: {total_percentage}%."
             print(output)
         return False
+
+
+def cross_validate(model, test_dataloader, k_folds=10):
+    # kfold doesnt work on loaders
+    dataset = list(test_dataloader.dataset)
+
+    # create k-folder
+    kfolder = KFold(n_splits=k_folds, shuffle=True)
+
+    progress_bar = tqdm(enumerate(kfolder.split(dataset)), total=k_folds, desc="Folds")
+
+    results = []
+    for fold, (train_ids, val_ids) in enumerate(kfolder.split(dataset)):
+        # Subsets made according to ids returned by the kfolder
+        train_subsample = Subset(dataset, train_ids)
+        val_subsample = Subset(dataset, val_ids)
+
+        # Load corresponding dataset
+        # train_data_loader = DataLoader(train_subsample, batch_size=test_dataloader.batch_size, shuffle=True)
+        val_data_loader = DataLoader(val_subsample, batch_size=test_dataloader.batch_size, shuffle=False)
+
+        trainer = Trainer(max_epochs=0)
+        trainer.validate(model=model, dataloaders=val_data_loader)
+
+    trainer = Trainer(max_epochs=0)
+    trainer.test(dataloaders=test_dataloader)
diff --git a/beam_search.py b/decode_utils.py
similarity index 68%
rename from beam_search.py
rename to decode_utils.py
index 9e1deb961d34b7e77cc585c85ba4942fb4ff7ed6..bf3392cc001cee81b6f71c536e2ca2272f1854b0 100644
--- a/beam_search.py
+++ b/decode_utils.py
@@ -1,7 +1,24 @@
 import torch
+from itertools import groupby
 
 
-def beam_search_bert(model_outputs, max_length=512, beam_size=3, min_length=None, tokenizer_eos_token_id=None):
+def shrink_list(lst):
+    groups = groupby(lst)
+    result = []
+
+    for key, group in groups:
+        count = sum(1 for _ in group)
+        if count > 1:
+            result.append(f"{key}*{count}")
+        else:
+            result.append(key)
+
+    return result
+
+
+# Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation
+def beam_search_bert(model_outputs, max_length=512, beam_size=3, min_length=None, tokenizer_eos_token_id=None,
+                     length_penalty_enabled=True, alpha=0.6):
     device = model_outputs.device
 
     if len(model_outputs.shape) == 2:
@@ -45,4 +62,15 @@ def beam_search_bert(model_outputs, max_length=512, beam_size=3, min_length=None
         beam_scores = top_scores
         beam_lengths = beam_lengths.gather(1, prev_beam) + 1
 
+        if length_penalty_enabled:
+            # Apply length penalty
+            length_penalty = ((5.0 + beam_lengths.float()) / 6.0).pow(alpha)
+            beam_scores = beam_scores / length_penalty
+
     return beams, beam_scores, beam_lengths
+
+
+def greedy_search(model_outputs):
+    # Argmax probability dimension
+    max_scores, token_ids = torch.max(model_outputs, dim=-1)
+    return token_ids
\ No newline at end of file
diff --git a/lightning_models.py b/lightning_models.py
index cb967294c2d701db630c666ca8b1894b754ef7a6..c123c1a4d2e46cbf0fc264a9cd4aa6f847ab5a7b 100644
--- a/lightning_models.py
+++ b/lightning_models.py
@@ -4,13 +4,14 @@ from torch.optim import AdamW
 from torchmetrics.functional.text.rouge import rouge_score
 from transformers import BertTokenizerFast
 import torch
-from beam_search import beam_search_bert
+from decode_utils import beam_search_bert, greedy_search, shrink_list
 
 
 class BertLightning(LightningModule):
-    def __init__(self, model, learning_rate=1e-5, tokenizer=None, greedy_decode=False, beam_width=3, name_override=None):
+    def __init__(self, model, learning_rate=1e-5, tokenizer=None, greedy_decode=True, beam_search_decode=True,
+                 beam_width=3, name_override=None):
         super().__init__()
-
+        self.save_hyperparameters()
         self.model = model
         try:
             self.name = self.model.name
@@ -24,6 +25,7 @@ class BertLightning(LightningModule):
 
         self.greedy_decode = greedy_decode
         self.beam_width = beam_width
+        self.beam_search_decode = beam_search_decode
 
         self.validation_step_outputs = []
         self.validation_step_labels = []
@@ -53,8 +55,54 @@ class BertLightning(LightningModule):
         outputs = self(input_ids=input_ids, attention_mask=attention_mask)
         loss = self.criterion(outputs.view(-1, self.model.bert.config.vocab_size), labels.view(-1))
         self.log('val_loss', loss, prog_bar=True)
-        #self.validation_step_outputs.append(predicted_indices)
-        #self.validation_step_labels.append(labels)
+        input_ids = batch["input_ids"]
+        attention_mask = batch["attention_mask"]
+        labels = batch["labels"]
+        decoded_lables = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
+        outputs = self(input_ids=input_ids, attention_mask=attention_mask)
+        if self.greedy_decode:
+            outputs = greedy_search(outputs)
+            decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
+            scores = rouge_score(decoded_outputs, decoded_lables)
+            for key, value in scores.items():
+                self.log(f'test_greedy_{key}', value)
+        if self.beam_search_decode:
+            beams, beam_scores, beam_lengths = beam_search_bert(outputs, length_penalty_enabled=False, beam_size=10)
+            best_beam_indices = torch.argmax(beam_scores, dim=-1)
+            best_beams = beams[torch.arange(beams.size(0)), best_beam_indices]  # Currently based on highest score
+            decoded_outputs = self.tokenizer.batch_decode(best_beams[0], skip_special_tokens=True)
+            scores = rouge_score(decoded_outputs, decoded_lables)
+            for key, value in scores.items():
+                if key == "rouge1_fmeasure" or (key == "val_rouge2_fmeasure" and value != 0.0):
+                    self.log(f'val_beam_decode_{key}', value, prog_bar=True)
+                else:
+                    self.log(f'val_beam_decode_{key}', value)
+
+
+    @torch.no_grad()
+    def test_step(self, batch, batch_idx):
+        input_ids = batch["input_ids"]
+        attention_mask = batch["attention_mask"]
+        labels = batch["labels"]
+        decoded_lables = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
+        outputs = self(input_ids=input_ids, attention_mask=attention_mask)
+        if self.greedy_decode:
+            outputs = greedy_search(outputs)
+            decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
+            scores = rouge_score(decoded_outputs, decoded_lables)
+            for key, value in scores.items():
+                self.log(f'test_greedy_{key}', value, prog_bar=False)
+        if self.beam_search_decode:
+            beams, beam_scores, beam_lengths = beam_search_bert(outputs, length_penalty_enabled=False, beam_size=10)
+            best_beam_indices = torch.argmax(beam_scores, dim=-1)
+            best_beams = beams[torch.arange(beams.size(0)), best_beam_indices]  # Currently based on highest score
+            decoded_outputs = self.tokenizer.batch_decode(best_beams[0], skip_special_tokens=True)
+            scores = rouge_score(decoded_outputs, decoded_lables)
+            for key, value in scores.items():
+                if key == "rouge1_fmeasure" or (key == "val_rouge2_fmeasure" and value != 0.0):
+                    self.log(f'test_beam_decode_{key}', value, prog_bar=True)
+                else:
+                    self.log(f'test_beam_decode_{key}', value, prog_bar=False)
 
     def on_validation_epoch_end(self):
         # Calculate ROUGE score
@@ -62,7 +110,7 @@ class BertLightning(LightningModule):
 
         # Log ROUGE scores
         for key, value in scores.items():
-            self.log(f'val_{key}', value, prog_bar=True)
+            self.log(f'{self.name}val_{key}', value, prog_bar=True)
 
     def configure_optimizers(self):
         optimizer = AdamW(self.parameters(), lr=self.learning_rate)
diff --git a/main.py b/main.py
index a58918174ab19136cc96463942d0fd3f2820a92d..45935752acb8a9ebf089e08082fc1096051c2fa4 100644
--- a/main.py
+++ b/main.py
@@ -1,3 +1,4 @@
+from pytorch_lightning.loggers import WandbLogger
 from sklearn.model_selection import KFold
 from torch.optim import AdamW
 from base_models import *
@@ -15,6 +16,12 @@ from socket import gethostname
 import evaluate
 import torch
 import train_utils
+import os
+import glob
+import wandb
+
+wandb.login()
+
 # Output cuda/cpu
 device = "cuda" if torch.cuda.is_available() else "cpu"
 print(f'==============')
@@ -119,75 +126,96 @@ test_dataset = CNNDailyMailDataset(tokenizer=tokenizer, split_type='test', split
 test_loader = DataLoader(test_dataset, batch_size=batch_size)
 
 # Define Model Object
-models = [BertDoubleDense(), BertBiLSTM(), BertSingleDense()]
+models = [BertSingleDense(),BertDoubleDense(), BertBiLSTM()]
 # Define KFold Object, set to None if not cross validating
 cross_validation_k_fold = KFold(n_splits=num_k_folds) if num_k_folds > 0 else None
 # Define Optimizer (AdamW) - Filters to only optimize params that are not frozen (i.e. not bert)
 # Define loss function object
 criterion = nn.NLLLoss()
 
+train = False
+num_cpus = os.cpu_count()
+available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
 output_config()
+wandb_logger = WandbLogger()
 for model in models:
-    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.05)
-    if not use_lightning:
-        model.to(device)
-        # Load Old Rouge scorer
-        rouge_score = evaluate.load("rouge")
-        train_utils.train_model(
-            model=model,
-            num_epochs=10,
-            train_dataset=train_dataset,
-            validation_dataset=validation_dataset,
-            add_time_to_model_name=False,
-            criterion=criterion,
-            optimizer=optimizer,
-            scorer=rouge_score,
-            evaluate=True,
-            verbose=1,
-            save_after_epoch=True,
-            save_best=True,
-            device=device,
-            tokenizer=tokenizer,
-            batch_size=batch_size
-        )  # model_name="BertSingleDense" Omitted to test custom BERT wrappers + autoname - also means this is technically
-        # more generalized code
+    if train:
+        optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.05)
+        if not use_lightning:
+            model.to(device)
+            # Load Old Rouge scorer
+            rouge_score = evaluate.load("rouge")
+            train_utils.train_model(
+                model=model,
+                num_epochs=10,
+                train_dataset=train_dataset,
+                validation_dataset=validation_dataset,
+                add_time_to_model_name=False,
+                criterion=criterion,
+                optimizer=optimizer,
+                scorer=rouge_score,
+                evaluate=True,
+                verbose=1,
+                save_after_epoch=True,
+                save_best=True,
+                device=device,
+                tokenizer=tokenizer,
+                batch_size=batch_size
+            )  # model_name="BertSingleDense" Omitted to test custom BERT wrappers + autoname - also means this is technically
+            # more generalized code
+        else:
+            train_loader = DataLoader(train_dataset, batch_size=batch_size)
+            validation_loader = DataLoader(validation_dataset, batch_size=batch_size)
+            print(available_gpus)
+            model = BertLightning(model)
+            print(f"Available GPUs: {len(available_gpus)}")
+            logger = CSVLogger("logs", name=model.name+"Logger")
+            trainer = Trainer(logger=wandb_logger)
+            loss_checkpoint_callback = ModelCheckpoint(
+                monitor='val_loss',
+                dirpath='Models/',
+                filename='{epoch}-{val_loss:.2f}-{rouge:.2f}',
+                save_top_k=2,
+                mode='min'
+            )
+            '''
+            rogue2_checkpoint_callback = ModelCheckpoint(
+                monitor='val_rouge2_fmeasure',
+                dirpath='Models/',
+                filename='{epoch}-{val_loss:.2f}-{rouge:.2f}',
+                save_top_k=2,
+                mode='max'
+            )
+            rogue_checkpoint_callback = ModelCheckpoint(
+                monitor='val_rouge1_fmeasure',
+                dirpath='Models/',
+                filename='{epoch}-{val_loss:.2f}-{rouge:.2f}',
+                save_top_k=2,
+                mode='max'
+            )
+            '''
+
+            print(f"Training {model.name}")
+            trainer = Trainer(
+                max_epochs=num_epochs,
+                devices=len(available_gpus),
+                accelerator="auto",
+                precision="16",
+                callbacks=[loss_checkpoint_callback])
+            trainer.fit(model, train_loader, validation_loader)
     else:
-        train_loader = DataLoader(train_dataset, batch_size=batch_size)
-        validation_loader = DataLoader(validation_dataset, batch_size=batch_size)
-        available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
-        print(available_gpus)
-        model = BertLightning(model)
-        print(f"Available GPUs: {len(available_gpus)}")
-        logger = CSVLogger("logs", name=model.name+"Logger")
-        loss_checkpoint_callback = ModelCheckpoint(
-            monitor='val_loss',
-            dirpath='Models/',
-            filename='{epoch}-{val_loss:.2f}-{rouge:.2f}',
-            save_top_k=2,
-            mode='min'
-        )
-        '''
-        rogue2_checkpoint_callback = ModelCheckpoint(
-            monitor='val_rouge2_fmeasure',
-            dirpath='Models/',
-            filename='{epoch}-{val_loss:.2f}-{rouge:.2f}',
-            save_top_k=2,
-            mode='max'
-        )
-        rogue_checkpoint_callback = ModelCheckpoint(
-            monitor='val_rouge1_fmeasure',
-            dirpath='Models/',
-            filename='{epoch}-{val_loss:.2f}-{rouge:.2f}',
-            save_top_k=2,
-            mode='max'
-        )
-        '''
-
-        print(f"Training {model.name}")
-        trainer = Trainer(
-            max_epochs=num_epochs,
-            devices=len(available_gpus),
-            accelerator="auto",
-            precision="16",
-            callbacks=[loss_checkpoint_callback])
-        trainer.fit(model, train_loader, validation_loader)
\ No newline at end of file
+        checkpoint_dir = "Models/"
+        checkpoints = [f for f in os.listdir(checkpoint_dir)]
+        for checkpoint_path in glob.glob("Models/*.ckpt"):
+            for m in models:
+                try:
+                    model = BertLightning.load_from_checkpoint(checkpoint_path, model=m)
+                    model.beam_search_decode = True
+                    wandb_logger = WandbLogger()
+                    batch_size = 8  # I am a mortal on local machine
+                    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=0)
+                    trainer = Trainer(logger=wandb_logger)
+                    trainer.test(model=model, dataloaders=test_loader, ckpt_path=checkpoint_path)
+                except Exception as pokemon:
+                    print(pokemon)
+                    pass
diff --git a/string_utils.py b/string_utils.py
index fcd9268e822ffef9238e0decbada5ffe41a98822..cff059abb03a6b13a4345686c7cc9a998adb09da 100644
--- a/string_utils.py
+++ b/string_utils.py
@@ -82,5 +82,6 @@ def create_pip_installs_from_requirements(text=None, file_path=None, print_file=
 
 if __name__ == '__main__':
     file_path = 'requirements.txt'
+    fix_requirements(file_path=file_path)
     create_pip_installs_from_requirements(file_path=file_path)