Skip to content
Snippets Groups Projects
Commit 1d271cc1 authored by Joseph Omar's avatar Joseph Omar
Browse files

Add pre-trained models and enhance training with learning rate scheduler

- Added pre-trained models for DINO and DINOv2 to backbone_store
- Updated .gitignore to exclude egg-info files
- Enhanced LinearHead to ensure new weights are initialized correctly
- Integrated Cosine Annealing LR Scheduler into pretraining and continual learning sessions
- Improved OOD accuracy computation with additional entropy and energy metrics
parent 62c39337
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
*.pyc *.pyc
__pycache__/ __pycache__/
# ignore egg stuff
*.egg-info/
runs/ runs/
experiments/debug experiments/debug
... ...
......
...@@ -17,20 +17,16 @@ def cl_session(args, session_dataset: TransformedTensorDataset, model: torch.nn. ...@@ -17,20 +17,16 @@ def cl_session(args, session_dataset: TransformedTensorDataset, model: torch.nn.
logger.debug(f"Begin Continual Learning Session {args.current_session}") logger.debug(f"Begin Continual Learning Session {args.current_session}")
# make sure the dataset has the correct shape # make sure the dataset has the correct shape
assert len(session_dataset.tensor_dataset.tensors) == 5, "Session Dataset should have 5 tensors, (data, true labels, true types, pred types, pseudo labels). Got: " + str(len(session_dataset.tensor_dataset.tensors)) assert len(session_dataset.tensor_dataset.tensors) == 5, "Session Dataset should have 5 tensors, (data, true labels, true types, pred types, pseudo labels). Got: " + str(len(session_dataset.tensor_dataset.tensors))
for i, t in enumerate(session_dataset.tensor_dataset.tensors):
if t.device != torch.device("cpu"):
logger.warning(f"Tensor {i} is not on CPU")
# create the required training stuff and things and junk # create the required training stuff and things and junk
# we are only training with nove data at the moment, so we need to whittle down the dataset to only the predicted novel samples # we are only training with nove data at the moment, so we need to whittle down the dataset to only the predicted novel samples
novel_samples_mask = session_dataset.tensor_dataset.tensors[3] == 1 # novel samples are labelled with 1. predtypes are the 4th tensor in the dataset novel_samples_mask = session_dataset.tensor_dataset.tensors[3] == 1 # novel samples are labelled with 1. predtypes are the 4th tensor in the dataset
novel_samples_mask = novel_samples_mask.cpu()
novel_tensors = [tensor[novel_samples_mask] for tensor in session_dataset.tensor_dataset.tensors] novel_tensors = [tensor[novel_samples_mask] for tensor in session_dataset.tensor_dataset.tensors]
# adjust the psuedo labels (which start from 0 atm) to start from args.dataset.known + ((args.current_session - 1) * args.dataset.novel_inc)
adjust_value = args.dataset.known + ((args.current_session - 1) * args.dataset.novel_inc)
logger.debug(f"Adjusting Pseudo Labels by {adjust_value}")
novel_tensors[4] += adjust_value
logger.debug(f"Adjusted Pseudo Labels: {torch.unique(novel_tensors[4], sorted=True)}")
session_dataset = TransformedTensorDataset( session_dataset = TransformedTensorDataset(
tensor_dataset=torch.utils.data.TensorDataset(*novel_tensors), tensor_dataset=torch.utils.data.TensorDataset(*novel_tensors),
transform=session_dataset.transform, transform=session_dataset.transform,
...@@ -79,6 +75,14 @@ def cl_session(args, session_dataset: TransformedTensorDataset, model: torch.nn. ...@@ -79,6 +75,14 @@ def cl_session(args, session_dataset: TransformedTensorDataset, model: torch.nn.
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
) )
if not args.no_sched:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimiser,
T_max=args.cl_epochs,
eta_min=args.lr * 1e-3,
verbose=args.verbose,
)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
results = None results = None
...@@ -126,6 +130,9 @@ def cl_session(args, session_dataset: TransformedTensorDataset, model: torch.nn. ...@@ -126,6 +130,9 @@ def cl_session(args, session_dataset: TransformedTensorDataset, model: torch.nn.
results.to_csv(f"{args.exp_dir}/results_s{args.current_session}.csv", index=False) results.to_csv(f"{args.exp_dir}/results_s{args.current_session}.csv", index=False)
logger.debug(f"Session {args.current_session} Results saved to {args.exp_dir}/results_s{args.current_session}.csv") logger.debug(f"Session {args.current_session} Results saved to {args.exp_dir}/results_s{args.current_session}.csv")
if not args.no_sched:
scheduler.step()
return model return model
def _train(args, model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, optimiser: torch.optim.Optimizer, criterion: torch.nn.Module) -> Tuple[torch.nn.Module, float]: def _train(args, model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, optimiser: torch.optim.Optimizer, criterion: torch.nn.Module) -> Tuple[torch.nn.Module, float]:
... ...
......
File added
File added
...@@ -16,12 +16,15 @@ class LinearHead(torch.nn.Module): ...@@ -16,12 +16,15 @@ class LinearHead(torch.nn.Module):
""" """
logger.info(f"Expanding Head: {self.fc.out_features} -> {self.fc.out_features + num}") logger.info(f"Expanding Head: {self.fc.out_features} -> {self.fc.out_features + num}")
old_fc = self.fc old_fc = self.fc
self.fc = torch.nn.Linear(old_fc.in_features, old_fc.out_features + num) self.fc = torch.nn.Linear(old_fc.in_features, old_fc.out_features + num)
self.fc.weight.data[:old_fc.out_features] = old_fc.weight.data self.fc.weight.data[:old_fc.out_features] = old_fc.weight.data
self.fc.bias.data[:old_fc.out_features] = old_fc.bias.data self.fc.bias.data[:old_fc.out_features] = old_fc.bias.data
self.fc.weight.data[old_fc.out_features:] = 0 self.fc.weight.data[old_fc.out_features:] = 0
self.fc.bias.data[old_fc.out_features:] = 0 self.fc.bias.data[old_fc.out_features:] = 0
self.fc = self.fc.to(old_fc.weight.device)
if init_new: if init_new:
torch.nn.init.kaiming_normal_(self.fc.weight.data[old_fc.out_features:]) torch.nn.init.kaiming_normal_(self.fc.weight.data[old_fc.out_features:])
torch.nn.init.zeros_(self.fc.bias.data[old_fc.out_features:]) torch.nn.init.zeros_(self.fc.bias.data[old_fc.out_features:])
... ...
......
...@@ -2,13 +2,16 @@ import os ...@@ -2,13 +2,16 @@ import os
from loguru import logger from loguru import logger
import torch import torch
class ENTCLModel(torch.nn.Module): class ENTCLModel(torch.nn.Module):
def __init__(self, head: torch.nn.Module, backbone_url: str, backbone: str, backbone_source: str): def __init__(
self,
head: torch.nn.Module,
backbone_version: int = 1,
):
super().__init__() super().__init__()
# load the backbone self.backbone = self._load_backbone(backbone_version=backbone_version)
self.backbone = torch.hub.load(backbone_url, backbone, source=backbone_source)
logger.debug(f"Loaded backbone: {backbone} from {backbone_url} (src: {backbone_source})")
# freeze the backbone # freeze the backbone
for param in self.backbone.parameters(): for param in self.backbone.parameters():
...@@ -26,6 +29,47 @@ class ENTCLModel(torch.nn.Module): ...@@ -26,6 +29,47 @@ class ENTCLModel(torch.nn.Module):
super().train(mode) super().train(mode)
self.backbone.train(False) self.backbone.train(False)
def _load_backbone(self, backbone_version: int = 1):
if backbone_version == 1:
from dino import vision_transformer as vits
model = vits.__dict__["vit_base"](patch_size=16, num_classes=0)
state_dict = torch.load(
f=os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"backbone_store",
"dino_vitbase16_pretrain.pth",
),
map_location="cpu",
weights_only=False,
)
model.load_state_dict(state_dict, strict=True)
elif backbone_version == 2:
from dinov2 import vision_transformer as vits
model = vits.__dict__["vit_base"](
img_size=518,
patch_size=14,
init_values=1.0,
ffn_layer="mlp",
block_chunks=0,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1,
)
state_dict = torch.load(
f=os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"backbone_store",
"dinov2_vitb14_pretrain.pth",
),
map_location="cpu",
weights_only=False,
)
model.load_state_dict(state_dict, strict=True)
else:
raise ValueError(f"Unsupported backbone version: {backbone_version}")
return model
if __name__ == "__main__": if __name__ == "__main__":
model = ENTCLModel(head=None) model = ENTCLModel(head=None)
... ...
......
...@@ -33,6 +33,17 @@ def pretrain(args, model): ...@@ -33,6 +33,17 @@ def pretrain(args, model):
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
) )
if not args.no_sched:
logger.debug("Using Cosine Annealing LR Scheduler")
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimiser,
T_max=args.pretrain_epochs,
eta_min=args.lr * 1e-3,
verbose=args.verbose,
)
else:
logger.debug("Not Using Scheduler")
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
results = None results = None
...@@ -75,6 +86,10 @@ def pretrain(args, model): ...@@ -75,6 +86,10 @@ def pretrain(args, model):
# save the results dataframe # save the results dataframe
results.to_csv(f"{args.exp_dir}/results_s0.csv", index=False) results.to_csv(f"{args.exp_dir}/results_s0.csv", index=False)
logger.debug(f"Epoch {epoch} Finished. Pretrain Results Saved to {args.exp_dir}/results_s0.csv") logger.debug(f"Epoch {epoch} Finished. Pretrain Results Saved to {args.exp_dir}/results_s0.csv")
if not args.no_sched:
scheduler.step()
return model return model
else: else:
raise ValueError(f"No Model to load and mode is not pretrain or both. Mode: {args.mode}, Pretrain Load: {args.pretrain_load}") raise ValueError(f"No Model to load and mode is not pretrain or both. Mode: {args.mode}, Pretrain Load: {args.pretrain_load}")
... ...
......
...@@ -16,7 +16,7 @@ from entcl.pretrain import pretrain ...@@ -16,7 +16,7 @@ from entcl.pretrain import pretrain
@logger.catch @logger.catch
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
model = ENTCLModel(head=args.head, backbone_url=args.backbone_url, backbone=args.backbone, backbone_source=args.backbone_source) model = ENTCLModel(head=args.head, backbone_version=args.backbone)
logger.debug(f"Model: {model}") logger.debug(f"Model: {model}")
logger.info("Pretraining Model (Session 0)") logger.info("Pretraining Model (Session 0)")
...@@ -77,10 +77,11 @@ if __name__ == "__main__": ...@@ -77,10 +77,11 @@ if __name__ == "__main__":
parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset to use', choices=['cifar100']) parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset to use', choices=['cifar100'])
# optimiser args # optimiser args
parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate for all optimisers') parser.add_argument('--lr', type=float, default=0.1, help='Learning Rate for all optimisers')
parser.add_argument('--gamma', type=float, default=0.1, help='Gamma for all optimisers') parser.add_argument('--gamma', type=float, default=0.1, help='Gamma for all optimisers')
parser.add_argument('--momentum', type=float, default=0, help='Momentum for all optimisers') parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for all optimisers')
parser.add_argument('--weight_decay', type=float, default=0, help='Weight Decay for all optimisers') parser.add_argument('--weight_decay', type=float, default=5e-5, help='Weight Decay for all optimisers')
parser.add_argument('--no_sched', action='store_true', help='Do not use a scheduler')
# cl args # cl args
parser.add_argument('--known', type=int, default=50, help='Number of known classes. The rest are novel classes') parser.add_argument('--known', type=int, default=50, help='Number of known classes. The rest are novel classes')
...@@ -102,9 +103,7 @@ if __name__ == "__main__": ...@@ -102,9 +103,7 @@ if __name__ == "__main__":
# model args # model args
parser.add_argument('--head', type=str, default='linear', help='Classification head to use', choices=['linear','mlp', 'dino_head']) parser.add_argument('--head', type=str, default='linear', help='Classification head to use', choices=['linear','mlp', 'dino_head'])
parser.add_argument('--backbone_url', type=str, default="/cl/entcl/entcl/models/dinov2", help="URL to the repo containing the backbone model") parser.add_argument("--backbone", type=int, default=1, help="Version of DINO to use", choices=[1, 2])
parser.add_argument("--backbone", type=str, default="dinov2_vitb14", help="Name of the backbone model to use")
parser.add_argument("--backbone_source", type=str, default="local", help="Source of the backbone model")
# ood args # ood args
parser.add_argument('--ood_score', type=str, default='entropy', help='Changes the metric(s) to base OOD detection on', choices=['entropy', 'energy', 'both']) parser.add_argument('--ood_score', type=str, default='entropy', help='Changes the metric(s) to base OOD detection on', choices=['entropy', 'energy', 'both'])
...@@ -114,6 +113,7 @@ if __name__ == "__main__": ...@@ -114,6 +113,7 @@ if __name__ == "__main__":
parser.add_argument('--ncd_findk_method', type=str, default='cheat', help='Method to use for finding the number of novel classes', choices=['elbow', 'silhouette', 'gap', 'cheat']) parser.add_argument('--ncd_findk_method', type=str, default='cheat', help='Method to use for finding the number of novel classes', choices=['elbow', 'silhouette', 'gap', 'cheat'])
args = parser.parse_args() args = parser.parse_args()
seed(args.seed) # seed everything seed(args.seed) # seed everything
# setup device # setup device
...@@ -172,4 +172,6 @@ if __name__ == "__main__": ...@@ -172,4 +172,6 @@ if __name__ == "__main__":
for arg in vars(args): for arg in vars(args):
argstr += f"{arg}: {getattr(args, arg)}\n" argstr += f"{arg}: {getattr(args, arg)}\n"
logger.info(argstr)
main(args) main(args)
\ No newline at end of file
...@@ -87,32 +87,35 @@ def generate_mapping( ...@@ -87,32 +87,35 @@ def generate_mapping(
:return: Dict[int, int] a mapping between the true labels and the pseudo labels. :return: Dict[int, int] a mapping between the true labels and the pseudo labels.
""" """
logger.debug("Calculating Clustering Accuracy") logger.debug("Calculating Clustering Accuracy")
true_labels = true_labels.cpu().numpy()
pseudo_labels = pseudo_labels.cpu().numpy()
assert ( assert (
true_labels.shape == pseudo_labels.shape true_labels.shape == pseudo_labels.shape
), f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}" ), f"True and Pseudo labels must have the same shape. true_labels.shape: {true_labels.shape}, pseudo_labels.shape: {pseudo_labels.shape}"
# true labels will be > 50 and psuedo labels will start at 0. we need to adjust the pseudo labels to match the true labels. # this is used for testing, so we will cheat and remove all known classes from the data before finding the mapping
# we will assume the true labels are sequential, and the lowest true label is 0. no_known_data_mask = true_labels >= args.dataset.known + (args.current_session - 1) * args.dataset.novel_inc
true_labels = true_labels[no_known_data_mask]
pseudo_labels = pseudo_labels[no_known_data_mask]
novel_true_min_idx = true_labels.min() # true and psuedo_label should start at 0, so we subtract the minimum value from both to move them into the 0-starting space
pseudo_labels += novel_true_min_idx labels_start = true_labels.min()
true_labels -= labels_start
true_labels = true_labels.cpu().numpy() pseudo_labels -= labels_start
pseudo_labels = pseudo_labels.cpu().numpy()
# Hungarian Algorithm to find the best matching between the true and pseudo labels.
conf_mat = confusion_matrix(true_labels, pseudo_labels) conf_mat = confusion_matrix(true_labels, pseudo_labels)
row_idxs, col_idxs = linear_sum_assignment( row_idxs, col_idxs = linear_sum_assignment(
-conf_mat -conf_mat
) # Hungarian Algorithm to find the best matching between the true and pseudo labels. )
# align the pseudo labels with the true labels based on the hungarian algorithm results # align the pseudo labels with the true labels based on the hungarian algorithm results
pseudo_labels_aligned = np.zeros_like(pseudo_labels) pseudo_labels_aligned = np.zeros_like(pseudo_labels)
for pseudo_label, true_label in zip(col_idxs, row_idxs): for pseudo_label, true_label in zip(col_idxs, row_idxs):
pseudo_labels_aligned[pseudo_labels == pseudo_label] = true_label pseudo_labels_aligned[pseudo_labels == pseudo_label] = true_label
# create a mapping between the true labels and the pseudo labels, used in validation and testing
mapping = {true_label: pseudo_label for pseudo_label, true_label in zip(col_idxs, row_idxs)}
# compute the overall accuracy # compute the overall accuracy
overall_accuracy = np.mean(true_labels == pseudo_labels_aligned) overall_accuracy = np.mean(true_labels == pseudo_labels_aligned)
...@@ -126,31 +129,32 @@ def generate_mapping( ...@@ -126,31 +129,32 @@ def generate_mapping(
string = f"NCD Clustering Accuracies for Session {args.current_session}:" string = f"NCD Clustering Accuracies for Session {args.current_session}:"
for true_class, acc in per_class_accuracy.items(): for true_class, acc in per_class_accuracy.items():
string += f"\nTrue Class {true_class}: {acc*100:4f}%" string += f"\nTrue Class {true_class+labels_start}: {acc*100:4f}%"
string += f"\n" string += f"\n"
string += f"\nOverall Accuracy: {overall_accuracy*100:4f}%" string += f"\nOverall Accuracy: {overall_accuracy*100:4f}%"
logger.info(string) logger.info(string)
# plot the confusion matrix
plot_confmat(
conf_mat,
os.path.join(args.exp_dir, f"confmat_session_{args.current_session}.png"),
)
# Here is where we create the mapping ------------------------------------------------
# we now add the minimum value to the col and row idxs to move them back into the original label space
row_idxs += labels_start
col_idxs += labels_start
# create a mapping between the true labels and the pseudo labels
mapping = {true_label: pseudo_label for pseudo_label, true_label in zip(col_idxs, row_idxs)}
string = f"Mapping for Session {args.current_session}:" string = f"Mapping for Session {args.current_session}:"
for true_label, pseudo_label in mapping.items(): for true_label, pseudo_label in mapping.items():
string += f"\nTrue Label {true_label} -> Pseudo Label {pseudo_label}" string += f"\nTrue Label {true_label} -> Pseudo Label {pseudo_label}"
logger.info(string) logger.info(string)
# plot the confusion matrix
try:
plot_confmat(
conf_mat,
os.path.join(args.exp_dir, f"confmat_session_{args.currect_session}.png"),
)
logger.debug(
f"Confusion Matrix saved to {os.path.join(args.exp_dir, f'confmat_session_{args.currect_session}.png')}"
)
except Exception as e:
logger.error(
f"Could not plot the confusion matrix. Error: {e}\n Confusion Matrix not saved. Continuing..."
)
return mapping return mapping
def _cluster_features(args, features: torch.Tensor) -> torch.Tensor: def _cluster_features(args, features: torch.Tensor) -> torch.Tensor:
...@@ -166,7 +170,7 @@ def _cluster_features(args, features: torch.Tensor) -> torch.Tensor: ...@@ -166,7 +170,7 @@ def _cluster_features(args, features: torch.Tensor) -> torch.Tensor:
kmeans = KMeans(n_clusters=args.novel_classes_per_session, random_state=args.seed) kmeans = KMeans(n_clusters=args.novel_classes_per_session, random_state=args.seed)
pseudo_labels = torch.tensor(kmeans.fit_predict(features.cpu().numpy())) pseudo_labels = torch.tensor(kmeans.fit_predict(features.cpu().numpy()), dtype=torch.long)
return pseudo_labels return pseudo_labels
...@@ -208,6 +212,10 @@ def find_novel_classes_for_session( ...@@ -208,6 +212,10 @@ def find_novel_classes_for_session(
# cluster the features # cluster the features
pseudo_labels = _cluster_features(args, novel_features) pseudo_labels = _cluster_features(args, novel_features)
# pseudo_labels needs to start at the novel class start index, so we add that to the 0-starting pseudo labels
novel_class_start = args.dataset.known + (args.current_session - 1) * args.dataset.novel_inc
pseudo_labels += novel_class_start
# calculate the clustering accuracy (not used in the dataset, only for logging and testing) # calculate the clustering accuracy (not used in the dataset, only for logging and testing)
mapping = generate_mapping( mapping = generate_mapping(
novel_dataset.tensor_dataset.tensors[1], pseudo_labels, args novel_dataset.tensor_dataset.tensors[1], pseudo_labels, args
...@@ -228,6 +236,7 @@ def find_novel_classes_for_session( ...@@ -228,6 +236,7 @@ def find_novel_classes_for_session(
clustering_df = pd.DataFrame(columns=["true_labels", "type", "pseudo_labels"]) clustering_df = pd.DataFrame(columns=["true_labels", "type", "pseudo_labels"])
clustering_df["true_labels"] = session_dataset.tensor_dataset.tensors[1].cpu().numpy() clustering_df["true_labels"] = session_dataset.tensor_dataset.tensors[1].cpu().numpy()
clustering_df["type"] = session_dataset.tensor_dataset.tensors[2].cpu().numpy() clustering_df["type"] = session_dataset.tensor_dataset.tensors[2].cpu().numpy()
clustering_df["predtype"] = session_dataset.tensor_dataset.tensors[3].cpu().numpy()
clustering_df["pseudo_labels"] = pseudo_labels_aligned.cpu().numpy() clustering_df["pseudo_labels"] = pseudo_labels_aligned.cpu().numpy()
# save the dataset to a csv file # save the dataset to a csv file
... ...
......
...@@ -206,19 +206,19 @@ def label_ood_for_session( ...@@ -206,19 +206,19 @@ def label_ood_for_session(
session_dataset.tensor_dataset.tensors[0], # the data session_dataset.tensor_dataset.tensors[0], # the data
session_dataset.tensor_dataset.tensors[1], # the labels session_dataset.tensor_dataset.tensors[1], # the labels
session_dataset.tensor_dataset.tensors[2], # the real types session_dataset.tensor_dataset.tensors[2], # the real types
final_predtypes, # the predicted types (duh) final_predtypes.cpu(), # the predicted types (duh)
), ),
transform=session_dataset.transform, transform=session_dataset.transform,
) )
# compute the OOD Accuracy # compute the OOD Accuracy
_compute_ood_accuracy(session_dataset, args) _compute_ood_accuracy(session_dataset, entropies, energies, args)
return session_dataset return session_dataset
def _compute_ood_accuracy( def _compute_ood_accuracy(
session_dataset: TransformedTensorDataset, args session_dataset: TransformedTensorDataset, entropies, energies, args
) -> None: ) -> None:
""" """
Computes the Accuracy of the OOD Labelling for a session dataset. Computes the Accuracy of the OOD Labelling for a session dataset.
...@@ -227,8 +227,6 @@ def _compute_ood_accuracy( ...@@ -227,8 +227,6 @@ def _compute_ood_accuracy(
:return: None :return: None
""" """
# Create the DataFrame # Create the DataFrame
df = pd.DataFrame( df = pd.DataFrame(
{ {
...@@ -238,6 +236,11 @@ def _compute_ood_accuracy( ...@@ -238,6 +236,11 @@ def _compute_ood_accuracy(
} }
) )
if entropies is not None:
df["entropy"] = entropies.cpu().numpy()
if energies is not None:
df["energy"] = energies.cpu().numpy()
df["is_correct"] = df["predtype"] == df["type"] df["is_correct"] = df["predtype"] == df["type"]
...@@ -246,14 +249,14 @@ def _compute_ood_accuracy( ...@@ -246,14 +249,14 @@ def _compute_ood_accuracy(
string = f"OOD Accuracies for Session {args.current_session}:" string = f"OOD Accuracies for Session {args.current_session}:"
string += f"\nKnown Samples Correct/Total (Accuracy%): {known['is_correct'].sum()}/{known.shape[0]} ({known['is_correct'].mean()*100:.4f}%)" string += f"\nKnown Samples Correct/Total (Accuracy%): {known['is_correct'].sum()}/{known.shape[0]} ({known['is_correct'].mean()*100:.4f}%)"
string += f"\nKnown Samples Incorrect/Total (Error%) : {len(known) - known['is_correct'].sum()}/{known.shape[0]} ({1 - known['is_correct'].mean()*100:.4f}%)" string += f"\nKnown Samples Incorrect/Total (Error%) : {len(known) - known['is_correct'].sum()}/{known.shape[0]} ({((len(known) - known['is_correct'].sum())/len(known))*100:.4f}%)"
string += f"\n" string += f"\n"
string += f"\nNovel Samples Correct/Total (Accuracy%): {novel['is_correct'].sum()}/{novel.shape[0]} ({novel['is_correct'].mean()*100:.4f}%)" string += f"\nNovel Samples Correct/Total (Accuracy%): {novel['is_correct'].sum()}/{novel.shape[0]} ({novel['is_correct'].mean()*100:.4f}%)"
string += f"\nNovel Samples Incorrect/Total (Error%) : {len(novel) - novel['is_correct'].sum()}/{novel.shape[0]} ({1 - novel['is_correct'].mean()*100:.4f}%)" string += f"\nNovel Samples Incorrect/Total (Error%) : {len(novel) - novel['is_correct'].sum()}/{novel.shape[0]} ({((len(novel) - novel['is_correct'].sum())/len(novel))*100:.4f}%)"
string += f"\n" string += f"\n"
string += f"\nOverall Accuracy: {df['is_correct'].mean()*100:.4f}%" string += f"\nOverall Accuracy: {df['is_correct'].mean()*100:.4f}%"
logger.info(string) logger.info(string)
file_path = generate_unique_path(os.path.join(args.exp_dir, f'ood_accuracy_{args.currect_session}.csv')) file_path = generate_unique_path(os.path.join(args.exp_dir, f'ood_accuracy_{args.current_session}.csv'))
df.to_csv(file_path, index=False) df.to_csv(file_path, index=False)
logger.info(f"OOD Accuracy CSV saved to {file_path}") logger.info(f"OOD Accuracy CSV saved to {file_path}")
\ No newline at end of file
Source diff could not be displayed: it is too large. Options to address this: view the blob.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment