diff --git a/entcl/data/cifar100.py b/entcl/data/cifar100.py
index 5fbbf9df6df141b67f376ca197e298f999ca19c9..1e534d6783deef3d15d553e7932d252eadb6dbe4 100644
--- a/entcl/data/cifar100.py
+++ b/entcl/data/cifar100.py
@@ -159,9 +159,10 @@ class CIFAR100Dataset:
         
         samples = torch.cat(samples)
         labels = torch.cat(labels)
+        types = torch.full((labels.size(0),), 0, dtype=torch.long) # they are all known classes, so type is 0
         logger.debug(f"Creating dataset for session 0 (pretraining). There are {len(samples)} samples, and {len(labels)} labels. There are {labels.unique().size(0)} different classes")
         logger.debug(f"Classes in Pretraining Dataset: {labels.unique(sorted=True)}")
-        datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels), transform=self.transform)
+        datasets[0] = TransformedTensorDataset(tensor_dataset=torch.utils.data.TensorDataset(samples, labels, types), transform=self.transform)
         
         # CL sessions' datasets
         logger.debug("Splitting data for CL sessions")    
diff --git a/entcl/run.py b/entcl/run.py
index 0fd351457fe7b9834c4b0a5e717c452bc0b790dc..4133e1ea018e9374330857c9625d97b82fe52a04 100644
--- a/entcl/run.py
+++ b/entcl/run.py
@@ -20,6 +20,8 @@ def main(args: argparse.Namespace):
     logger.debug(f"Model: {model}")
     
     logger.info("Pretraining Model (Session 0)")
+    args.current_session = 0
+    os.makedirs(os.path.join(args.exp_dir, 'session_0'), exist_ok=True)
     model = pretrain(args, model)
     
     if args.mode in ['cl', 'both']:
@@ -28,6 +30,8 @@ def main(args: argparse.Namespace):
         for session in range(1, args.sessions + 1):
             logger.info(f"Starting Continual Learning Session {session}")
             args.current_session = session
+            os.makedirs(os.path.join(args.exp_dir, f'session_{session}'), exist_ok=True)
+            
             session_dataset = args.dataset.get_dataset(session)
             
             # OOD detection
@@ -40,6 +44,7 @@ def main(args: argparse.Namespace):
             
             # Expand Classification Head & Initialise
             model.head.expand(args.dataset.novel_inc) # we are cheating here, we know the number of novel classes
+            assert model.head.fc.out_features == args.dataset.known + session * args.dataset.novel_inc, f"Head has {model.head.fc.out_features} features, expected {args.dataset.known + session * args.dataset.novel_inc}"
             
             # freeze the weights for the existing classes. We are only training unknown samples (EG: 50 (known) + (2 (session) - 1) * 10 (novel_inc) = 60 classes have been seen in cl session 2)
             model.head.freeze(start_idx=0, end_idx=args.dataset.known + ((session -1) * args.dataset.novel_inc))
@@ -150,14 +155,14 @@ if __name__ == "__main__":
         
     if args.head == 'linear':
         from entcl.models.linear_head import LinearHead
-        args.head = LinearHead(in_features=768, out_features=args.dataset.num_classes)
+        args.head = LinearHead(in_features=768, out_features=args.dataset.known)
         logger.debug(f"Using Linear Head: {args.head}")
     elif args.head == 'dino_head':
         from entcl.models.dinohead import DINOHead
         args.head = DINOHead(768, args.dataset.known, nlayers=3)
     elif args.head == 'mlp':
         from entcl.models.linear_head import MLPHead
-        args.head = MLPHead(in_features=768, out_features=args.dataset.num_classes, hidden_dim1=512, hidden_dim2=256)
+        args.head = MLPHead(in_features=768, out_features=args.dataset.known, hidden_dim1=512, hidden_dim2=256)
         logger.debug(f"Using MLP Head: {args.head}")
         
     if args.mode == 'cl' and args.pretrain_load is None:
diff --git a/enviroment.yml b/enviroment.yml
index 86f93904f4232f6a75749b06811cd61efbec2b4d..abc7d80627f122324dff6a5fa217e20b138b758d 100644
--- a/enviroment.yml
+++ b/enviroment.yml
@@ -6,16 +6,21 @@ channels:
   - conda-forge
   - defaults
 dependencies:
-  - _libgcc_mutex=0.1=main
-  - _openmp_mutex=5.1=1_gnu
+  - _libgcc_mutex=0.1=conda_forge
+  - _openmp_mutex=4.5=2_gnu
   - antlr-python-runtime=4.9.3=pyhd8ed1ab_1
   - blas=1.0=mkl
+  - bottleneck=1.4.2=py310ha9d4c09_0
+  - brotli=1.0.9=h5eee18b_8
+  - brotli-bin=1.0.9=h5eee18b_8
   - brotli-python=1.0.9=py310h6a678d5_8
   - bzip2=1.0.8=h5eee18b_6
+  - c-ares=1.19.1=h5eee18b_0
   - ca-certificates=2024.9.24=h06a4308_0
   - certifi=2024.8.30=py310h06a4308_0
   - cffi=1.17.1=py310h1fdaa30_0
   - charset-normalizer=3.3.2=pyhd3eb1b0_0
+  - contourpy=1.2.0=py310hdb19cb5_0
   - cpython=3.10.15=py310hd8ed1ab_2
   - cuda-cudart=11.7.99=0
   - cuda-cupti=11.7.101=0
@@ -24,49 +29,86 @@ dependencies:
   - cuda-nvtx=11.7.91=0
   - cuda-runtime=11.7.1=0
   - cudatoolkit=11.3.1=h2bc3f7f_2
+  - cycler=0.11.0=pyhd3eb1b0_0
+  - cyrus-sasl=2.1.28=h52b45da_1
   - dataclasses=0.8=pyh6d0b6a4_7
+  - dbus=1.13.18=hb2f20db_0
+  - expat=2.6.3=h6a678d5_0
   - ffmpeg=4.3=hf484d3e_0
   - filelock=3.13.1=py310h06a4308_0
+  - fontconfig=2.14.1=h55d465d_3
+  - fonttools=4.51.0=py310h5eee18b_0
   - freetype=2.12.1=h4a9f257_0
   - future=1.0.0=py310h06a4308_0
   - fvcore=0.1.5.post20221221=pyhd8ed1ab_0
+  - glib=2.78.4=h6a678d5_0
+  - glib-tools=2.78.4=h6a678d5_0
   - gmp=6.2.1=h295c915_3
   - gmpy2=2.1.2=py310heeb90bb_0
   - gnutls=3.6.15=he1e5248_0
+  - gst-plugins-base=1.14.1=h6a678d5_1
+  - gstreamer=1.14.1=h5eee18b_1
+  - icu=73.1=h6a678d5_0
   - idna=3.7=py310h06a4308_0
   - intel-openmp=2021.4.0=h06a4308_3561
   - iopath=0.1.10=pyhd8ed1ab_0
   - jinja2=3.1.4=py310h06a4308_1
+  - joblib=1.4.2=py310h06a4308_0
   - jpeg=9e=h5eee18b_3
+  - kiwisolver=1.4.4=py310h6a678d5_0
+  - kneed=0.8.5=pyhd8ed1ab_0
+  - krb5=1.20.1=h143b758_1
   - lame=3.100=h7b6447c_0
   - lcms2=2.12=h3be6417_0
   - ld_impl_linux-64=2.40=h12ee557_0
   - lerc=3.0=h295c915_0
+  - libbrotlicommon=1.0.9=h5eee18b_8
+  - libbrotlidec=1.0.9=h5eee18b_8
+  - libbrotlienc=1.0.9=h5eee18b_8
+  - libclang=14.0.6=default_hc6dbbc7_1
+  - libclang13=14.0.6=default_he11475f_1
   - libcublas=11.10.3.66=0
   - libcufft=10.7.2.124=h4fbf590_0
   - libcufile=1.9.1.3=0
+  - libcups=2.4.2=h2d74bed_1
   - libcurand=10.3.5.147=0
+  - libcurl=8.9.1=h251f7ec_0
   - libcusolver=11.4.0.1=0
   - libcusparse=11.7.4.91=0
   - libdeflate=1.17=h5eee18b_1
+  - libedit=3.1.20230828=h5eee18b_0
+  - libev=4.33=h7f8727e_1
   - libffi=3.4.4=h6a678d5_1
-  - libgcc-ng=11.2.0=h1234567_1
-  - libgomp=11.2.0=h1234567_1
+  - libgcc=14.2.0=h77fa898_1
+  - libgcc-ng=14.2.0=h69a702a_1
+  - libgfortran-ng=11.2.0=h00389a5_1
+  - libgfortran5=11.2.0=h1234567_1
+  - libglib=2.78.4=hdc74915_0
+  - libgomp=14.2.0=h77fa898_1
   - libiconv=1.16=h5eee18b_3
   - libidn2=2.3.4=h5eee18b_0
+  - libllvm14=14.0.6=hecde1de_4
+  - libnghttp2=1.57.0=h2d74bed_0
   - libnpp=11.7.4.75=0
   - libnvjpeg=11.8.0.2=0
   - libpng=1.6.39=h5eee18b_0
+  - libpq=17.0=hdbd6064_0
   - libprotobuf=3.20.3=he621ea3_0
+  - libssh2=1.11.0=h251f7ec_0
   - libstdcxx-ng=11.2.0=h1234567_1
   - libtasn1=4.19.0=h5eee18b_0
   - libtiff=4.5.1=h6a678d5_0
   - libunistring=0.9.10=h27cfd23_0
   - libuuid=1.41.5=h5eee18b_0
   - libwebp-base=1.3.2=h5eee18b_1
+  - libxcb=1.15=h7f8727e_0
+  - libxkbcommon=1.0.1=h097e994_2
+  - libxml2=2.13.1=hfdd30dd_2
   - loguru=0.7.2=py310h06a4308_1
   - lz4-c=1.9.4=h6a678d5_1
   - markupsafe=2.1.3=py310h5eee18b_0
+  - matplotlib=3.9.2=py310h06a4308_0
+  - matplotlib-base=3.9.2=py310hbfdbfaf_0
   - mkl=2021.4.0=h06a4308_640
   - mkl-service=2.4.0=py310h7f8727e_0
   - mkl_fft=1.3.1=py310hd6ae3a3_0
@@ -74,43 +116,64 @@ dependencies:
   - mpc=1.1.0=h10f8cd9_1
   - mpfr=4.0.2=hb69a4c5_1
   - mpmath=1.3.0=py310h06a4308_0
+  - mysql=8.4.0=h0bac5ae_0
   - ncurses=6.4=h6a678d5_0
   - nettle=3.7.3=hbbd107a_1
   - networkx=3.3=py310h06a4308_0
   - ninja=1.12.1=h06a4308_0
   - ninja-base=1.12.1=hdb19cb5_0
+  - numexpr=2.8.4=py310h8879344_0
   - numpy=1.24.3=py310hd5efca6_0
   - numpy-base=1.24.3=py310h8e6c178_0
   - omegaconf=2.3.0=pyhd8ed1ab_0
   - openh264=2.1.1=h4ff587b_0
   - openjpeg=2.5.2=he7f1fd0_0
-  - openssl=3.0.15=h5eee18b_0
+  - openldap=2.6.4=h42fbc30_0
+  - openssl=3.4.0=hb9d3cd8_0
+  - pandas=2.2.2=py310h6a678d5_0
+  - pcre2=10.42=hebb0a14_1
   - pillow=10.4.0=py310h5eee18b_0
   - pip=24.2=py310h06a4308_0
+  - platformdirs=3.10.0=py310h06a4308_0
+  - ply=3.11=py310h06a4308_0
+  - pooch=1.8.2=py310h06a4308_0
   - portalocker=2.3.0=py310h06a4308_1
   - pycparser=2.21=pyhd3eb1b0_0
+  - pyparsing=3.2.0=py310h06a4308_0
+  - pyqt=5.15.10=py310h6a678d5_0
+  - pyqt5-sip=12.13.0=py310h5eee18b_0
   - pysocks=1.7.1=py310h06a4308_0
   - python=3.10.15=he870216_1
+  - python-dateutil=2.9.0post0=py310h06a4308_2
+  - python-tzdata=2023.3=pyhd3eb1b0_0
   - python_abi=3.10=2_cp310
   - pytorch=2.0.0=py3.10_cuda11.7_cudnn8.5.0_0
   - pytorch-cuda=11.7=h778d358_5
   - pytorch-mutex=1.0=cuda
   - pyyaml=6.0.2=py310h5eee18b_0
+  - qt-main=5.15.2=hb6262e9_11
   - readline=8.2=h5eee18b_0
   - requests=2.32.3=py310h06a4308_1
+  - scikit-learn=1.5.1=py310h1128e8f_0
+  - seaborn=0.13.2=py310h06a4308_0
   - setuptools=75.1.0=py310h06a4308_0
+  - sip=6.7.12=py310h6a678d5_0
   - six=1.16.0=pyhd3eb1b0_1
   - sqlite=3.45.3=h5eee18b_0
   - sympy=1.13.3=pyh2585a3b_104
   - tabulate=0.9.0=py310h06a4308_0
   - termcolor=2.1.0=py310h06a4308_0
+  - threadpoolctl=3.5.0=py310h2f386ee_0
   - tk=8.6.14=h39e8969_0
+  - tomli=2.0.1=py310h06a4308_0
   - torchaudio=2.0.0=py310_cu117
   - torchtriton=2.0.0=py310
   - torchvision=0.15.0=py310_cu117
+  - tornado=6.4.1=py310h5eee18b_0
   - tqdm=4.66.5=py310h2f386ee_0
   - typing-extensions=4.11.0=py310h06a4308_0
   - typing_extensions=4.11.0=py310h06a4308_0
+  - unicodedata2=15.1.0=py310h5eee18b_0
   - urllib3=2.2.3=py310h06a4308_0
   - wheel=0.44.0=py310h06a4308_0
   - xformers=0.0.19=py310_cu11.8.0_pyt2.0.0
@@ -137,8 +200,8 @@ dependencies:
       - distributed-ucxx-cu11==0.40.0
       - fastrlock==0.8.2
       - fsspec==2024.10.0
+      - gapstatistics==0.1.5
       - importlib-metadata==8.5.0
-      - joblib==1.4.2
       - libcudf-cu11==24.10.1
       - libucx-cu11==1.17.0
       - libucxx-cu11==0.40.0
@@ -150,7 +213,6 @@ dependencies:
       - numba==0.60.0
       - nvtx==0.2.10
       - packaging==24.2
-      - pandas==2.2.2
       - partd==1.4.2
       - psutil==6.1.0
       - ptxcompiler-cu11==0.8.1.post2
@@ -159,7 +221,6 @@ dependencies:
       - pylibcudf-cu11==24.10.1
       - pylibraft-cu11==24.10.0
       - pynvml==11.4.1
-      - python-dateutil==2.9.0.post0
       - pytz==2024.2
       - raft-dask-cu11==24.10.0
       - rapids-dask-dependency==24.10.0
@@ -170,7 +231,6 @@ dependencies:
       - submitit==1.5.2
       - tblib==3.0.0
       - toolz==1.0.0
-      - tornado==6.4.1
       - treelite==4.3.0
       - tzdata==2024.2
       - ucx-py-cu11==0.40.0