Skip to content
Snippets Groups Projects
Commit 74273ffc authored by Joseph Omar's avatar Joseph Omar
Browse files

idk

parent 267400f4
No related branches found
No related tags found
No related merge requests found
......@@ -159,9 +159,10 @@ class CIFAR100Dataset:
samples = torch.cat(samples)
labels = torch.cat(labels)
types = torch.full((labels.size(0),), 0, dtype=torch.long) # they are all known classes, so type is 0
logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
logger.debug(f"Classes in Pretraining Dataset: {labels.unique(sorted=True)}")
datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels, types), transform=self.transform)
# CL sessions' datasets
logger.debug("Splitting data for CL sessions")
......
......
......@@ -20,6 +20,8 @@ def main(args: argparse.Namespace):
logger.debug(f"Model: {model}")
logger.info("Pretraining Model (Session 0)")
args.current_session = 0
os.makedirs(os.path.join(args.exp_dir, 'session_0'), exist_ok=True)
model = pretrain(args, model)
if args.mode in ['cl', 'both']:
......@@ -28,6 +30,8 @@ def main(args: argparse.Namespace):
for session in range(1, args.sessions + 1):
logger.info(f"Starting Continual Learning Session {session}")
args.current_session = session
os.makedirs(os.path.join(args.exp_dir, f'session_{session}'), exist_ok=True)
session_dataset = args.dataset.get_dataset(session)
# OOD detection
......@@ -40,6 +44,7 @@ def main(args: argparse.Namespace):
# Expand Classification Head & Initialise
model.head.expand(args.dataset.novel_inc) # we are cheating here, we know the number of novel classes
assert model.head.fc.out_features == args.dataset.known + session * args.dataset.novel_inc, f"Head has {model.head.fc.out_features} features, expected {args.dataset.known + session * args.dataset.novel_inc}"
# freeze the weights for the existing classes. We are only training unknown samples (EG: 50 (known) + (2 (session) - 1) * 10 (novel_inc) = 60 classes have been seen in cl session 2)
model.head.freeze(start_idx=0, end_idx=args.dataset.known + ((session -1) * args.dataset.novel_inc))
......@@ -150,14 +155,14 @@ if __name__ == "__main__":
if args.head == 'linear':
from entcl.models.linear_head import LinearHead
args.head = LinearHead(in_features=768, out_features=args.dataset.num_classes)
args.head = LinearHead(in_features=768, out_features=args.dataset.known)
logger.debug(f"Using Linear Head: {args.head}")
elif args.head == 'dino_head':
from entcl.models.dinohead import DINOHead
args.head = DINOHead(768, args.dataset.known, nlayers=3)
elif args.head == 'mlp':
from entcl.models.linear_head import MLPHead
args.head = MLPHead(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256)
args.head = MLPHead(in_features=768, out_features=args.dataset.known, hidden_dim1=512, hidden_dim2=256)
logger.debug(f"Using MLP Head: {args.head}")
if args.mode == 'cl' and args.pretrain_load is None:
......
......
......@@ -6,16 +6,21 @@ channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- antlr-python-runtime=4.9.3=pyhd8ed1ab_1
- blas=1.0=mkl
- bottleneck=1.4.2=py310ha9d4c09_0
- brotli=1.0.9=h5eee18b_8
- brotli-bin=1.0.9=h5eee18b_8
- brotli-python=1.0.9=py310h6a678d5_8
- bzip2=1.0.8=h5eee18b_6
- c-ares=1.19.1=h5eee18b_0
- ca-certificates=2024.9.24=h06a4308_0
- certifi=2024.8.30=py310h06a4308_0
- cffi=1.17.1=py310h1fdaa30_0
- charset-normalizer=3.3.2=pyhd3eb1b0_0
- contourpy=1.2.0=py310hdb19cb5_0
- cpython=3.10.15=py310hd8ed1ab_2
- cuda-cudart=11.7.99=0
- cuda-cupti=11.7.101=0
......@@ -24,49 +29,86 @@ dependencies:
- cuda-nvtx=11.7.91=0
- cuda-runtime=11.7.1=0
- cudatoolkit=11.3.1=h2bc3f7f_2
- cycler=0.11.0=pyhd3eb1b0_0
- cyrus-sasl=2.1.28=h52b45da_1
- dataclasses=0.8=pyh6d0b6a4_7
- dbus=1.13.18=hb2f20db_0
- expat=2.6.3=h6a678d5_0
- ffmpeg=4.3=hf484d3e_0
- filelock=3.13.1=py310h06a4308_0
- fontconfig=2.14.1=h55d465d_3
- fonttools=4.51.0=py310h5eee18b_0
- freetype=2.12.1=h4a9f257_0
- future=1.0.0=py310h06a4308_0
- fvcore=0.1.5.post20221221=pyhd8ed1ab_0
- glib=2.78.4=h6a678d5_0
- glib-tools=2.78.4=h6a678d5_0
- gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py310heeb90bb_0
- gnutls=3.6.15=he1e5248_0
- gst-plugins-base=1.14.1=h6a678d5_1
- gstreamer=1.14.1=h5eee18b_1
- icu=73.1=h6a678d5_0
- idna=3.7=py310h06a4308_0
- intel-openmp=2021.4.0=h06a4308_3561
- iopath=0.1.10=pyhd8ed1ab_0
- jinja2=3.1.4=py310h06a4308_1
- joblib=1.4.2=py310h06a4308_0
- jpeg=9e=h5eee18b_3
- kiwisolver=1.4.4=py310h6a678d5_0
- kneed=0.8.5=pyhd8ed1ab_0
- krb5=1.20.1=h143b758_1
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.40=h12ee557_0
- lerc=3.0=h295c915_0
- libbrotlicommon=1.0.9=h5eee18b_8
- libbrotlidec=1.0.9=h5eee18b_8
- libbrotlienc=1.0.9=h5eee18b_8
- libclang=14.0.6=default_hc6dbbc7_1
- libclang13=14.0.6=default_he11475f_1
- libcublas=11.10.3.66=0
- libcufft=10.7.2.124=h4fbf590_0
- libcufile=1.9.1.3=0
- libcups=2.4.2=h2d74bed_1
- libcurand=10.3.5.147=0
- libcurl=8.9.1=h251f7ec_0
- libcusolver=11.4.0.1=0
- libcusparse=11.7.4.91=0
- libdeflate=1.17=h5eee18b_1
- libedit=3.1.20230828=h5eee18b_0
- libev=4.33=h7f8727e_1
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libgcc=14.2.0=h77fa898_1
- libgcc-ng=14.2.0=h69a702a_1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libglib=2.78.4=hdc74915_0
- libgomp=14.2.0=h77fa898_1
- libiconv=1.16=h5eee18b_3
- libidn2=2.3.4=h5eee18b_0
- libllvm14=14.0.6=hecde1de_4
- libnghttp2=1.57.0=h2d74bed_0
- libnpp=11.7.4.75=0
- libnvjpeg=11.8.0.2=0
- libpng=1.6.39=h5eee18b_0
- libpq=17.0=hdbd6064_0
- libprotobuf=3.20.3=he621ea3_0
- libssh2=1.11.0=h251f7ec_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.1=h6a678d5_0
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.41.5=h5eee18b_0
- libwebp-base=1.3.2=h5eee18b_1
- libxcb=1.15=h7f8727e_0
- libxkbcommon=1.0.1=h097e994_2
- libxml2=2.13.1=hfdd30dd_2
- loguru=0.7.2=py310h06a4308_1
- lz4-c=1.9.4=h6a678d5_1
- markupsafe=2.1.3=py310h5eee18b_0
- matplotlib=3.9.2=py310h06a4308_0
- matplotlib-base=3.9.2=py310hbfdbfaf_0
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py310h7f8727e_0
- mkl_fft=1.3.1=py310hd6ae3a3_0
......@@ -74,43 +116,64 @@ dependencies:
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpmath=1.3.0=py310h06a4308_0
- mysql=8.4.0=h0bac5ae_0
- ncurses=6.4=h6a678d5_0
- nettle=3.7.3=hbbd107a_1
- networkx=3.3=py310h06a4308_0
- ninja=1.12.1=h06a4308_0
- ninja-base=1.12.1=hdb19cb5_0
- numexpr=2.8.4=py310h8879344_0
- numpy=1.24.3=py310hd5efca6_0
- numpy-base=1.24.3=py310h8e6c178_0
- omegaconf=2.3.0=pyhd8ed1ab_0
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.5.2=he7f1fd0_0
- openssl=3.0.15=h5eee18b_0
- openldap=2.6.4=h42fbc30_0
- openssl=3.4.0=hb9d3cd8_0
- pandas=2.2.2=py310h6a678d5_0
- pcre2=10.42=hebb0a14_1
- pillow=10.4.0=py310h5eee18b_0
- pip=24.2=py310h06a4308_0
- platformdirs=3.10.0=py310h06a4308_0
- ply=3.11=py310h06a4308_0
- pooch=1.8.2=py310h06a4308_0
- portalocker=2.3.0=py310h06a4308_1
- pycparser=2.21=pyhd3eb1b0_0
- pyparsing=3.2.0=py310h06a4308_0
- pyqt=5.15.10=py310h6a678d5_0
- pyqt5-sip=12.13.0=py310h5eee18b_0
- pysocks=1.7.1=py310h06a4308_0
- python=3.10.15=he870216_1
- python-dateutil=2.9.0post0=py310h06a4308_2
- python-tzdata=2023.3=pyhd3eb1b0_0
- python_abi=3.10=2_cp310
- pytorch=2.0.0=py3.10_cuda11.7_cudnn8.5.0_0
- pytorch-cuda=11.7=h778d358_5
- pytorch-mutex=1.0=cuda
- pyyaml=6.0.2=py310h5eee18b_0
- qt-main=5.15.2=hb6262e9_11
- readline=8.2=h5eee18b_0
- requests=2.32.3=py310h06a4308_1
- scikit-learn=1.5.1=py310h1128e8f_0
- seaborn=0.13.2=py310h06a4308_0
- setuptools=75.1.0=py310h06a4308_0
- sip=6.7.12=py310h6a678d5_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.45.3=h5eee18b_0
- sympy=1.13.3=pyh2585a3b_104
- tabulate=0.9.0=py310h06a4308_0
- termcolor=2.1.0=py310h06a4308_0
- threadpoolctl=3.5.0=py310h2f386ee_0
- tk=8.6.14=h39e8969_0
- tomli=2.0.1=py310h06a4308_0
- torchaudio=2.0.0=py310_cu117
- torchtriton=2.0.0=py310
- torchvision=0.15.0=py310_cu117
- tornado=6.4.1=py310h5eee18b_0
- tqdm=4.66.5=py310h2f386ee_0
- typing-extensions=4.11.0=py310h06a4308_0
- typing_extensions=4.11.0=py310h06a4308_0
- unicodedata2=15.1.0=py310h5eee18b_0
- urllib3=2.2.3=py310h06a4308_0
- wheel=0.44.0=py310h06a4308_0
- xformers=0.0.19=py310_cu11.8.0_pyt2.0.0
......@@ -137,8 +200,8 @@ dependencies:
- distributed-ucxx-cu11==0.40.0
- fastrlock==0.8.2
- fsspec==2024.10.0
- gapstatistics==0.1.5
- importlib-metadata==8.5.0
- joblib==1.4.2
- libcudf-cu11==24.10.1
- libucx-cu11==1.17.0
- libucxx-cu11==0.40.0
......@@ -150,7 +213,6 @@ dependencies:
- numba==0.60.0
- nvtx==0.2.10
- packaging==24.2
- pandas==2.2.2
- partd==1.4.2
- psutil==6.1.0
- ptxcompiler-cu11==0.8.1.post2
......@@ -159,7 +221,6 @@ dependencies:
- pylibcudf-cu11==24.10.1
- pylibraft-cu11==24.10.0
- pynvml==11.4.1
- python-dateutil==2.9.0.post0
- pytz==2024.2
- raft-dask-cu11==24.10.0
- rapids-dask-dependency==24.10.0
......@@ -170,7 +231,6 @@ dependencies:
- submitit==1.5.2
- tblib==3.0.0
- toolz==1.0.0
- tornado==6.4.1
- treelite==4.3.0
- tzdata==2024.2
- ucx-py-cu11==0.40.0
......
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment