diff --git a/exp2_ms_with_colorMNIST/exp2l_bias_data_bg_colour_sender_unbias_receiver_l1_reg.ipynb b/exp2_ms_with_colorMNIST/exp2l_bias_data_bg_colour_sender_unbias_receiver_l1_reg.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..780ad6cea62e2cda9657f04d50d92f5556e7a545
--- /dev/null
+++ b/exp2_ms_with_colorMNIST/exp2l_bias_data_bg_colour_sender_unbias_receiver_l1_reg.ipynb
@@ -0,0 +1,650 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "46e81ec0-e3c4-4060-97c1-bfccf47bc350",
+   "metadata": {
+    "jp-MarkdownHeadingCollapsed": true
+   },
+   "source": [
+    "# Exp 2k: Bias data, unbias trained receiver, bg_unbiased sender network\n",
+    "bg_unbiased means digit and bg_colour are unrelated, but the model is trained and tested on the background.\n",
+    "\n",
+    "Now we will \n",
+    "only use unbias-trained network as receiver\n",
+    "only use bias data to train the stitch\n",
+    "Try the sender networks at all different stitch levels\n",
+    "## Rank\n",
+    "Also perform rank analysis on the stitched networks based on exp1e\n",
+    "## 4 Epochs\n",
+    "Only do 4 epochs of training (keep 10 epochs of stitch training) so that the initial models are weaker"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "51fc60f8-f6f6-469c-8705-c9015bd43951",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Packages\n",
+    "%matplotlib inline\n",
+    "\n",
+    "import argparse\n",
+    "import gc\n",
+    "import os.path\n",
+    "\n",
+    "import pandas as pd\n",
+    "from torch.linalg import LinAlgError\n",
+    "\n",
+    "import matplotlib.pyplot as plt\n",
+    "import torchvision\n",
+    "import torch\n",
+    "from torch import optim\n",
+    "\n",
+    "from torch import nn\n",
+    "from torch.utils.data import DataLoader\n",
+    "from torchvision.datasets import MNIST\n",
+    "import torchvision.transforms as transforms\n",
+    "import datetime\n",
+    "\n",
+    "import random\n",
+    "import numpy as np\n",
+    "\n",
+    "import sys\n",
+    "import os\n",
+    "# add the path to find colour_mnist\n",
+    "sys.path.append(os.path.abspath('../ReferenceCode'))\n",
+    "import colour_mnist\n",
+    "from stitch_utils import train_model, RcvResNet18, StitchedResNet18, get_layer_output_shape\n",
+    "from stitch_utils import generate_activations, SyntheticDataset\n",
+    "import stitch_utils\n",
+    "\n",
+    "# add the path to find the rank analysis code\n",
+    "# https://github.com/DHLSmith/jons-tunnel-effect/tree/NeurIPSPaper\n",
+    "sys.path.append(os.path.abspath('../../jons-tunnel-effect/'))\n",
+    "from utils.modelfitting import evaluate_model, set_seed\n",
+    "from extract_weight_rank import install_hooks, perform_analysis\n",
+    "\n",
+    "import torchvision\n",
+    "import torchvision.transforms as transforms\n",
+    "from torchvision.datasets import MNIST\n",
+    "\n",
+    "# To track memory usage\n",
+    "import psutil\n",
+    "process = psutil.Process()\n",
+    "            \n",
+    "\n",
+    "def logtofile(log_text, verbose=True):\n",
+    "    if verbose:\n",
+    "        print(log_text)\n",
+    "    with open(save_log_as, \"a\") as f:    \n",
+    "        print(log_text, file=f)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6761870b-2996-4763-8d90-76529ec5822e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Set Parameters\n",
+    "\n",
+    "# fix random seed for reproducibility\n",
+    "seed = 57\n",
+    "torch.manual_seed(seed)\n",
+    "torch.backends.cudnn.benchmark = False\n",
+    "torch.backends.cudnn.deterministic = True\n",
+    "random.seed(seed)\n",
+    "torch.cuda.manual_seed(seed)\n",
+    "np.random.seed(seed)\n",
+    "\n",
+    "results_root = \"results_2k\"\n",
+    "train_all = False  # Just use pretrained model\n",
+    "\n",
+    "# BG_UNBIASED is digits with randomly selected colour background. Targets represent the colour\n",
+    "train_bg_unbiased_colour_mnist_model = train_all  # when False, automatically loads a trained model\n",
+    "bg_unbiased_colour_mnist_model_to_load = \"./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_bg_unbiased_colour_mnist.weights\"\n",
+    "\n",
+    "# UNBIASED is digits with randoly selected colour background. Targets are digit values\n",
+    "train_unbiased_colour_mnist_model = train_all  # when False, automatically loads a trained model\n",
+    "unbiased_colour_mnist_model_to_load = \"./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist.weights\"\n",
+    "\n",
+    "original_train_epochs = 4\n",
+    "bg_noise = 0.1\n",
+    "\n",
+    "stitch_train_epochs = 10\n",
+    "\n",
+    "batch_size = 128"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "20ee6e98-a6f2-4647-b378-5f7b1af48581",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Generate filenames and log the setup details\n",
+    "formatted_time = datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n",
+    "filename_prefix = f\"./{results_root}/{formatted_time}_SEED{seed}_EPOCHS{original_train_epochs}_BGN{bg_noise}_exp2e_ResNet18\"\n",
+    "#save_mix_mnist_model_as = f\"{filename_prefix}_mix_mnist.weights\"\n",
+    "#save_bw_mnist_model_as = f\"{filename_prefix}_bw_mnist.weights\"\n",
+    "#save_bg_only_colour_mnist_model_as = f\"{filename_prefix}_bg_only_colour_mnist.weights\"\n",
+    "save_bg_unbiased_colour_mnist_model_as = f\"{filename_prefix}_bg_unbiased_colour_mnist.weights\"\n",
+    "#save_biased_colour_mnist_model_as = f\"{filename_prefix}_biased_colour_mnist.weights\"\n",
+    "save_unbiased_colour_mnist_model_as = f\"{filename_prefix}_unbiased_colour_mnist.weights\"\n",
+    "save_log_as = f\"{filename_prefix}_log.txt\"\n",
+    "\n",
+    "colour_mnist_shape = (3,28,28)\n",
+    "\n",
+    "\n",
+    "logtofile(f\"Executed at {formatted_time}\")\n",
+    "logtofile(f\"logging to {save_log_as}\")\n",
+    "logtofile(f\"{seed=}\")\n",
+    "logtofile(f\"{bg_noise=}\")\n",
+    "\n",
+    "logtofile(f\"{train_bg_unbiased_colour_mnist_model=}\")\n",
+    "if train_bg_unbiased_colour_mnist_model:\n",
+    "    logtofile(f\"{save_bg_unbiased_colour_mnist_model_as=}\")\n",
+    "    logtofile(f\"{original_train_epochs=}\")\n",
+    "else:\n",
+    "    logtofile(f\"{bg_unbiased_colour_mnist_model_to_load=}\")\n",
+    "\n",
+    "logtofile(f\"{train_unbiased_colour_mnist_model=}\")\n",
+    "if train_unbiased_colour_mnist_model:\n",
+    "    logtofile(f\"{save_unbiased_colour_mnist_model_as=}\")\n",
+    "    logtofile(f\"{original_train_epochs=}\")\n",
+    "else:\n",
+    "    logtofile(f\"{unbiased_colour_mnist_model_to_load=}\")\n",
+    "\n",
+    "logtofile(f\"{stitch_train_epochs=}\")\n",
+    "logtofile(f\"================================================\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "6cc621f6-55e8-4191-8288-dc2493cd6bff",
+   "metadata": {},
+   "source": [
+    "mnist and cifar-10 both use 10-classes, with 60_000 train samples and 10_000 test samples. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d34a54d2-c8fa-4f51-8809-7a40b4fefc6c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "# unbiased means each digit has correct label and random colour - but bg means we will use colour as label (i.e. the bias_target will be the target)\n",
+    "bg_unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root=\"./MNIST\", batch_size=batch_size, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, bias_targets_as_targets=True)\n",
+    "bg_unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root=\"./MNIST\", batch_size=batch_size, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, bias_targets_as_targets=True)\n",
+    "\n",
+    "# biased means each digit has correct label and consistent colour - Expect network to learn the colours only\n",
+    "biased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root=\"./MNIST\", batch_size=batch_size, data_label_correlation=1.0, train=True, bg_noise_level=bg_noise, standard_getitem=True)\n",
+    "biased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root=\"./MNIST\", batch_size=batch_size, data_label_correlation=1.0, train=False, bg_noise_level=bg_noise, standard_getitem=True)\n",
+    "\n",
+    "# unbiased means each digit has correct label and random colour - Expect network to disregard colours?\n",
+    "unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root=\"./MNIST\", batch_size=batch_size, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, standard_getitem=True)\n",
+    "unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root=\"./MNIST\", batch_size=batch_size, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, standard_getitem=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "24842d82-5f62-4d1b-acc5-d1997a08b0b9",
+   "metadata": {},
+   "source": [
+    "## Set up resnet18 models and train it on versions of MNIST"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "cf635cfd-a9ad-4e37-98a0-80d0db2a3b9f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "process_structure = dict()\n",
+    "device = 'cuda:0'\n",
+    "\n",
+    "\n",
+    "process_structure[\"bg\"] = dict()\n",
+    "process_structure[\"unbias\"]    = dict()\n",
+    "\n",
+    "# \"bg_unbiased_colour\"\n",
+    "process_structure[\"bg\"][\"model\"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model\n",
+    "process_structure[\"bg\"][\"train\"] = train_bg_unbiased_colour_mnist_model \n",
+    "process_structure[\"bg\"][\"train_loader\"] = bg_unbiased_train_dataloader\n",
+    "process_structure[\"bg\"][\"test_loader\"] = bg_unbiased_test_dataloader\n",
+    "process_structure[\"bg\"][\"saveas\"] = save_bg_unbiased_colour_mnist_model_as\n",
+    "process_structure[\"bg\"][\"loadfrom\"] = bg_unbiased_colour_mnist_model_to_load\n",
+    "\n",
+    "# \"unbiased_colour_mnist\"\n",
+    "process_structure[\"unbias\"][\"model\"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model\n",
+    "process_structure[\"unbias\"][\"train\"] = train_unbiased_colour_mnist_model\n",
+    "process_structure[\"unbias\"][\"train_loader\"] = unbiased_train_dataloader\n",
+    "process_structure[\"unbias\"][\"test_loader\"] = unbiased_test_dataloader\n",
+    "process_structure[\"unbias\"][\"saveas\"] = save_unbiased_colour_mnist_model_as\n",
+    "process_structure[\"unbias\"][\"loadfrom\"] =  unbiased_colour_mnist_model_to_load\n",
+    "\n",
+    "for key, val in process_structure.items():\n",
+    "    print(f\"Processing for {key=}\")\n",
+    "    if val[\"train\"]:\n",
+    "        train_model(model=val[\"model\"], train_loader=val[\"train_loader\"], \n",
+    "                    epochs=original_train_epochs, saveas=val[\"saveas\"], \n",
+    "                    description=key, device=device, logtofile=logtofile)\n",
+    "    else:\n",
+    "        logtofile(f\"{val['loadfrom']=}\")\n",
+    "        val[\"model\"].load_state_dict(torch.load(val[\"loadfrom\"], map_location=torch.device(device)))\n",
+    "    val[\"model\"].eval()\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "b169c5c7-7929-48e1-82eb-6f047aa4e5f2",
+   "metadata": {},
+   "source": [
+    "## Measure the Accuracy, Record the Confusion Matrix\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "4fbb925a-269d-4027-9500-6cdce4de9d70",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "logtofile(\"Entering Confusion\")\n",
+    "# logtofile(process.memory_info().rss)  # in bytes \n",
+    "\n",
+    "original_accuracy = dict()\n",
+    "for key, val in process_structure.items():\n",
+    "    logtofile(f\"Accuracy Calculation for ResNet18 with {key=}\")\n",
+    "    model = val[\"model\"]\n",
+    "    model.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS\n",
+    "    \n",
+    "    # Compute the model accuracy on the test set\n",
+    "    correct = 0\n",
+    "    total = 0\n",
+    "    \n",
+    "    # assuming 10 classes\n",
+    "    # rows represent actual class, columns are predicted\n",
+    "    confusion_matrix = torch.zeros(10,10, dtype=torch.int)\n",
+    "    \n",
+    "    TDL = biased_test_dataloader  # In test 2D - ALWAYS use biased dataset to measure/train stitch\n",
+    "    for data in TDL:\n",
+    "        inputs, labels = data\n",
+    "        inputs = inputs.to(device)\n",
+    "        labels = labels.to(device)\n",
+    "        predictions = torch.argmax(model(inputs),1)\n",
+    "        \n",
+    "        matches = predictions == labels\n",
+    "        correct += matches.sum().item()\n",
+    "        total += len(labels)\n",
+    "        for idx, l in enumerate(labels):\n",
+    "            confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] \n",
+    "    \n",
+    "    logtofile(\"Test the Trained Resnet18 against BIASED TEST DATALOADER\")\n",
+    "    acc = ((100.0 * correct) / total)\n",
+    "    logtofile('Test Accuracy: %2.2f %%' % acc)\n",
+    "    original_accuracy[key] = acc\n",
+    "    logtofile('Confusion Matrix')\n",
+    "    logtofile(confusion_matrix)\n",
+    "    logtofile(confusion_matrix.sum())\n",
+    "    # logtofile(process.memory_info().rss)  # in bytes \n",
+    "\n",
+    "\n",
+    "logtofile(f\"{original_accuracy=}\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "da373b34-fe35-4256-a9e0-1040f699d45d",
+   "metadata": {},
+   "source": [
+    "## Measure Rank with __biased__ dataloader (test) before cutting and stitching"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c4f74591-2a3f-4521-8aec-32c324125a5b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "logtofile(\"Entering whole model check\")\n",
+    "# logtofile(process.memory_info().rss)  # in bytes \n",
+    "\n",
+    "# For the Whole Model - but we will pass it through the RcvResNet18 function to get matching feature names\n",
+    "for key, val in process_structure.items():\n",
+    "    \n",
+    "    TDL = biased_test_dataloader  # ALWAYS use biased dataloader for this test\n",
+    "    \n",
+    "    if val[\"train\"]:\n",
+    "        filename = val[\"saveas\"] \n",
+    "    else:    \n",
+    "        filename = val[\"loadfrom\"] \n",
+    "    assert os.path.exists(filename)\n",
+    "    mdl = torchvision.models.resnet18(num_classes=10) # Untrained model\n",
+    "    state = torch.load(filename, map_location=torch.device(\"cpu\"))\n",
+    "    mdl.load_state_dict(state, assign=True)\n",
+    "    mdl=mdl.to(device)\n",
+    "    mdl = RcvResNet18(mdl, -1, colour_mnist_shape, device).to(device)\n",
+    "\n",
+    "    out_filename = filename.split('/')[-1].replace('.weights', '-test.csv')\n",
+    "    \n",
+    "    outpath = f\"./{results_root}_rank/{key}-bias-{seed}_{out_filename}\"  # denote output name as <model_training_type>-dataset-<name>\n",
+    "    \n",
+    "    if os.path.exists(f\"{outpath}\"):\n",
+    "        logtofile(f\"Already evaluated for {outpath}\")\n",
+    "        continue\n",
+    "    logtofile(f\"Measure Rank for {key=}\")\n",
+    "    print(f\"output to {outpath}\")\n",
+    "            \n",
+    "    params = {}\n",
+    "    params[\"model\"] = key\n",
+    "    params[\"dataset\"] = \"bias\"\n",
+    "    params[\"seed\"] = seed\n",
+    "    if val[\"train\"]: # as only one network used, record its filename as both send and receive files\n",
+    "        params[\"send_file\"] = val[\"saveas\"] \n",
+    "        params[\"rcv_file\"] = val[\"saveas\"] \n",
+    "    else:    \n",
+    "        params[\"send_file\"] = val[\"loadfrom\"] \n",
+    "        params[\"rcv_file\"] = val[\"loadfrom\"]     \n",
+    "    \n",
+    "    with torch.no_grad():\n",
+    "        layers, features, handles = install_hooks(mdl)\n",
+    "        \n",
+    "        metrics = evaluate_model(mdl, TDL, 'acc', verbose=2)\n",
+    "        params.update(metrics)\n",
+    "        classes = None\n",
+    "        df = perform_analysis(features, classes, layers, params, n=-1)\n",
+    "        df.to_csv(f\"{outpath}\")\n",
+    "    for h in handles:\n",
+    "        h.remove()\n",
+    "    del mdl, layers, features, metrics, params, df, handles\n",
+    "    gc.collect()\n",
+    "    # logtofile(process.memory_info().rss)  # in bytes \n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "de2928d3-9c5f-411f-a590-14fd04026ab6",
+   "metadata": {},
+   "source": [
+    "# Stitch at a given layer\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "3cbaf773-ed43-4d91-b0a0-a35fd468ac02",
+   "metadata": {},
+   "source": [
+    "## Train the stitch layer and check rank"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "2c29d808-b86e-4a0c-a3c5-127c952f3ab9",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "logtofile(\"Entering Stitch/Rank\")\n",
+    "# logtofile(process.memory_info().rss)  # in bytes \n",
+    "\n",
+    "logtofile(f\"{device=}\")\n",
+    "stitching_accuracies = dict()\n",
+    "stitching_penalties = dict()\n",
+    "# NOTE this is only valid as all models are the same architecture\n",
+    "num_layers_in_model = len(list(process_structure[\"bg\"][\"model\"].children()))  \n",
+    "for send_key, send_val in process_structure.items():\n",
+    "    if (send_key != \"bg\"):\n",
+    "        logtofile(f\"NOTE: Only running stitch with bg send model\")\n",
+    "        continue\n",
+    "    stitching_accuracies[send_key] = dict()\n",
+    "    stitching_penalties[send_key] = dict()\n",
+    "    \n",
+    "    for rcv_key, rcv_val in process_structure.items():        \n",
+    "        if (rcv_key != \"unbias\"):\n",
+    "            logtofile(f\"NOTE: Only running stitch with unbias receive model\")\n",
+    "            continue       \n",
+    "            \n",
+    "        stitching_accuracies[send_key][rcv_key] = dict()\n",
+    "        stitching_penalties[send_key][rcv_key] = dict()\n",
+    "        for layer_to_cut_after in range(3,num_layers_in_model - 1):\n",
+    "            # for consistency, use the rcv network for the filename stem.\n",
+    "            if rcv_val[\"train\"]:\n",
+    "                filename = rcv_val[\"saveas\"] \n",
+    "            else:    \n",
+    "                filename = rcv_val[\"loadfrom\"] \n",
+    "            \n",
+    "            rank_filename = filename.split('/')[-1].replace('.weights', '-test.csv')        \n",
+    "            # denote output name as <model_training_type>-dataset-<name>\n",
+    "            # where <model_training_type> is [sender_model or X][layer_to_cut_after][Receiver_model]\n",
+    "            model_training_type = f\"{send_key}{layer_to_cut_after}{rcv_key}\"\n",
+    "            dataset_type = \"bias\"  # ALWAYS use bias dataset in this test\n",
+    "            outpath = f\"./{results_root}_rank/{model_training_type}-{dataset_type}-{seed}_{rank_filename}\"  \n",
+    "                            \n",
+    "            if os.path.exists(f\"{outpath}\"):\n",
+    "                logtofile(f\"Already evaluated for {outpath}\")\n",
+    "                continue\n",
+    "            ####################################################################################\n",
+    "            logtofile(f\"Evaluate ranks and output to {outpath}\")\n",
+    "            # logtofile(process.memory_info().rss)  # in bytes \n",
+    "\n",
+    "            logtofile(f\"Train the stitch to a model stitched after layer {layer_to_cut_after} from {send_key} to {rcv_key}\")    \n",
+    "            logtofile(f\"Use the biased data loader (train and test) regardless of what {rcv_key} was trained on\")\n",
+    "            \n",
+    "            # train a stitch on the unbiased_colour dataset to compare receiver network performance with stitched\n",
+    "            model_stitched = StitchedResNet18(send_model=send_val[\"model\"], \n",
+    "                                              after_layer_index=layer_to_cut_after, \n",
+    "                                              rcv_model=rcv_val[\"model\"],\n",
+    "                                              input_image_shape=colour_mnist_shape, device=device  ).to(device)\n",
+    "                        \n",
+    "            #############################################################\n",
+    "            # store the initial stitch state\n",
+    "            initial_stitch_weight = model_stitched.stitch.s_conv1.weight.clone()\n",
+    "            initial_stitch_bias   = model_stitched.stitch.s_conv1.bias.clone()\n",
+    "            stitch_initial_weight_outpath    = f\"./{results_root}/STITCH_initial_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}\"  \n",
+    "            stitch_initial_bias_outpath      = f\"./{results_root}/STITCH_initial_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}\"  \n",
+    "            torch.save(initial_stitch_weight, stitch_initial_weight_outpath)\n",
+    "            torch.save(initial_stitch_bias, stitch_initial_bias_outpath)\n",
+    "            ############################################################\n",
+    "                    \n",
+    "            # define the loss function and the optimiser\n",
+    "            loss_function = nn.CrossEntropyLoss()\n",
+    "            # Hernandez said : momentum 0.9, batch size 256, weight decay 0.01, learning rate 0.01, and a post-warmup cosine learning rate scheduler.\n",
+    "            # optimiser = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)\n",
+    "            optimiser = optim.SGD(model_stitched.parameters(), lr=1e-4, momentum=0.9, weight_decay=0.01)\n",
+    "            \n",
+    "            # Put top model into train mode so that bn and dropout perform in training mode\n",
+    "            model_stitched.train()\n",
+    "            # Freeze the whole model\n",
+    "            model_stitched.requires_grad_(False)\n",
+    "            # Un-Freeze the stitch layer\n",
+    "            for name, param in model_stitched.stitch.named_parameters():\n",
+    "                param.requires_grad_(True)\n",
+    "            # the epoch loop: note that we're training the whole network\n",
+    "            for epoch in range(stitch_train_epochs):\n",
+    "                running_loss = 0.0\n",
+    "                for data in biased_train_dataloader:\n",
+    "                    # data is (representations, labels) tuple\n",
+    "                    # get the inputs and put them on the GPU\n",
+    "                    inputs, labels = data\n",
+    "                    inputs = inputs.to(device)\n",
+    "                    labels = labels.to(device)\n",
+    "            \n",
+    "                    # zero the parameter gradients\n",
+    "                    optimiser.zero_grad()\n",
+    "            \n",
+    "                    # forward + loss + backward + optimise (update weights)\n",
+    "                    outputs = model_stitched(inputs)\n",
+    "                    loss = loss_function(outputs, labels)\n",
+    "                    lambda_reg = 0.01\n",
+    "                    l1_norm = sum(p.abs().sum() for p in model_stitched.stitch.parameters())\n",
+    "                    loss += lambda_reg * l1_norm\n",
+    "                    loss.backward()\n",
+    "                    optimiser.step()\n",
+    "            \n",
+    "                    # keep track of the loss this epoch\n",
+    "                    running_loss += loss.item()\n",
+    "                logtofile(\"Epoch %d, loss %4.2f\" % (epoch, running_loss))\n",
+    "                # logtofile(process.memory_info().rss)  # in bytes \n",
+    "\n",
+    "            logtofile('**** Finished Training ****')\n",
+    "            \n",
+    "            model_stitched.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS\n",
+    "\n",
+    "            ############################################################\n",
+    "            # store the trained stitch\n",
+    "            trained_stitch_weight = model_stitched.stitch.s_conv1.weight.clone()\n",
+    "            trained_stitch_bias   = model_stitched.stitch.s_conv1.bias.clone()\n",
+    "            stitch_trained_weight_outpath    = f\"./{results_root}/STITCH_trained_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}\"  \n",
+    "            stitch_trained_bias_outpath      = f\"./{results_root}/STITCH_trained_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}\"  \n",
+    "            torch.save(trained_stitch_weight, stitch_trained_weight_outpath)\n",
+    "            torch.save(trained_stitch_bias, stitch_trained_bias_outpath)\n",
+    "                       \n",
+    "            stitch_weight_diff = trained_stitch_weight - initial_stitch_weight\n",
+    "            stitch_weight_delta = torch.linalg.norm(stitch_weight_diff).item()\n",
+    "            logtofile(f\"Change in stitch weights: {stitch_weight_delta}\")\n",
+    "            maxabsweight =  torch.max(stitch_weight_diff.abs()).item()\n",
+    "            logtofile(f\"Largest abs weight change: {maxabsweight}\")\n",
+    "            stitch_weight_number = torch.sum(torch.where(stitch_weight_diff.abs() > 0.1*maxabsweight, True, False)).item()\n",
+    "            logtofile(f\"Number of weights changing > 0.1 of that: {stitch_weight_number}\")\n",
+    "\n",
+    "            \n",
+    "            print(f\"Number of weight / bias in stitch layer is {len(initial_stitch_weight)}\")\n",
+    "            stitch_bias_diff = trained_stitch_bias - initial_stitch_bias\n",
+    "            stitch_bias_delta = torch.linalg.norm(stitch_bias_diff).item()\n",
+    "            logtofile(f\"Change in stitch bias: {stitch_bias_delta}\")\n",
+    "            maxabsbias =  torch.max(stitch_bias_diff.abs()).item()\n",
+    "            logtofile(f\"Largest abs bias change: {maxabsbias}\")\n",
+    "            stitch_bias_number = torch.sum(torch.where(stitch_bias_diff.abs() > 0.1*maxabsbias, True, False)).item()\n",
+    "            logtofile(f\"Number of bias changing > 0.1 of that: {stitch_bias_number}\")\n",
+    "            ##############################################################\n",
+    "\n",
+    "            \n",
+    "            # Compute the model accuracy on the test set\n",
+    "            correct = 0\n",
+    "            total = 0\n",
+    "            \n",
+    "            # assuming 10 classes\n",
+    "            # rows represent actual class, columns are predicted\n",
+    "            confusion_matrix = torch.zeros(10,10, dtype=torch.int)\n",
+    "            \n",
+    "            for data in biased_test_dataloader:  # Only use biased test data\n",
+    "                inputs, labels = data\n",
+    "                inputs = inputs.to(device)\n",
+    "                labels = labels.to(device)\n",
+    "                \n",
+    "                predictions = torch.argmax(model_stitched(inputs),1)\n",
+    "                matches = predictions == labels.to(device)\n",
+    "                correct += matches.sum().item()\n",
+    "                total += len(labels)\n",
+    "            \n",
+    "                for idx, l in enumerate(labels):\n",
+    "                    confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] \n",
+    "            logtofile(\"Test the trained stitch against biased data\")    \n",
+    "            acc =  ((100.0 * correct) / total)\n",
+    "            logtofile('Test Accuracy: %2.2f %%' % acc)\n",
+    "            logtofile('Confusion Matrix')\n",
+    "            logtofile(confusion_matrix)\n",
+    "            logtofile(\"===================================================================\")\n",
+    "            # logtofile(process.memory_info().rss)  # in bytes \n",
+    "\n",
+    "            # Stitching penalty should be negative if there is an improvement, and is relative to the original receiver network\n",
+    "            stitching_accuracies[send_key][rcv_key][layer_to_cut_after] = acc\n",
+    "            stitching_penalties[send_key][rcv_key][layer_to_cut_after] = original_accuracy[rcv_key] - acc\n",
+    "\n",
+    "            TDL = biased_test_dataloader\n",
+    "            params = {}\n",
+    "            params[\"model\"] = model_training_type # a mnemonic\n",
+    "            params[\"dataset\"] = dataset_type\n",
+    "            params[\"seed\"] = seed\n",
+    "            if send_val[\"train\"]:\n",
+    "                params[\"send_file\"] = send_val[\"saveas\"] \n",
+    "            else:    \n",
+    "                params[\"send_file\"] = send_val[\"loadfrom\"] \n",
+    "            if rcv_val[\"train\"]:\n",
+    "                params[\"rcv_file\"] = rcv_val[\"saveas\"] \n",
+    "            else:    \n",
+    "                params[\"rcv_file\"] = rcv_val[\"loadfrom\"] \n",
+    "            params[\"stitch_weight_delta\"] = stitch_weight_delta\n",
+    "            params[\"stitch_bias_delta\"] = stitch_bias_delta        \n",
+    "            params[\"stitch_weight_number\"] = stitch_weight_number\n",
+    "            params[\"stitch_bias_number\"] = stitch_bias_number\n",
+    "            # logtofile(process.memory_info().rss)  # in bytes \n",
+    "            with torch.no_grad():\n",
+    "                layers, features, handles = install_hooks(model_stitched)                \n",
+    "                metrics = evaluate_model(model_stitched, TDL, 'acc', verbose=2)\n",
+    "                params.update(metrics)\n",
+    "                classes = None\n",
+    "                df = perform_analysis(features, classes, layers, params, n=-1)\n",
+    "                df.to_csv(f\"{outpath}\")\n",
+    "                \n",
+    "            for h in handles:\n",
+    "                h.remove()\n",
+    "            del model_stitched, layers, features, metrics, params, df, handles\n",
+    "            gc.collect()\n",
+    "            # logtofile(process.memory_info().rss)  # in bytes \n",
+    "\n",
+    "            "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "3db5808a-6351-4133-9e7c-85e8e9248cfb",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "logtofile(f\"{stitching_accuracies=}\")\n",
+    "logtofile(f\"{stitching_penalties=}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "cd669a19-7c42-4265-b8e8-57165995c097",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "for s_key in stitching_accuracies:    \n",
+    "    for r_key in stitching_accuracies[s_key]:\n",
+    "        logtofile(f\"{s_key}-{r_key}\")\n",
+    "        logtofile(f\"{original_accuracy[r_key]=}\")\n",
+    "        logtofile(\"Stitch Accuracy\")\n",
+    "        for layer in stitching_accuracies[s_key][r_key]:\n",
+    "            logtofile(f\"L{layer}: {stitching_accuracies[s_key][r_key][layer]}\")\n",
+    "        logtofile(\"--------------------------\")"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "affectnet_env",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.12.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}