Skip to content
Snippets Groups Projects
Commit 0c28c917 authored by Joseph Omar's avatar Joseph Omar
Browse files

added ood detection

parent 290c3a6e
No related branches found
No related tags found
No related merge requests found
from typing import Optional, Union
from entcl.data.util import TransformedTensorDataset
from entcl.models.model import ENTCLModel
import torch
from loguru import logger
from sklearn.mixture import GaussianMixture
import tqdm
......@@ -3,11 +3,12 @@ from loguru import logger
import torch
class ENTCLModel(torch.nn.Module):
def __init__(self, head: torch.nn.Module):
def __init__(self, head: torch.nn.Module, backbone_url: str, backbone: str, backbone_source: str):
super().__init__()
# load the backbone
self.backbone = torch.hub.load(os.path.join(os.path.dirname(__file__), 'dinov2'), 'dinov2_vitb14', source='local')
self.backbone = torch.hub.load(backbone_url, backbone, source=backbone_source)
logger.debug(f"Loaded backbone: {backbone} from {backbone_url} (src: {backbone_source})")
# freeze the backbone
for param in self.backbone.parameters():
......
......@@ -50,14 +50,14 @@ def pretrain(args, model):
logger.debug(f"Epoch {epoch} Started")
# train model
print(f"Epoch {epoch}:")
model, loss_total = _train(args, model, train_dataloader, optimiser, criterion)
logger.info(f"Epoch {epoch}: Loss: {loss_total}")
model, train_loss = _train(args, model, train_dataloader, optimiser, criterion)
logger.info(f"Epoch {epoch}: Loss: {train_loss}")
# validate model
model, accuracy = _validate(args, model, val_dataloader)
model, accuracy, val_loss = _validate(args, model, val_dataloader, criterion)
accuracies.append(accuracy)
logger.info(f"Epoch {epoch}: Accuracy: {accuracy}")
logger.info(f"Epoch {epoch}: Accuracy: {accuracy}, Loss: {val_loss}")
torch.save(model.state_dict(), os.path.join(args.exp_dir, f"model_{epoch}.pt"))
# select the best model
......@@ -101,10 +101,11 @@ def _train(args, model, train_dataloader, optimiser, criterion):
loss_total /= len(train_dataloader)
return model, loss_total
def _validate(args, model, val_dataloader):
def _validate(args, model, val_dataloader, criterion):
model.eval()
correct = 0
total = 0
loss_total = 0
with torch.no_grad():
for x, y in tqdm(val_dataloader, desc=f"Validating", unit = "batch"):
x, y = x.to(args.device), y.to(args.device)
......@@ -112,12 +113,17 @@ def _validate(args, model, val_dataloader):
logits, _ = model(x)
logger.debug(f"logits shape: {logits.shape}")
loss = criterion(logits, y)
loss = loss.item()
loss_total += loss
_, predicted = torch.max(logits, 1)
num_correct = (predicted == y).sum().item()
total += y.size(0)
correct += num_correct
logger.debug(f"This Batch Num Correct: {num_correct}, Total: {y.size(0)}, Accuracy: {num_correct / y.size(0)}")
return model, correct / total
logger.debug(f"This Batch Num Correct: {num_correct}, Total: {y.size(0)}, Accuracy: {num_correct / y.size(0)}, Loss: {loss}")
return model, correct / total, loss_total / len(val_dataloader)
......@@ -5,14 +5,23 @@ from loguru import logger
from datetime import datetime
import torch
from entcl.utils.util import seed
from entcl.models.model import ENTCLModel
from entcl.pretrain import pretrain
@logger.catch
def main():
def main(args: argparse.Namespace):
model = ENTCLModel(head=args.head, backbone_url=args.backbone_url, backbone=args.backbone, backbone_source=args.backbone_source)
logger.debug(f"Model: {model}")
model = pretrain(args, model)
if __name__ == "__main__":
logger.debug("Entry Point: run.py")
parser = argparse.ArgumentParser()
# program args
parser.add_argument('--name', type=str, default="entcl_" + datetime.now().isoformat(timespec='seconds'))
......@@ -35,10 +44,10 @@ def main():
parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset to use', choices=['cifar100'])
# optimiser args
parser.add_argument('--lr', type=float, default=0.1, help='Learning Rate for all optimisers')
parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate for all optimisers')
parser.add_argument('--gamma', type=float, default=0.1, help='Gamma for all optimisers')
parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for all optimisers')
parser.add_argument('--weight_decay', type=float, default=5e-5, help='Weight Decay for all optimisers')
parser.add_argument('--momentum', type=float, default=0, help='Momentum for all optimisers')
parser.add_argument('--weight_decay', type=float, default=0, help='Weight Decay for all optimisers')
# cl args
parser.add_argument('--known', type=int, default=50, help='Number of known classes. The rest are novel classes')
......@@ -60,8 +69,17 @@ def main():
# model args
parser.add_argument('--head', type=str, default='linear2', help='Classification head to use', choices=['linear','linear2', 'dino_head'])
parser.add_argument('--backbone_url', type=str, default="/cl/entcl/entcl/models/dinov2", help="URL to the repo containing the backbone model")
parser.add_argument("--backbone", type=str, default="dinov2_vitb14", help="Name of the backbone model to use")
parser.add_argument("--backbone_source", type=str, default="local", help="Source of the backbone model")
# ood args
parser.add_argument('--ood_score', type=str, default='entropy', help='Changes the metric(s) to base OOD detection on', choices=['entropy', 'energy', 'both'])
parser.add_argument('--ood_eps', type=float, default=1e-8, help='Epsilon value for computing entropy in OOD detection')
args = parser.parse_args()
seed(args.seed) # seed everything
# setup device
if not torch.cuda.is_available():
raise ValueError("CUDA not available")
......@@ -107,12 +125,8 @@ def main():
args.head = LinearHead2(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256)
logger.debug(f"Using Linear2 Head: {args.head}")
model = ENTCLModel(head=args.head)
logger.debug(f"Model: {model}")
model = pretrain(args, model)
if __name__ == "__main__":
logger.debug("Entry Point: run.py")
main()
\ No newline at end of file
argstr = "Arguments: \n"
for arg in vars(args):
argstr += f"{arg}: {getattr(args, arg)}\n"
main(args)
\ No newline at end of file
from typing import Iterable, Tuple, Union
from entcl.data.util import TransformedTensorDataset
from entcl.models.model import ENTCLModel
from loguru import logger
from sklearn.mixture import GaussianMixture
import torch
from tqdm import tqdm
def _get_scores(
loader: Union[torch.utils.data.DataLoader, Iterable], model: ENTCLModel, args
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[None, None],
Tuple[torch.Tensor, None],
Tuple[None, torch.Tensor],
]:
"""
Computes entropy and/or energy scores for the dataset based on `args.ood_score`.
:param loader: a Dataloader or Iterable for the dataset.
:param model: The model to evaluate.
:param args: Object with the attributes `ood_score` and `device`.
:return: Tuple of torch.Tensors with the entropy and energy scores respectively.
"""
logger.debug(f"Getting Scores: {args.ood_score}")
model.eval()
compute_entropy = args.ood_score in ["entropy", "all"]
compute_energy = args.ood_score in ["energy", "all"]
entropies, energies = [], []
for x, _ in tqdm(loader, desc="Calculating Scores", leave=True, unit="batch"):
x = x.to(args.device)
with torch.no_grad():
logits, _ = model(x)
if compute_entropy:
softmax = torch.nn.functional.softmax(logits, dim=1)
entropy = -torch.sum(
softmax * torch.log(softmax + args.ood_eps), dim=1
) # Added epsilon for numerical stability ood_eps is 1e-8 by default
entropies.append(entropy)
if compute_energy:
energy = -torch.logsumexp(logits, dim=1)
energies.append(energy)
logger.debug(
f"Scores Calculated: Entropy: {len(entropies)} batches, Energy: {len(energies)} batches"
)
entropies = torch.cat(entropies) if compute_entropy else None
energies = torch.cat(energies) if compute_energy else None
return entropies, energies
def _fit_predict_gmm(data: torch.Tensor, args) -> torch.Tensor:
"""
Helper function to fit a Gaussian Mixture Model to the data
:param data: Tensor of shape [N] with the data to fit the GMM to.
:param args: Object with the attribute `seed`.
"""
logger.debug(f"Fitting Gaussian Mixture Model to Data")
gmm = GaussianMixture(n_components=2, random_state=args.seed)
# Fit and predict using the GMM
predtypes_hard = torch.tensor(
gmm.fit_predict(data.view(-1, 1).cpu().numpy()), device=data.device
)
predtypes_soft = torch.tensor(
gmm.predict_proba(data.view(-1, 1).cpu().numpy()), device=data.device
)
# Retrieve the means of the two clusters
mean_0 = torch.mean(data[predtypes_hard == 0], dim=0)
mean_1 = torch.mean(data[predtypes_hard == 1], dim=0)
logger.debug(f"Mean for Type 0: {mean_0}, Mean for Type 1: {mean_1}")
# Swapping clusters if necessary
if mean_1 < mean_0:
logger.debug("Type 1 has lower mean than Type 0. Swapping types")
predtypes_hard = 1 - predtypes_hard # Swap the types
predtypes_soft = predtypes_soft[
:, [1, 0]
] # Swap probability columns so that the first column is the probability of type 0 and the second column is the probability of type 1
else:
logger.debug("Type 0 has lower mean than Type 1. Keeping types")
return predtypes_hard, predtypes_soft
def _resolve_conflicts(
entropy_predtypes_soft: torch.Tensor, energy_predtypes_soft: torch.Tensor
) -> torch.Tensor:
"""
Resolves conflicts between entropy and energy predictions by selecting the type with the highest confidence.
:param entropy_predtypes_soft: Tensor of shape [N, 2] with soft predictions from the entropy GMM.
:param energy_predtypes_soft: Tensor of shape [N, 2] with soft predictions from the energy GMM.
:return: Tensor of shape [N] with resolved hard predictions (0 or 1).
"""
logger.debug("Resolving Conflicts")
assert (
entropy_predtypes_soft.shape == energy_predtypes_soft.shape
), f"Entropy and Energy predictions must have the same shape. Got Entropy.shape: {entropy_predtypes_soft.shape}, Energy.shape: {energy_predtypes_soft.shape}"
# Compute hard predictions and their confidence scores
entropy_predtypes_hard = torch.argmax(entropy_predtypes_soft, dim=1)
energy_predtypes_hard = torch.argmax(energy_predtypes_soft, dim=1)
# for each sample, get the confidence of the predicted type
entropy_confidence = entropy_predtypes_soft[
torch.arange(entropy_predtypes_soft.size(0)), entropy_predtypes_hard
]
energy_confidence = energy_predtypes_soft[
torch.arange(energy_predtypes_soft.size(0)), energy_predtypes_hard
]
# Resolve conflicts by selecting the type with the highest confidence
# torch.where(condition, x, y) returns x if condition is True, otherwise y
resolved_predictions = torch.where(
entropy_predtypes_hard
== energy_predtypes_hard, # if the predictions are the same
entropy_predtypes_hard, # return the prediction
torch.where( # otherwise
energy_confidence
> entropy_confidence, # if the energy confidence is higher
energy_predtypes_hard, # use the energy prediction
entropy_predtypes_hard,
), # otherwise use the entropy prediction
)
return resolved_predictions
def label_ood_for_session(
args,
session_dataset: TransformedTensorDataset,
model: ENTCLModel,
return_new_dataset: bool = False,
) -> Union[TransformedTensorDataset, torch.Tensor]:
"""
OOD Labelling for a session dataset. This function computes entropy and/or energy scores for the dataset based on `args.ood_score` and fits a Gaussian Mixture Model to each of the scores. The GMM has 2 components, one for in-distribution samples and one for OOD samples. The function then resolves conflicts between the entropy and energy predictions by selecting the type with the highest confidence. Finally, the function returns a new dataset for the session, including the predicted types.
:param args: Objects with the attributes `ood_score`, `ood_eps`, `seed` and `device` (Program Arguments).
:param session_dataset: Dataset for the session.
:param model: The model to evaluate.
:param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted types (useful when theres more to do before training).
:return: A TransformedTensorDataset with the predicted types or just a torch.Tensor of the predicted types.
"""
logger.debug("Starting OOD Labelling for Session")
# first we dataload the session dataset, for memory efficiency
session_loader = torch.utils.data.DataLoader(
session_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
shuffle=False,
)
# next we run the dataset through the model to retrieve an entropy/energy score for each sample
entropies, energies = _get_scores(session_loader, model, args)
# next step depends on the selected OOD Score
# placeholder for the final predictions
final_predtypes = None
# if we are using entropy only
if args.ood_score == "entropy":
logger.debug("Using Entropy Only")
predtypes_hard, _ = _fit_predict_gmm(
entropies, args
) # fit a GMM to the entropy scores, we do not care about the soft predictions
final_predtypes = predtypes_hard
# if we are using energy only
elif args.ood_score == "energy":
logger.debug("Using Energy Only")
predtypes_hard, _ = _fit_predict_gmm(energies, args)
final_predtypes = predtypes_hard
# if we are using both entropy and energy
elif args.ood_score == "both":
logger.debug("Using Both Entropy and Energy")
_, entropy_predtypes_soft = _fit_predict_gmm(
entropies, args
) # we do not care about the hard predictions
_, energy_predtypes_soft = _fit_predict_gmm(
energies, args
) # we do not care about the hard predictions
final_predtypes = _resolve_conflicts(
entropy_predtypes_soft, energy_predtypes_soft
)
else:
raise ValueError(f"Invalid OOD Score: {args.ood_score}")
if return_new_dataset:
logger.debug("Returning New Dataset")
# return the new dataset with the predicted types
session_dataset = TransformedTensorDataset(
tensor_dataset=torch.utils.data.TensorDataset(
session_dataset.tensor_dataset.tensors[0], # the data
session_dataset.tensor_dataset.tensors[1], # the labels
final_predtypes, # the predicted types (duh)
),
transform=session_dataset.transform,
)
return session_dataset
else:
logger.debug("Returning Predicted Types")
# return just the predicted types
return final_predtypes
import torch
import numpy as np
import os
import random
def seed(seed=8008135):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment