Skip to content
Snippets Groups Projects
Commit 155568e8 authored by Joseph Omar's avatar Joseph Omar
Browse files

initial

parents
No related branches found
No related tags found
No related merge requests found
Showing with 753 additions and 0 deletions
*.pth filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.svg filter=lfs diff=lfs merge=lfs -text
*.png filter=lfs diff=lfs merge=lfs -text
*.pdf filter=lfs diff=lfs merge=lfs -text
# Ignore compiled or cached Python files
*.pyc
__pycache__/
runs/
experiments/debug
experiments/GM_MI_DEBUG
dataset
cifar-100-python
*.tar.gz
Metadata-Version: 2.1
Name: entcl
Version: 0.1.0
setup.py
entcl/__init__.py
entcl/config.py
entcl/pretrain.py
entcl/run.py
entcl.egg-info/PKG-INFO
entcl.egg-info/SOURCES.txt
entcl.egg-info/dependency_links.txt
entcl.egg-info/top_level.txt
entcl/data/__init__.py
entcl/data/cifar100.py
entcl/data/test.py
entcl/models/__init__.py
entcl/models/dinohead.py
entcl/models/linear_head.py
entcl/models/model.py
\ No newline at end of file
entcl
CIFAR100_DIR = '/cl/datasets/CIFAR/'
\ No newline at end of file
import os
from typing import Dict, List, Union
import torch
from torchvision.datasets import CIFAR100 as _CIFAR100
import torchvision.transforms.v2 as transforms
from entcl.config import CIFAR100_DIR
from loguru import logger
CIFAR100_TRANSFORM = transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
class CIFAR100Dataset:
def __init__(
self,
known: int = 50,
pretrain_n_known: int = 400,
cl_n_known: int = 20,
cl_n_novel: int = 400,
cl_n_prevnovel: int = 20,
sessions: int = 5,
mutex: bool = True,
force_download = False,
):
"""
CIFAR100 dataset with incremental learning settings.
:param known: Number of known classes. Default: 50
:param pretrain_n_known: Number of samples per known class for pretraining. Default: 400
:param cl_n_known: Number of samples per known class for each CL session. Default: 20
:param cl_n_novel: Number of samples per novel class for each CL session. Default: 400
:param cl_n_prevnovel: Number of samples per previously novel class for each CL session. Default: 20
"""
if known >= 100:
raise ValueError("Number of known classes cannot be greater than 100")
self.transform = CIFAR100_TRANSFORM
self.known = known
self.sessions = sessions
self.pretrain_n_known = pretrain_n_known
self.num_classes = 100
self.cl_n_known = cl_n_known
self.cl_n_novel = cl_n_novel
self.cl_n_prevnovel = cl_n_prevnovel
self.novel_inc = (100 - self.known) // self.sessions
# Verify the CL settings
logger.debug(
"Verifying incremental learning settings\n"
+ f"Known classes: {self.known}\n"
+ f"Pretraining samples per known class: {self.pretrain_n_known}\n"
+ f"Samples per known class per CL session: {self.cl_n_known}\n"
+ f"Samples per novel class per CL session: {self.cl_n_novel}\n"
+ f"Samples per previously novel class per CL session: {self.cl_n_prevnovel}\n"
+ f"CL sessions: {self.sessions}"
)
self._verify_splits()
download = (not os.path.exists(os.path.join(CIFAR100_DIR, "cifar-100-python"))) or force_download
logger.debug(f"Download: {download}")
# load and sort the data into master lists
logger.debug("Loading and Sorting CIFAR100 Train split")
master_train_data: Dict[int, torch.Tensor] = self._split_data_by_class(_CIFAR100(
CIFAR100_DIR, train=True, transform=self.transform, download=download
))
# split the data into datasets for each session
logger.debug("Splitting Train Data for Sessions")
self.train_datasets = self._split_train_data_for_sessions(master_train_data, mutex=mutex)
del master_train_data
logger.debug("Loading and Sorting CIFAR100 Test split")
master_test_data: Dict[int, torch.Tensor] = self._split_data_by_class(_CIFAR100(
CIFAR100_DIR, train=False, transform=self.transform, download=download
))
logger.debug("Splitting Test Data for Sessions")
self.test_datasets = self._split_test_data_for_sessions(master_test_data)
del master_test_data
def get_dataset(self, session: int = 0, train: bool = True):
"""
Get the dataset for a given session.
:param session: Session number, 0 for pretraining. Default: 0
:param train: Whether to get the training set. Default: True
:return: Dataset for the given session
"""
if session == "pretrain":
session = 0
if session not in self.train_datasets:
raise ValueError(f"Session {session} does not exist, only sessions {list(self.train_datasets.keys())} exist")
return self.train_datasets[session] if train else self.test_datasets[session]
def _split_train_data_for_sessions(self, masterlist: Dict[int, torch.Tensor], mutex: bool = False):
"""
Split the data for each session, creating a dataset for each session
:param masterlist: Dict containing the data for each class
:param mutex: Whether to use mutex sampling. Default: False
:return: Dict[int, torch.util.data.Dataset] containing the datasets for each session
"""
def sample_data(data_tensor: torch.Tensor, num_samples: int):
"""
Sample data from the given tensor.
:param data_tensor: Data tensor
:param num_samples: Number of samples to select
:return: Sampled data and mask to remove the sampled data
"""
sample_idxs = torch.randperm(data_tensor.size(0))[:num_samples]
sampled_data = data_tensor[sample_idxs]
mask = torch.ones(data_tensor.size(0), dtype=torch.bool)
mask[sample_idxs] = False
return sampled_data, mask
def update_masterlist(class_idx: int, mask: torch.Tensor):
"""
Update the masterlist by removing the sampled data (if mutex is True).
:param class_idx: Class index
:param mask: Mask to remove the sampled data
"""
if mutex:
masterlist[class_idx] = masterlist[class_idx][mask]
def append_samples(data_tensor: torch.Tensor, num_samples: int, class_idx: int, samples_list: List[torch.Tensor], labels_list: List[torch.Tensor]):
"""
Append samples to the samples list.
:param data_tensor: Data tensor
:param num_samples: Number of samples to append
:param class_idx: Class index
:param samples_list: List to append the samples
:param labels_list: List to append the labels
"""
sampled_data, mask = sample_data(data_tensor, num_samples)
samples_list.append(sampled_data)
labels_list.append(torch.full((num_samples,), class_idx, dtype=torch.long))
update_masterlist(class_idx, mask)
# initialise the dict of final datasets
datasets = {}
logger.debug(f"Splitting data for {self.sessions} sessions")
# pretraining session's dataset (session 0)
logger.debug(f"Splitting data for session 0 (pretraining)")
samples, labels = [], []
for class_idx in range(self.known):
append_samples(masterlist[class_idx], self.pretrain_n_known, class_idx, samples, labels)
samples = torch.cat(samples)
labels = torch.cat(labels)
logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
logger.debug(f"Classes in Pretraining Dataset: {labels.unique(sorted=True)}")
datasets[0] = torch.utils.data.TensorDataset(samples, labels)
# CL sessions' datasets
logger.debug("Splitting data for CL sessions")
for session in range(1, self.sessions + 1):
logger.debug(f"Splitting data for session {session}")
samples, labels = [], []
# Known classes
logger.debug(f"There are {self.known} known classes. Starting at 0 (inc), ending at {self.known} (exc)")
for class_idx in range(self.known):
append_samples(masterlist[class_idx], self.cl_n_known, class_idx, samples, labels)
novel_start = self.known + (session - 1) * self.novel_inc
novel_end = novel_start + self.novel_inc
logger.debug(f"There are {self.novel_inc} novel classes. Starting at {novel_start} (inc), ending at {novel_end} (exc)")
for class_idx in range(novel_start, novel_end):
append_samples(masterlist[class_idx], self.cl_n_novel, class_idx, samples, labels)
if novel_start > self.known: # if there are any previously novel classes
# ATTN: Previous novel classes start at the final known class (self.known), and end at the start of the novel classes (novel_start)
# Example: for session 3, with known=50 and novel_inc=10, the previously novel classes start at 50 and end at 80
logger.debug(f"There are {self.novel_inc} previously novel classes. Starting at {self.known} (inc), ending at {novel_start} (exc)")
for class_idx in range(self.known, novel_start):
append_samples(masterlist[class_idx], self.cl_n_prevnovel, class_idx, samples, labels)
samples = torch.cat(samples)
labels = torch.cat(labels)
logger.debug(f"Creating dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
logger.debug(f"Classes in this Session {session}'s Train Dataset: {labels.unique(sorted=True)}")
datasets[session] = torch.utils.data.TensorDataset(samples, labels)
return datasets
def _split_test_data_for_sessions(self, masterlist: Dict[int, torch.Tensor]):
"""
Split the test data for each session, creating a dataset of classes seen during that session
:param masterlist: Dict containing the data for each class
:return: Dict[int, torch.util.data.Dataset] containing the datasets for each session
"""
datasets = {}
logger.debug(f"Splitting test data for {self.sessions} sessions")
for session in range(0, self.sessions + 1):
logger.debug(f"Splitting test data for session {session}")
samples, labels = [], []
# get data for all seen classes (self.known + session * self.novel_inc)
seen_classes_end = self.known + (session * self.novel_inc)
logger.debug(f"There are {seen_classes_end} seen classes. Starting at 0 (inc), ending at {seen_classes_end} (exc)")
for class_idx in range(seen_classes_end):
samples.append(masterlist[class_idx])
labels.append(torch.full((masterlist[class_idx].size(0),), class_idx, dtype=torch.long))
samples = torch.cat(samples)
labels = torch.cat(labels)
logger.debug(f"Creating test dataset for session {session}. There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
logger.debug(f"Classes in Session {session}'s Test Dataset: {labels.unique(sorted=True)}")
datasets[session] = torch.utils.data.TensorDataset(samples, labels)
return datasets
def _split_data_by_class(self, dataset: _CIFAR100):
# loop through the dataset and split the data by class
all_data = {}
for data, label in dataset:
if label not in all_data:
all_data[label] = []
all_data[label].append(data)
# stack the data for each key
for class_id in sorted(all_data.keys()):
all_data[class_id] = torch.stack(all_data[class_id])
return all_data
def _verify_splits(self):
# verify sessions
if self.sessions <= 0:
raise ValueError(
f'Number of sessions should be greater than 0. Given "sessions": {self.sessions}'
)
# verify known classes
if not (0 <= self.known <= 100 - self.sessions):
raise ValueError(
f'Number of known classes should be between 0 and 100 - sessions ({self.sessions}). Given "known": {self.known}, "sessions": {self.sessions}'
)
# verify pretrain_n_known
if not (0 <= self.pretrain_n_known <= 500 - (self.cl_n_known * self.sessions)):
raise ValueError(
f'Number of samples per known class for pretraining should be between 0 and 500 - (cl_n_known * sessions). Given "pretrain_n_known": {self.pretrain_n_known}, "cl_n_known": {self.cl_n_known}, "sessions": {self.sessions}'
)
# verify cl_n_known
if not (0 <= self.cl_n_known <= (500 - self.pretrain_n_known) / self.sessions):
raise ValueError(
f'Number of samples per known class for each CL session should be between 0 and (500 - pretrain_n_known) / sessions. Given "cl_n_known": {self.cl_n_known}, "pretrain_n_known": {self.pretrain_n_known}, "sessions": {self.sessions}'
)
# verify cl_n_novel
if not (0 <= self.cl_n_novel <= 500 - self.sessions):
raise ValueError(
f'Number of samples per novel class for each CL session should be between 0 and 500 - sessions. Given "cl_n_novel": {self.cl_n_novel}, "sessions": {self.sessions}'
)
# verify cl_n_prevnovel
if not (0 <= self.cl_n_prevnovel <= (500 - self.cl_n_novel) / self.sessions):
raise ValueError(
f'Number of samples per previously novel class for each CL session should be between 0 and (500 - cl_n_novel) / sessions. Given "cl_n_prevnovel": {self.cl_n_prevnovel}, "cl_n_novel": {self.cl_n_novel}, "sessions": {self.sessions}'
)
if __name__ == "__main__":
from time import sleep
logger.debug("Entry Point: cifar100.py")
if CIFAR100_DIR is None:
raise ValueError("CIFAR100_DIR is not set. Please set it in entcl/config.py")
cifar100 = CIFAR100Dataset()
for session in range(cifar100.sessions + 1):
logger.debug(f"Session {session}")
logger.debug(f"Train Dataset: {cifar100.get_dataset(session, train=True)}")
logger.debug(f"Test Dataset: {cifar100.get_dataset(session, train=False)}")
sleep(5)
\ No newline at end of file
import torch
from torchvision.datasets import CIFAR100 as _CIFAR100
import torchvision.transforms as transforms
t = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
train = _CIFAR100(
"/cl/datasets/CIFAR", train=True, transform=t, download=True
)
first_sample, first_label = train[0]
datadict = {}
for data, label in train:
if label not in datadict:
datadict[label] = []
datadict[label].append(data)
class_tensors = []
for class_id in sorted(datadict.keys()):
class_tensors.append(torch.stack(datadict[class_id]))
to_return = torch.stack(class_tensors)
assert to_return[first_label][0].allclose(first_sample)
# sort the datadict
import torch
from loguru import logger
class DINOHead(torch.torch.nn.Module):
def __init__(
self,
in_dim,
out_dim,
use_bn=False,
norm_last_layer=True,
nlayers=3,
hidden_dim=2048,
bottleneck_dim=256,
):
super().__init__()
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = torch.nn.Linear(in_dim, bottleneck_dim)
elif nlayers != 0:
layers = [torch.nn.Linear(in_dim, hidden_dim)]
if use_bn:
layers.append(torch.nn.BatchNorm1d(hidden_dim))
layers.append(torch.nn.GELU())
for _ in range(nlayers - 2):
layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
if use_bn:
layers.append(torch.nn.BatchNorm1d(hidden_dim))
layers.append(torch.nn.GELU())
layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = torch.nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = torch.nn.utils.weight_norm(
torch.nn.Linear(in_dim, out_dim, bias=False)
)
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
torch.torch.nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, torch.nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def forward(self, x):
x_proj = self.mlp(x)
x = torch.nn.functional.normalize(x, dim=-1, p=2)
# x = x.detach()
logits = self.last_layer(x)
return x_proj, logits
if __name__ == "__main__":
model = DINOHead(768, 100, nlayers=3)
print(model)
\ No newline at end of file
import torch
from loguru import logger
class LinearHead(torch.nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.fc = torch.nn.Linear(in_features, out_features)
def forward(self, x):
return self.fc(x)
class LinearHead2(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, hidden_dim1:int, hidden_dim2:int):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(in_features, hidden_dim1),
torch.nn.GELU(),
torch.nn.Linear(hidden_dim1, hidden_dim2),
torch.nn.GELU(),
torch.nn.Linear(hidden_dim2, out_features, bias=False)
)
self._init_weights(self.mlp)
def _init_weights(self, m):
for layer in m:
if isinstance(layer, torch.nn.Linear):
torch.nn.init.kaiming_normal_(layer.weight)
torch.nn.init.zeros_(layer.bias)
def forward(self, x):
return self.mlp(x)
\ No newline at end of file
import os
from loguru import logger
import torch
class ENTCLModel(torch.nn.Module):
def __init__(self, head: torch.nn.Module):
super().__init__()
# load the backbone
self.backbone = torch.hub.load(os.path.join(os.path.dirname(__file__), 'dinov2'), 'dinov2_vitb14', source='local')
# freeze the backbone
for param in self.backbone.parameters():
param.requires_grad = False
# set the head
self.head = head
def forward(self, x):
feats = self.backbone(x)
logits = self.head(feats)
return logits, feats
def train(self, mode=True):
super().train(mode)
self.backbone.train(False)
if __name__ == "__main__":
model = ENTCLModel(head=None)
print(model)
\ No newline at end of file
import os
from loguru import logger
import torch
from tqdm import tqdm
def pretrain(args, model):
train_dataset = args.dataset.get_dataset(session=0, train=True)
train_dataloader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
drop_last=True,
)
val_dataset = args.dataset.get_dataset(session=0, train=False)
val_dataloader = torch.utils.data.DataLoader(
dataset=val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
drop_last=False,
)
optimiser = torch.optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
)
criterion = torch.nn.CrossEntropyLoss()
accuracies = []
if args.pretrain_load is not None:
logger.debug(f"Loading pretrained model from {args.pretrain_load}")
model.load_state_dict(torch.load(args.pretrain_load, weights_only=True))
model = model.to(args.device)
model, accuracy = _validate(args, model, val_dataloader)
logger.info(f"Loaded Pretrained Model Accuracy: {accuracy}")
return model
else:
logger.debug("No pretrained model to load, training from scratch")
model = model.to(args.device)
for epoch in range(args.pretrain_epochs):
logger.debug(f"Epoch {epoch} Started")
# train model
print(f"Epoch {epoch}:")
model, loss_total = _train(args, model, train_dataloader, optimiser, criterion)
logger.info(f"Epoch {epoch}: Loss: {loss_total}")
# validate model
model, accuracy = _validate(args, model, val_dataloader)
accuracies.append(accuracy)
logger.info(f"Epoch {epoch}: Accuracy: {accuracy}")
torch.save(model.state_dict(), os.path.join(args.exp_dir, f"model_{epoch}.pt"))
# select the best model
if args.pretrain_sel_strat == 'best':
best_epoch = accuracies.index(max(accuracies))
logger.info(f"Best Epoch: {best_epoch}")
model.load_state_dict(torch.load(os.path.join(args.exp_dir, f"model_{best_epoch}.pt"), weights_only=True))
# remove all other models
if not args.retain_all:
for epoch in range(args.pretrain_epochs):
if epoch != best_epoch:
os.remove(os.path.join(args.exp_dir, f"model_{epoch}.pt"))
# delete all models except the last one, return the last model
elif args.pretrain_sel_strat == 'last':
if not args.retain_all:
for epoch in range(args.pretrain_epochs - 1):
os.remove(os.path.join(args.exp_dir, f"model_{epoch}.pt"))
elif args.pretrain_sel_strat == 'load':
model.load_state_dict(torch.load(args.pretrain_load, weights_only=True))
return model
def _train(args, model, train_dataloader, optimiser, criterion):
model.train()
loss_total = 0
for x, y in tqdm(train_dataloader, desc=f"Training", unit = "batch"):
x, y = x.to(args.device), y.to(args.device)
logger.debug(f"x shape: {x.shape}, y shape: {y.shape}")
optimiser.zero_grad()
logits, _ = model(x)
logger.debug(f"logits shape: {logits.shape}")
loss = criterion(logits, y)
loss.backward()
optimiser.step()
loss = loss.item()
logger.debug(f"Loss: {loss}")
loss_total += loss
loss_total /= len(train_dataloader)
return model, loss_total
def _validate(args, model, val_dataloader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in tqdm(val_dataloader, desc=f"Validating", unit = "batch"):
x, y = x.to(args.device), y.to(args.device)
logger.debug(f"x shape: {x.shape}, y shape: {y.shape}")
logits, _ = model(x)
logger.debug(f"logits shape: {logits.shape}")
_, predicted = torch.max(logits, 1)
num_correct = (predicted == y).sum().item()
total += y.size(0)
correct += num_correct
logger.debug(f"This Batch Num Correct: {num_correct}, Total: {y.size(0)}, Accuracy: {num_correct / y.size(0)}")
return model, correct / total
import argparse
import os
import sys
from loguru import logger
from datetime import datetime
import torch
from entcl.models.model import ENTCLModel
from entcl.pretrain import pretrain
@logger.catch
def main():
parser = argparse.ArgumentParser()
# program args
parser.add_argument('--name', type=str, default="entcl_" + datetime.now().isoformat(timespec='seconds'))
parser.add_argument('--mode', type=str, default='pretrain', help='Mode to run the program', choices=['pretrain', 'cl', 'dryrun'])
parser.add_argument('--dryrun', action='store_true', default=False, help='Dry Run Mode. Does not save anything')
parser.add_argument('--debug', action='store_true', default=False, help='Debug Mode. Epochs are only done once. Enables Verbose Mode automatically')
parser.add_argument('--verbose', action='store_true', default=False, help='Verbose Mode. Enables debug logs')
parser.add_argument('--seed', type=int, default=8008135, help='Seed for reproducibility')
parser.add_argument('--device', type=int, default=0, help='cuda device to use')
parser.add_argument('--exp_root', type=str, default='/cl/entcl_LFS/experiments', help='Root directory for experiments')
# dataloader args
parser.add_argument('--batch_size', type=int, default=128, help='Batch Size for all dataloaders')
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for all dataloaders')
parser.add_argument('--pin_memory', action='store_true', default=True, help='Pin Memory for all dataloaders')
# dataset args
parser.add_argument('--dataset', type=str, default='cifar100', help='Dataset to use', choices=['cifar100'])
# optimiser args
parser.add_argument('--lr', type=float, default=0.1, help='Learning Rate for all optimisers')
parser.add_argument('--gamma', type=float, default=0.1, help='Gamma for all optimisers')
parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for all optimisers')
parser.add_argument('--weight_decay', type=float, default=5e-5, help='Weight Decay for all optimisers')
# cl args
parser.add_argument('--known', type=int, default=50, help='Number of known classes. The rest are novel classes')
parser.add_argument('--pretrain_n_known', type=int, default=400, help='How many samples per known class to use for pretraining')
parser.add_argument('--cl_n_novel', type=int, default=400, help='How many novel samples per novel class to use in each session of cl') # this val * novel_classes = novel samples per session
parser.add_argument('--cl_n_known', type=int, default=20, help='How many known samples per known class to use in each session of cl') # this val * known_classes = known samples per session
parser.add_argument('--cl_n_prevnovel', type=int, default=20, help='How many known samples per previously-novel class to use in each session of cl (Classes that used to be novel, but are now known)') # this val * knownnovel_classes = knownnovel samples per session
parser.add_argument('--sessions', type=int, default=5, help='Number of mixed incremental continual learning sessions')
parser.add_argument('--cl_epochs', type=int, default=100, help='Number of epochs for continual learning sessions')
# pretrain args
parser.add_argument('--pretrain_epochs', type=int, default=100, help='Number of epochs for pretraining')
parser.add_argument('--pretrain_sel_strat', type=str, default='last', choices=['last', 'best'], help='Pretrain Model Selection Strategy')
parser.add_argument('--pretrain_load', type=str, default=None, help='Path to a pretrained model to load')
parser.add_argument('--retain_all', action='store_true', default=False, help='Keep all model checkpoints')
# model args
parser.add_argument('--head', type=str, default='linear2', help='Classification head to use', choices=['linear','linear2', 'dino_head'])
args = parser.parse_args()
# setup device
if not torch.cuda.is_available():
raise ValueError("CUDA not available")
else:
if args.device >= torch.cuda.device_count():
raise ValueError(f"Invalid device {args.device}. There are only {torch.cuda.device_count()} devices available")
else:
args.device = torch.device(f'cuda:{args.device}')
# enable verbose mode if debug mode is enabled
if args.debug:
args.verbose = True
args.name = f"debug" + args.name
args.cl_epochs = 1
args.pretrain_epochs = 1
# initialise directories
os.makedirs(args.exp_root, exist_ok=True)
args.exp_dir = os.path.join(args.exp_root, args.name)
os.makedirs(args.exp_dir, exist_ok=False)
# initialise logger
logger_format = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
logger.remove(0)
logger.add(os.path.join(args.exp_dir, f'{args.name}.log'), format=logger_format, level='DEBUG' if args.verbose else 'INFO')
logger.add(sys.stdout, format=logger_format, level='DEBUG' if args.verbose else 'INFO')
# initialise dataset
if args.dataset == 'cifar100':
from entcl.data.cifar100 import CIFAR100Dataset
args.dataset = CIFAR100Dataset(known=args.known, pretrain_n_known=args.pretrain_n_known, cl_n_known=args.cl_n_known, cl_n_novel=args.cl_n_novel, cl_n_prevnovel=args.cl_n_prevnovel, sessions=5)
if args.head == 'linear':
from entcl.models.linear_head import LinearHead
args.head = LinearHead(in_features=768, out_features=args.dataset.num_classes)
logger.debug(f"Using Linear Head: {args.head}")
elif args.head == 'dino_head':
from entcl.models.dinohead import DINOHead
args.head = DINOHead(768, args.dataset.known, nlayers=3)
elif args.head == 'linear2':
from entcl.models.linear_head import LinearHead2
args.head = LinearHead2(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256)
logger.debug(f"Using Linear2 Head: {args.head}")
model = ENTCLModel(head=args.head)
logger.debug(f"Model: {model}")
model = pretrain(args, model)
if __name__ == "__main__":
logger.debug("Entry Point: run.py")
main()
\ No newline at end of file
setup.py 0 → 100644
from setuptools import setup, find_packages
setup(
name='entcl',
version='0.1.0',
packages=find_packages(),
install_requires=[
# Add your project's dependencies here
# e.g., 'requests', 'numpy', etc.
],
entry_points={
'console_scripts': [
# Add command line scripts here
# e.g., 'your_command=your_module:main_function',
],
},
#author='Your Name',
#author_email='your.email@example.com',
#description='A short description of your package',
#long_description=open('README.md').read(),
#long_description_content_type='text/markdown',
#url='https://github.com/yourusername/your-repo',
#classifiers=[
# 'Programming Language :: Python :: 3',
# 'License :: OSI Approved :: MIT License',
# 'Operating System :: OS Independent',
#],
#python_requires='>=3.6',
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment