diff --git a/entcl/data/cifar100.py b/entcl/data/cifar100.py index 5fbbf9df6df141b67f376ca197e298f999ca19c9..1e534d6783deef3d15d553e7932d252eadb6dbe4 100644 --- a/entcl/data/cifar100.py +++ b/entcl/data/cifar100.py @@ -159,9 +159,10 @@ class CIFAR100Dataset: samples = torch.cat(samples) labels = torch.cat(labels) + types = torch.full((labels.size(0),), 0, dtype=torch.long) # they are all known classes, so type is 0 logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes") logger.debug(f"Classes in Pretraining Dataset: {labels.unique(sorted=True)}") - datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform) + datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels, types), transform=self.transform) # CL sessions' datasets logger.debug("Splitting data for CL sessions") diff --git a/entcl/run.py b/entcl/run.py index 0fd351457fe7b9834c4b0a5e717c452bc0b790dc..4133e1ea018e9374330857c9625d97b82fe52a04 100644 --- a/entcl/run.py +++ b/entcl/run.py @@ -20,6 +20,8 @@ def main(args: argparse.Namespace): logger.debug(f"Model: {model}") logger.info("Pretraining Model (Session 0)") + args.current_session = 0 + os.makedirs(os.path.join(args.exp_dir, 'session_0'), exist_ok=True) model = pretrain(args, model) if args.mode in ['cl', 'both']: @@ -28,6 +30,8 @@ def main(args: argparse.Namespace): for session in range(1, args.sessions + 1): logger.info(f"Starting Continual Learning Session {session}") args.current_session = session + os.makedirs(os.path.join(args.exp_dir, f'session_{session}'), exist_ok=True) + session_dataset = args.dataset.get_dataset(session) # OOD detection @@ -40,6 +44,7 @@ def main(args: argparse.Namespace): # Expand Classification Head & Initialise model.head.expand(args.dataset.novel_inc) # we are cheating here, we know the number of novel classes + assert model.head.fc.out_features == args.dataset.known + session * args.dataset.novel_inc, f"Head has {model.head.fc.out_features} features, expected {args.dataset.known + session * args.dataset.novel_inc}" # freeze the weights for the existing classes. We are only training unknown samples (EG: 50 (known) + (2 (session) - 1) * 10 (novel_inc) = 60 classes have been seen in cl session 2) model.head.freeze(start_idx=0, end_idx=args.dataset.known + ((session -1) * args.dataset.novel_inc)) @@ -150,14 +155,14 @@ if __name__ == "__main__": if args.head == 'linear': from entcl.models.linear_head import LinearHead - args.head = LinearHead(in_features=768, out_features=args.dataset.num_classes) + args.head = LinearHead(in_features=768, out_features=args.dataset.known) logger.debug(f"Using Linear Head: {args.head}") elif args.head == 'dino_head': from entcl.models.dinohead import DINOHead args.head = DINOHead(768, args.dataset.known, nlayers=3) elif args.head == 'mlp': from entcl.models.linear_head import MLPHead - args.head = MLPHead(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256) + args.head = MLPHead(in_features=768, out_features=args.dataset.known, hidden_dim1=512, hidden_dim2=256) logger.debug(f"Using MLP Head: {args.head}") if args.mode == 'cl' and args.pretrain_load is None: diff --git a/enviroment.yml b/enviroment.yml index 86f93904f4232f6a75749b06811cd61efbec2b4d..abc7d80627f122324dff6a5fa217e20b138b758d 100644 --- a/enviroment.yml +++ b/enviroment.yml @@ -6,16 +6,21 @@ channels: - conda-forge - defaults dependencies: - - _libgcc_mutex=0.1=main - - _openmp_mutex=5.1=1_gnu + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=2_gnu - antlr-python-runtime=4.9.3=pyhd8ed1ab_1 - blas=1.0=mkl + - bottleneck=1.4.2=py310ha9d4c09_0 + - brotli=1.0.9=h5eee18b_8 + - brotli-bin=1.0.9=h5eee18b_8 - brotli-python=1.0.9=py310h6a678d5_8 - bzip2=1.0.8=h5eee18b_6 + - c-ares=1.19.1=h5eee18b_0 - ca-certificates=2024.9.24=h06a4308_0 - certifi=2024.8.30=py310h06a4308_0 - cffi=1.17.1=py310h1fdaa30_0 - charset-normalizer=3.3.2=pyhd3eb1b0_0 + - contourpy=1.2.0=py310hdb19cb5_0 - cpython=3.10.15=py310hd8ed1ab_2 - cuda-cudart=11.7.99=0 - cuda-cupti=11.7.101=0 @@ -24,49 +29,86 @@ dependencies: - cuda-nvtx=11.7.91=0 - cuda-runtime=11.7.1=0 - cudatoolkit=11.3.1=h2bc3f7f_2 + - cycler=0.11.0=pyhd3eb1b0_0 + - cyrus-sasl=2.1.28=h52b45da_1 - dataclasses=0.8=pyh6d0b6a4_7 + - dbus=1.13.18=hb2f20db_0 + - expat=2.6.3=h6a678d5_0 - ffmpeg=4.3=hf484d3e_0 - filelock=3.13.1=py310h06a4308_0 + - fontconfig=2.14.1=h55d465d_3 + - fonttools=4.51.0=py310h5eee18b_0 - freetype=2.12.1=h4a9f257_0 - future=1.0.0=py310h06a4308_0 - fvcore=0.1.5.post20221221=pyhd8ed1ab_0 + - glib=2.78.4=h6a678d5_0 + - glib-tools=2.78.4=h6a678d5_0 - gmp=6.2.1=h295c915_3 - gmpy2=2.1.2=py310heeb90bb_0 - gnutls=3.6.15=he1e5248_0 + - gst-plugins-base=1.14.1=h6a678d5_1 + - gstreamer=1.14.1=h5eee18b_1 + - icu=73.1=h6a678d5_0 - idna=3.7=py310h06a4308_0 - intel-openmp=2021.4.0=h06a4308_3561 - iopath=0.1.10=pyhd8ed1ab_0 - jinja2=3.1.4=py310h06a4308_1 + - joblib=1.4.2=py310h06a4308_0 - jpeg=9e=h5eee18b_3 + - kiwisolver=1.4.4=py310h6a678d5_0 + - kneed=0.8.5=pyhd8ed1ab_0 + - krb5=1.20.1=h143b758_1 - lame=3.100=h7b6447c_0 - lcms2=2.12=h3be6417_0 - ld_impl_linux-64=2.40=h12ee557_0 - lerc=3.0=h295c915_0 + - libbrotlicommon=1.0.9=h5eee18b_8 + - libbrotlidec=1.0.9=h5eee18b_8 + - libbrotlienc=1.0.9=h5eee18b_8 + - libclang=14.0.6=default_hc6dbbc7_1 + - libclang13=14.0.6=default_he11475f_1 - libcublas=11.10.3.66=0 - libcufft=10.7.2.124=h4fbf590_0 - libcufile=1.9.1.3=0 + - libcups=2.4.2=h2d74bed_1 - libcurand=10.3.5.147=0 + - libcurl=8.9.1=h251f7ec_0 - libcusolver=11.4.0.1=0 - libcusparse=11.7.4.91=0 - libdeflate=1.17=h5eee18b_1 + - libedit=3.1.20230828=h5eee18b_0 + - libev=4.33=h7f8727e_1 - libffi=3.4.4=h6a678d5_1 - - libgcc-ng=11.2.0=h1234567_1 - - libgomp=11.2.0=h1234567_1 + - libgcc=14.2.0=h77fa898_1 + - libgcc-ng=14.2.0=h69a702a_1 + - libgfortran-ng=11.2.0=h00389a5_1 + - libgfortran5=11.2.0=h1234567_1 + - libglib=2.78.4=hdc74915_0 + - libgomp=14.2.0=h77fa898_1 - libiconv=1.16=h5eee18b_3 - libidn2=2.3.4=h5eee18b_0 + - libllvm14=14.0.6=hecde1de_4 + - libnghttp2=1.57.0=h2d74bed_0 - libnpp=11.7.4.75=0 - libnvjpeg=11.8.0.2=0 - libpng=1.6.39=h5eee18b_0 + - libpq=17.0=hdbd6064_0 - libprotobuf=3.20.3=he621ea3_0 + - libssh2=1.11.0=h251f7ec_0 - libstdcxx-ng=11.2.0=h1234567_1 - libtasn1=4.19.0=h5eee18b_0 - libtiff=4.5.1=h6a678d5_0 - libunistring=0.9.10=h27cfd23_0 - libuuid=1.41.5=h5eee18b_0 - libwebp-base=1.3.2=h5eee18b_1 + - libxcb=1.15=h7f8727e_0 + - libxkbcommon=1.0.1=h097e994_2 + - libxml2=2.13.1=hfdd30dd_2 - loguru=0.7.2=py310h06a4308_1 - lz4-c=1.9.4=h6a678d5_1 - markupsafe=2.1.3=py310h5eee18b_0 + - matplotlib=3.9.2=py310h06a4308_0 + - matplotlib-base=3.9.2=py310hbfdbfaf_0 - mkl=2021.4.0=h06a4308_640 - mkl-service=2.4.0=py310h7f8727e_0 - mkl_fft=1.3.1=py310hd6ae3a3_0 @@ -74,43 +116,64 @@ dependencies: - mpc=1.1.0=h10f8cd9_1 - mpfr=4.0.2=hb69a4c5_1 - mpmath=1.3.0=py310h06a4308_0 + - mysql=8.4.0=h0bac5ae_0 - ncurses=6.4=h6a678d5_0 - nettle=3.7.3=hbbd107a_1 - networkx=3.3=py310h06a4308_0 - ninja=1.12.1=h06a4308_0 - ninja-base=1.12.1=hdb19cb5_0 + - numexpr=2.8.4=py310h8879344_0 - numpy=1.24.3=py310hd5efca6_0 - numpy-base=1.24.3=py310h8e6c178_0 - omegaconf=2.3.0=pyhd8ed1ab_0 - openh264=2.1.1=h4ff587b_0 - openjpeg=2.5.2=he7f1fd0_0 - - openssl=3.0.15=h5eee18b_0 + - openldap=2.6.4=h42fbc30_0 + - openssl=3.4.0=hb9d3cd8_0 + - pandas=2.2.2=py310h6a678d5_0 + - pcre2=10.42=hebb0a14_1 - pillow=10.4.0=py310h5eee18b_0 - pip=24.2=py310h06a4308_0 + - platformdirs=3.10.0=py310h06a4308_0 + - ply=3.11=py310h06a4308_0 + - pooch=1.8.2=py310h06a4308_0 - portalocker=2.3.0=py310h06a4308_1 - pycparser=2.21=pyhd3eb1b0_0 + - pyparsing=3.2.0=py310h06a4308_0 + - pyqt=5.15.10=py310h6a678d5_0 + - pyqt5-sip=12.13.0=py310h5eee18b_0 - pysocks=1.7.1=py310h06a4308_0 - python=3.10.15=he870216_1 + - python-dateutil=2.9.0post0=py310h06a4308_2 + - python-tzdata=2023.3=pyhd3eb1b0_0 - python_abi=3.10=2_cp310 - pytorch=2.0.0=py3.10_cuda11.7_cudnn8.5.0_0 - pytorch-cuda=11.7=h778d358_5 - pytorch-mutex=1.0=cuda - pyyaml=6.0.2=py310h5eee18b_0 + - qt-main=5.15.2=hb6262e9_11 - readline=8.2=h5eee18b_0 - requests=2.32.3=py310h06a4308_1 + - scikit-learn=1.5.1=py310h1128e8f_0 + - seaborn=0.13.2=py310h06a4308_0 - setuptools=75.1.0=py310h06a4308_0 + - sip=6.7.12=py310h6a678d5_0 - six=1.16.0=pyhd3eb1b0_1 - sqlite=3.45.3=h5eee18b_0 - sympy=1.13.3=pyh2585a3b_104 - tabulate=0.9.0=py310h06a4308_0 - termcolor=2.1.0=py310h06a4308_0 + - threadpoolctl=3.5.0=py310h2f386ee_0 - tk=8.6.14=h39e8969_0 + - tomli=2.0.1=py310h06a4308_0 - torchaudio=2.0.0=py310_cu117 - torchtriton=2.0.0=py310 - torchvision=0.15.0=py310_cu117 + - tornado=6.4.1=py310h5eee18b_0 - tqdm=4.66.5=py310h2f386ee_0 - typing-extensions=4.11.0=py310h06a4308_0 - typing_extensions=4.11.0=py310h06a4308_0 + - unicodedata2=15.1.0=py310h5eee18b_0 - urllib3=2.2.3=py310h06a4308_0 - wheel=0.44.0=py310h06a4308_0 - xformers=0.0.19=py310_cu11.8.0_pyt2.0.0 @@ -137,8 +200,8 @@ dependencies: - distributed-ucxx-cu11==0.40.0 - fastrlock==0.8.2 - fsspec==2024.10.0 + - gapstatistics==0.1.5 - importlib-metadata==8.5.0 - - joblib==1.4.2 - libcudf-cu11==24.10.1 - libucx-cu11==1.17.0 - libucxx-cu11==0.40.0 @@ -150,7 +213,6 @@ dependencies: - numba==0.60.0 - nvtx==0.2.10 - packaging==24.2 - - pandas==2.2.2 - partd==1.4.2 - psutil==6.1.0 - ptxcompiler-cu11==0.8.1.post2 @@ -159,7 +221,6 @@ dependencies: - pylibcudf-cu11==24.10.1 - pylibraft-cu11==24.10.0 - pynvml==11.4.1 - - python-dateutil==2.9.0.post0 - pytz==2024.2 - raft-dask-cu11==24.10.0 - rapids-dask-dependency==24.10.0 @@ -170,7 +231,6 @@ dependencies: - submitit==1.5.2 - tblib==3.0.0 - toolz==1.0.0 - - tornado==6.4.1 - treelite==4.3.0 - tzdata==2024.2 - ucx-py-cu11==0.40.0