Calculates the clustering accuracy between the true labels and the pseudo labels.
:param true_labels: torch.Tensor with the true labels.
:param pseudo_labels: torch.Tensor with the pseudo labels.
"""
asserttrue_labels.shape==pseudo_labels.shape,f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}"
:param args: Object with the attributes `ood_score` and `device`.
:return: Tuple of torch.Tensors with the entropy and energy scores respectively.
"""
session_loader=torch.utils.data.DataLoader(
session_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
shuffle=False,
)
logger.debug(f"Getting Scores: {args.ood_score}")
model.eval()
...
...
@@ -140,30 +148,21 @@ def _resolve_conflicts(
deflabel_ood_for_session(
args,
session_dataset:TransformedTensorDataset,
model:ENTCLModel,
return_new_dataset:bool=False,
model:torch.nn.Module,
return_new_dataset:bool=True,
)->Union[TransformedTensorDataset,torch.Tensor]:
"""
OOD Labelling for a session dataset. This function computes entropy and/or energy scores for the dataset based on `args.ood_score` and fits a Gaussian Mixture Model to each of the scores. The GMM has 2 components, one for in-distribution samples and one for OOD samples. The function then resolves conflicts between the entropy and energy predictions by selecting the type with the highest confidence. Finally, the function returns a new dataset for the session, including the predicted types.
:param args: Objects with the attributes `ood_score`, `ood_eps`, `seed` and `device` (Program Arguments).
:param session_dataset: Dataset for the session.
:param model: The model to evaluate.
:param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted types (useful when theres more to do before training).
:param return_new_dataset: Whether to return the new dataset ready for the session or just the predicted types.
:return: A TransformedTensorDataset with the predicted types or just a torch.Tensor of the predicted types.
"""
logger.debug("Starting OOD Labelling for Session")
# first we dataload the session dataset, for memory efficiency
session_loader=torch.utils.data.DataLoader(
session_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
shuffle=False,
)
# next we run the dataset through the model to retrieve an entropy/energy score for each sample