Skip to content
Snippets Groups Projects
Commit 74273ffc authored by Joseph Omar's avatar Joseph Omar
Browse files

idk

parent 267400f4
No related branches found
No related tags found
No related merge requests found
...@@ -159,9 +159,10 @@ class CIFAR100Dataset: ...@@ -159,9 +159,10 @@ class CIFAR100Dataset:
samples = torch.cat(samples) samples = torch.cat(samples)
labels = torch.cat(labels) labels = torch.cat(labels)
types = torch.full((labels.size(0),), 0, dtype=torch.long) # they are all known classes, so type is 0
logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes") logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
logger.debug(f"Classes in Pretraining Dataset: {labels.unique(sorted=True)}") logger.debug(f"Classes in Pretraining Dataset: {labels.unique(sorted=True)}")
datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform) datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels, types), transform=self.transform)
# CL sessions' datasets # CL sessions' datasets
logger.debug("Splitting data for CL sessions") logger.debug("Splitting data for CL sessions")
... ...
......
...@@ -20,6 +20,8 @@ def main(args: argparse.Namespace): ...@@ -20,6 +20,8 @@ def main(args: argparse.Namespace):
logger.debug(f"Model: {model}") logger.debug(f"Model: {model}")
logger.info("Pretraining Model (Session 0)") logger.info("Pretraining Model (Session 0)")
args.current_session = 0
os.makedirs(os.path.join(args.exp_dir, 'session_0'), exist_ok=True)
model = pretrain(args, model) model = pretrain(args, model)
if args.mode in ['cl', 'both']: if args.mode in ['cl', 'both']:
...@@ -28,6 +30,8 @@ def main(args: argparse.Namespace): ...@@ -28,6 +30,8 @@ def main(args: argparse.Namespace):
for session in range(1, args.sessions + 1): for session in range(1, args.sessions + 1):
logger.info(f"Starting Continual Learning Session {session}") logger.info(f"Starting Continual Learning Session {session}")
args.current_session = session args.current_session = session
os.makedirs(os.path.join(args.exp_dir, f'session_{session}'), exist_ok=True)
session_dataset = args.dataset.get_dataset(session) session_dataset = args.dataset.get_dataset(session)
# OOD detection # OOD detection
...@@ -40,6 +44,7 @@ def main(args: argparse.Namespace): ...@@ -40,6 +44,7 @@ def main(args: argparse.Namespace):
# Expand Classification Head & Initialise # Expand Classification Head & Initialise
model.head.expand(args.dataset.novel_inc) # we are cheating here, we know the number of novel classes model.head.expand(args.dataset.novel_inc) # we are cheating here, we know the number of novel classes
assert model.head.fc.out_features == args.dataset.known + session * args.dataset.novel_inc, f"Head has {model.head.fc.out_features} features, expected {args.dataset.known + session * args.dataset.novel_inc}"
# freeze the weights for the existing classes. We are only training unknown samples (EG: 50 (known) + (2 (session) - 1) * 10 (novel_inc) = 60 classes have been seen in cl session 2) # freeze the weights for the existing classes. We are only training unknown samples (EG: 50 (known) + (2 (session) - 1) * 10 (novel_inc) = 60 classes have been seen in cl session 2)
model.head.freeze(start_idx=0, end_idx=args.dataset.known + ((session -1) * args.dataset.novel_inc)) model.head.freeze(start_idx=0, end_idx=args.dataset.known + ((session -1) * args.dataset.novel_inc))
...@@ -150,14 +155,14 @@ if __name__ == "__main__": ...@@ -150,14 +155,14 @@ if __name__ == "__main__":
if args.head == 'linear': if args.head == 'linear':
from entcl.models.linear_head import LinearHead from entcl.models.linear_head import LinearHead
args.head = LinearHead(in_features=768, out_features=args.dataset.num_classes) args.head = LinearHead(in_features=768, out_features=args.dataset.known)
logger.debug(f"Using Linear Head: {args.head}") logger.debug(f"Using Linear Head: {args.head}")
elif args.head == 'dino_head': elif args.head == 'dino_head':
from entcl.models.dinohead import DINOHead from entcl.models.dinohead import DINOHead
args.head = DINOHead(768, args.dataset.known, nlayers=3) args.head = DINOHead(768, args.dataset.known, nlayers=3)
elif args.head == 'mlp': elif args.head == 'mlp':
from entcl.models.linear_head import MLPHead from entcl.models.linear_head import MLPHead
args.head = MLPHead(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256) args.head = MLPHead(in_features=768, out_features=args.dataset.known, hidden_dim1=512, hidden_dim2=256)
logger.debug(f"Using MLP Head: {args.head}") logger.debug(f"Using MLP Head: {args.head}")
if args.mode == 'cl' and args.pretrain_load is None: if args.mode == 'cl' and args.pretrain_load is None:
... ...
......
...@@ -6,16 +6,21 @@ channels: ...@@ -6,16 +6,21 @@ channels:
- conda-forge - conda-forge
- defaults - defaults
dependencies: dependencies:
- _libgcc_mutex=0.1=main - _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=5.1=1_gnu - _openmp_mutex=4.5=2_gnu
- antlr-python-runtime=4.9.3=pyhd8ed1ab_1 - antlr-python-runtime=4.9.3=pyhd8ed1ab_1
- blas=1.0=mkl - blas=1.0=mkl
- bottleneck=1.4.2=py310ha9d4c09_0
- brotli=1.0.9=h5eee18b_8
- brotli-bin=1.0.9=h5eee18b_8
- brotli-python=1.0.9=py310h6a678d5_8 - brotli-python=1.0.9=py310h6a678d5_8
- bzip2=1.0.8=h5eee18b_6 - bzip2=1.0.8=h5eee18b_6
- c-ares=1.19.1=h5eee18b_0
- ca-certificates=2024.9.24=h06a4308_0 - ca-certificates=2024.9.24=h06a4308_0
- certifi=2024.8.30=py310h06a4308_0 - certifi=2024.8.30=py310h06a4308_0
- cffi=1.17.1=py310h1fdaa30_0 - cffi=1.17.1=py310h1fdaa30_0
- charset-normalizer=3.3.2=pyhd3eb1b0_0 - charset-normalizer=3.3.2=pyhd3eb1b0_0
- contourpy=1.2.0=py310hdb19cb5_0
- cpython=3.10.15=py310hd8ed1ab_2 - cpython=3.10.15=py310hd8ed1ab_2
- cuda-cudart=11.7.99=0 - cuda-cudart=11.7.99=0
- cuda-cupti=11.7.101=0 - cuda-cupti=11.7.101=0
...@@ -24,49 +29,86 @@ dependencies: ...@@ -24,49 +29,86 @@ dependencies:
- cuda-nvtx=11.7.91=0 - cuda-nvtx=11.7.91=0
- cuda-runtime=11.7.1=0 - cuda-runtime=11.7.1=0
- cudatoolkit=11.3.1=h2bc3f7f_2 - cudatoolkit=11.3.1=h2bc3f7f_2
- cycler=0.11.0=pyhd3eb1b0_0
- cyrus-sasl=2.1.28=h52b45da_1
- dataclasses=0.8=pyh6d0b6a4_7 - dataclasses=0.8=pyh6d0b6a4_7
- dbus=1.13.18=hb2f20db_0
- expat=2.6.3=h6a678d5_0
- ffmpeg=4.3=hf484d3e_0 - ffmpeg=4.3=hf484d3e_0
- filelock=3.13.1=py310h06a4308_0 - filelock=3.13.1=py310h06a4308_0
- fontconfig=2.14.1=h55d465d_3
- fonttools=4.51.0=py310h5eee18b_0
- freetype=2.12.1=h4a9f257_0 - freetype=2.12.1=h4a9f257_0
- future=1.0.0=py310h06a4308_0 - future=1.0.0=py310h06a4308_0
- fvcore=0.1.5.post20221221=pyhd8ed1ab_0 - fvcore=0.1.5.post20221221=pyhd8ed1ab_0
- glib=2.78.4=h6a678d5_0
- glib-tools=2.78.4=h6a678d5_0
- gmp=6.2.1=h295c915_3 - gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py310heeb90bb_0 - gmpy2=2.1.2=py310heeb90bb_0
- gnutls=3.6.15=he1e5248_0 - gnutls=3.6.15=he1e5248_0
- gst-plugins-base=1.14.1=h6a678d5_1
- gstreamer=1.14.1=h5eee18b_1
- icu=73.1=h6a678d5_0
- idna=3.7=py310h06a4308_0 - idna=3.7=py310h06a4308_0
- intel-openmp=2021.4.0=h06a4308_3561 - intel-openmp=2021.4.0=h06a4308_3561
- iopath=0.1.10=pyhd8ed1ab_0 - iopath=0.1.10=pyhd8ed1ab_0
- jinja2=3.1.4=py310h06a4308_1 - jinja2=3.1.4=py310h06a4308_1
- joblib=1.4.2=py310h06a4308_0
- jpeg=9e=h5eee18b_3 - jpeg=9e=h5eee18b_3
- kiwisolver=1.4.4=py310h6a678d5_0
- kneed=0.8.5=pyhd8ed1ab_0
- krb5=1.20.1=h143b758_1
- lame=3.100=h7b6447c_0 - lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0 - lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.40=h12ee557_0 - ld_impl_linux-64=2.40=h12ee557_0
- lerc=3.0=h295c915_0 - lerc=3.0=h295c915_0
- libbrotlicommon=1.0.9=h5eee18b_8
- libbrotlidec=1.0.9=h5eee18b_8
- libbrotlienc=1.0.9=h5eee18b_8
- libclang=14.0.6=default_hc6dbbc7_1
- libclang13=14.0.6=default_he11475f_1
- libcublas=11.10.3.66=0 - libcublas=11.10.3.66=0
- libcufft=10.7.2.124=h4fbf590_0 - libcufft=10.7.2.124=h4fbf590_0
- libcufile=1.9.1.3=0 - libcufile=1.9.1.3=0
- libcups=2.4.2=h2d74bed_1
- libcurand=10.3.5.147=0 - libcurand=10.3.5.147=0
- libcurl=8.9.1=h251f7ec_0
- libcusolver=11.4.0.1=0 - libcusolver=11.4.0.1=0
- libcusparse=11.7.4.91=0 - libcusparse=11.7.4.91=0
- libdeflate=1.17=h5eee18b_1 - libdeflate=1.17=h5eee18b_1
- libedit=3.1.20230828=h5eee18b_0
- libev=4.33=h7f8727e_1
- libffi=3.4.4=h6a678d5_1 - libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1 - libgcc=14.2.0=h77fa898_1
- libgomp=11.2.0=h1234567_1 - libgcc-ng=14.2.0=h69a702a_1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libglib=2.78.4=hdc74915_0
- libgomp=14.2.0=h77fa898_1
- libiconv=1.16=h5eee18b_3 - libiconv=1.16=h5eee18b_3
- libidn2=2.3.4=h5eee18b_0 - libidn2=2.3.4=h5eee18b_0
- libllvm14=14.0.6=hecde1de_4
- libnghttp2=1.57.0=h2d74bed_0
- libnpp=11.7.4.75=0 - libnpp=11.7.4.75=0
- libnvjpeg=11.8.0.2=0 - libnvjpeg=11.8.0.2=0
- libpng=1.6.39=h5eee18b_0 - libpng=1.6.39=h5eee18b_0
- libpq=17.0=hdbd6064_0
- libprotobuf=3.20.3=he621ea3_0 - libprotobuf=3.20.3=he621ea3_0
- libssh2=1.11.0=h251f7ec_0
- libstdcxx-ng=11.2.0=h1234567_1 - libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0 - libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.1=h6a678d5_0 - libtiff=4.5.1=h6a678d5_0
- libunistring=0.9.10=h27cfd23_0 - libunistring=0.9.10=h27cfd23_0
- libuuid=1.41.5=h5eee18b_0 - libuuid=1.41.5=h5eee18b_0
- libwebp-base=1.3.2=h5eee18b_1 - libwebp-base=1.3.2=h5eee18b_1
- libxcb=1.15=h7f8727e_0
- libxkbcommon=1.0.1=h097e994_2
- libxml2=2.13.1=hfdd30dd_2
- loguru=0.7.2=py310h06a4308_1 - loguru=0.7.2=py310h06a4308_1
- lz4-c=1.9.4=h6a678d5_1 - lz4-c=1.9.4=h6a678d5_1
- markupsafe=2.1.3=py310h5eee18b_0 - markupsafe=2.1.3=py310h5eee18b_0
- matplotlib=3.9.2=py310h06a4308_0
- matplotlib-base=3.9.2=py310hbfdbfaf_0
- mkl=2021.4.0=h06a4308_640 - mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py310h7f8727e_0 - mkl-service=2.4.0=py310h7f8727e_0
- mkl_fft=1.3.1=py310hd6ae3a3_0 - mkl_fft=1.3.1=py310hd6ae3a3_0
...@@ -74,43 +116,64 @@ dependencies: ...@@ -74,43 +116,64 @@ dependencies:
- mpc=1.1.0=h10f8cd9_1 - mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1 - mpfr=4.0.2=hb69a4c5_1
- mpmath=1.3.0=py310h06a4308_0 - mpmath=1.3.0=py310h06a4308_0
- mysql=8.4.0=h0bac5ae_0
- ncurses=6.4=h6a678d5_0 - ncurses=6.4=h6a678d5_0
- nettle=3.7.3=hbbd107a_1 - nettle=3.7.3=hbbd107a_1
- networkx=3.3=py310h06a4308_0 - networkx=3.3=py310h06a4308_0
- ninja=1.12.1=h06a4308_0 - ninja=1.12.1=h06a4308_0
- ninja-base=1.12.1=hdb19cb5_0 - ninja-base=1.12.1=hdb19cb5_0
- numexpr=2.8.4=py310h8879344_0
- numpy=1.24.3=py310hd5efca6_0 - numpy=1.24.3=py310hd5efca6_0
- numpy-base=1.24.3=py310h8e6c178_0 - numpy-base=1.24.3=py310h8e6c178_0
- omegaconf=2.3.0=pyhd8ed1ab_0 - omegaconf=2.3.0=pyhd8ed1ab_0
- openh264=2.1.1=h4ff587b_0 - openh264=2.1.1=h4ff587b_0
- openjpeg=2.5.2=he7f1fd0_0 - openjpeg=2.5.2=he7f1fd0_0
- openssl=3.0.15=h5eee18b_0 - openldap=2.6.4=h42fbc30_0
- openssl=3.4.0=hb9d3cd8_0
- pandas=2.2.2=py310h6a678d5_0
- pcre2=10.42=hebb0a14_1
- pillow=10.4.0=py310h5eee18b_0 - pillow=10.4.0=py310h5eee18b_0
- pip=24.2=py310h06a4308_0 - pip=24.2=py310h06a4308_0
- platformdirs=3.10.0=py310h06a4308_0
- ply=3.11=py310h06a4308_0
- pooch=1.8.2=py310h06a4308_0
- portalocker=2.3.0=py310h06a4308_1 - portalocker=2.3.0=py310h06a4308_1
- pycparser=2.21=pyhd3eb1b0_0 - pycparser=2.21=pyhd3eb1b0_0
- pyparsing=3.2.0=py310h06a4308_0
- pyqt=5.15.10=py310h6a678d5_0
- pyqt5-sip=12.13.0=py310h5eee18b_0
- pysocks=1.7.1=py310h06a4308_0 - pysocks=1.7.1=py310h06a4308_0
- python=3.10.15=he870216_1 - python=3.10.15=he870216_1
- python-dateutil=2.9.0post0=py310h06a4308_2
- python-tzdata=2023.3=pyhd3eb1b0_0
- python_abi=3.10=2_cp310 - python_abi=3.10=2_cp310
- pytorch=2.0.0=py3.10_cuda11.7_cudnn8.5.0_0 - pytorch=2.0.0=py3.10_cuda11.7_cudnn8.5.0_0
- pytorch-cuda=11.7=h778d358_5 - pytorch-cuda=11.7=h778d358_5
- pytorch-mutex=1.0=cuda - pytorch-mutex=1.0=cuda
- pyyaml=6.0.2=py310h5eee18b_0 - pyyaml=6.0.2=py310h5eee18b_0
- qt-main=5.15.2=hb6262e9_11
- readline=8.2=h5eee18b_0 - readline=8.2=h5eee18b_0
- requests=2.32.3=py310h06a4308_1 - requests=2.32.3=py310h06a4308_1
- scikit-learn=1.5.1=py310h1128e8f_0
- seaborn=0.13.2=py310h06a4308_0
- setuptools=75.1.0=py310h06a4308_0 - setuptools=75.1.0=py310h06a4308_0
- sip=6.7.12=py310h6a678d5_0
- six=1.16.0=pyhd3eb1b0_1 - six=1.16.0=pyhd3eb1b0_1
- sqlite=3.45.3=h5eee18b_0 - sqlite=3.45.3=h5eee18b_0
- sympy=1.13.3=pyh2585a3b_104 - sympy=1.13.3=pyh2585a3b_104
- tabulate=0.9.0=py310h06a4308_0 - tabulate=0.9.0=py310h06a4308_0
- termcolor=2.1.0=py310h06a4308_0 - termcolor=2.1.0=py310h06a4308_0
- threadpoolctl=3.5.0=py310h2f386ee_0
- tk=8.6.14=h39e8969_0 - tk=8.6.14=h39e8969_0
- tomli=2.0.1=py310h06a4308_0
- torchaudio=2.0.0=py310_cu117 - torchaudio=2.0.0=py310_cu117
- torchtriton=2.0.0=py310 - torchtriton=2.0.0=py310
- torchvision=0.15.0=py310_cu117 - torchvision=0.15.0=py310_cu117
- tornado=6.4.1=py310h5eee18b_0
- tqdm=4.66.5=py310h2f386ee_0 - tqdm=4.66.5=py310h2f386ee_0
- typing-extensions=4.11.0=py310h06a4308_0 - typing-extensions=4.11.0=py310h06a4308_0
- typing_extensions=4.11.0=py310h06a4308_0 - typing_extensions=4.11.0=py310h06a4308_0
- unicodedata2=15.1.0=py310h5eee18b_0
- urllib3=2.2.3=py310h06a4308_0 - urllib3=2.2.3=py310h06a4308_0
- wheel=0.44.0=py310h06a4308_0 - wheel=0.44.0=py310h06a4308_0
- xformers=0.0.19=py310_cu11.8.0_pyt2.0.0 - xformers=0.0.19=py310_cu11.8.0_pyt2.0.0
...@@ -137,8 +200,8 @@ dependencies: ...@@ -137,8 +200,8 @@ dependencies:
- distributed-ucxx-cu11==0.40.0 - distributed-ucxx-cu11==0.40.0
- fastrlock==0.8.2 - fastrlock==0.8.2
- fsspec==2024.10.0 - fsspec==2024.10.0
- gapstatistics==0.1.5
- importlib-metadata==8.5.0 - importlib-metadata==8.5.0
- joblib==1.4.2
- libcudf-cu11==24.10.1 - libcudf-cu11==24.10.1
- libucx-cu11==1.17.0 - libucx-cu11==1.17.0
- libucxx-cu11==0.40.0 - libucxx-cu11==0.40.0
...@@ -150,7 +213,6 @@ dependencies: ...@@ -150,7 +213,6 @@ dependencies:
- numba==0.60.0 - numba==0.60.0
- nvtx==0.2.10 - nvtx==0.2.10
- packaging==24.2 - packaging==24.2
- pandas==2.2.2
- partd==1.4.2 - partd==1.4.2
- psutil==6.1.0 - psutil==6.1.0
- ptxcompiler-cu11==0.8.1.post2 - ptxcompiler-cu11==0.8.1.post2
...@@ -159,7 +221,6 @@ dependencies: ...@@ -159,7 +221,6 @@ dependencies:
- pylibcudf-cu11==24.10.1 - pylibcudf-cu11==24.10.1
- pylibraft-cu11==24.10.0 - pylibraft-cu11==24.10.0
- pynvml==11.4.1 - pynvml==11.4.1
- python-dateutil==2.9.0.post0
- pytz==2024.2 - pytz==2024.2
- raft-dask-cu11==24.10.0 - raft-dask-cu11==24.10.0
- rapids-dask-dependency==24.10.0 - rapids-dask-dependency==24.10.0
...@@ -170,7 +231,6 @@ dependencies: ...@@ -170,7 +231,6 @@ dependencies:
- submitit==1.5.2 - submitit==1.5.2
- tblib==3.0.0 - tblib==3.0.0
- toolz==1.0.0 - toolz==1.0.0
- tornado==6.4.1
- treelite==4.3.0 - treelite==4.3.0
- tzdata==2024.2 - tzdata==2024.2
- ucx-py-cu11==0.40.0 - ucx-py-cu11==0.40.0
... ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment