diff --git a/exp2_ms_with_colorMNIST/exp2l_bias_data_bg_colour_sender_unbias_receiver_l1_reg.ipynb b/exp2_ms_with_colorMNIST/exp2l_bias_data_bg_colour_sender_unbias_receiver_l1_reg.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..780ad6cea62e2cda9657f04d50d92f5556e7a545 --- /dev/null +++ b/exp2_ms_with_colorMNIST/exp2l_bias_data_bg_colour_sender_unbias_receiver_l1_reg.ipynb @@ -0,0 +1,650 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "46e81ec0-e3c4-4060-97c1-bfccf47bc350", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "# Exp 2k: Bias data, unbias trained receiver, bg_unbiased sender network\n", + "bg_unbiased means digit and bg_colour are unrelated, but the model is trained and tested on the background.\n", + "\n", + "Now we will \n", + "only use unbias-trained network as receiver\n", + "only use bias data to train the stitch\n", + "Try the sender networks at all different stitch levels\n", + "## Rank\n", + "Also perform rank analysis on the stitched networks based on exp1e\n", + "## 4 Epochs\n", + "Only do 4 epochs of training (keep 10 epochs of stitch training) so that the initial models are weaker" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51fc60f8-f6f6-469c-8705-c9015bd43951", + "metadata": {}, + "outputs": [], + "source": [ + "# Packages\n", + "%matplotlib inline\n", + "\n", + "import argparse\n", + "import gc\n", + "import os.path\n", + "\n", + "import pandas as pd\n", + "from torch.linalg import LinAlgError\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import torchvision\n", + "import torch\n", + "from torch import optim\n", + "\n", + "from torch import nn\n", + "from torch.utils.data import DataLoader\n", + "from torchvision.datasets import MNIST\n", + "import torchvision.transforms as transforms\n", + "import datetime\n", + "\n", + "import random\n", + "import numpy as np\n", + "\n", + "import sys\n", + "import os\n", + "# add the path to find colour_mnist\n", + "sys.path.append(os.path.abspath('../ReferenceCode'))\n", + "import colour_mnist\n", + "from stitch_utils import train_model, RcvResNet18, StitchedResNet18, get_layer_output_shape\n", + "from stitch_utils import generate_activations, SyntheticDataset\n", + "import stitch_utils\n", + "\n", + "# add the path to find the rank analysis code\n", + "# https://github.com/DHLSmith/jons-tunnel-effect/tree/NeurIPSPaper\n", + "sys.path.append(os.path.abspath('../../jons-tunnel-effect/'))\n", + "from utils.modelfitting import evaluate_model, set_seed\n", + "from extract_weight_rank import install_hooks, perform_analysis\n", + "\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "from torchvision.datasets import MNIST\n", + "\n", + "# To track memory usage\n", + "import psutil\n", + "process = psutil.Process()\n", + " \n", + "\n", + "def logtofile(log_text, verbose=True):\n", + " if verbose:\n", + " print(log_text)\n", + " with open(save_log_as, \"a\") as f: \n", + " print(log_text, file=f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6761870b-2996-4763-8d90-76529ec5822e", + "metadata": {}, + "outputs": [], + "source": [ + "# Set Parameters\n", + "\n", + "# fix random seed for reproducibility\n", + "seed = 57\n", + "torch.manual_seed(seed)\n", + "torch.backends.cudnn.benchmark = False\n", + "torch.backends.cudnn.deterministic = True\n", + "random.seed(seed)\n", + "torch.cuda.manual_seed(seed)\n", + "np.random.seed(seed)\n", + "\n", + "results_root = \"results_2k\"\n", + "train_all = False # Just use pretrained model\n", + "\n", + "# BG_UNBIASED is digits with randomly selected colour background. Targets represent the colour\n", + "train_bg_unbiased_colour_mnist_model = train_all # when False, automatically loads a trained model\n", + "bg_unbiased_colour_mnist_model_to_load = \"./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_bg_unbiased_colour_mnist.weights\"\n", + "\n", + "# UNBIASED is digits with randoly selected colour background. Targets are digit values\n", + "train_unbiased_colour_mnist_model = train_all # when False, automatically loads a trained model\n", + "unbiased_colour_mnist_model_to_load = \"./results_4_epochs/2024-08-02_11-10-38_SEED57_EPOCHS4_BGN0.1_exp2d_ResNet18_unbiased_colour_mnist.weights\"\n", + "\n", + "original_train_epochs = 4\n", + "bg_noise = 0.1\n", + "\n", + "stitch_train_epochs = 10\n", + "\n", + "batch_size = 128" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20ee6e98-a6f2-4647-b378-5f7b1af48581", + "metadata": {}, + "outputs": [], + "source": [ + "# Generate filenames and log the setup details\n", + "formatted_time = datetime.datetime.now().strftime(\"%Y-%m-%d_%H-%M-%S\")\n", + "filename_prefix = f\"./{results_root}/{formatted_time}_SEED{seed}_EPOCHS{original_train_epochs}_BGN{bg_noise}_exp2e_ResNet18\"\n", + "#save_mix_mnist_model_as = f\"{filename_prefix}_mix_mnist.weights\"\n", + "#save_bw_mnist_model_as = f\"{filename_prefix}_bw_mnist.weights\"\n", + "#save_bg_only_colour_mnist_model_as = f\"{filename_prefix}_bg_only_colour_mnist.weights\"\n", + "save_bg_unbiased_colour_mnist_model_as = f\"{filename_prefix}_bg_unbiased_colour_mnist.weights\"\n", + "#save_biased_colour_mnist_model_as = f\"{filename_prefix}_biased_colour_mnist.weights\"\n", + "save_unbiased_colour_mnist_model_as = f\"{filename_prefix}_unbiased_colour_mnist.weights\"\n", + "save_log_as = f\"{filename_prefix}_log.txt\"\n", + "\n", + "colour_mnist_shape = (3,28,28)\n", + "\n", + "\n", + "logtofile(f\"Executed at {formatted_time}\")\n", + "logtofile(f\"logging to {save_log_as}\")\n", + "logtofile(f\"{seed=}\")\n", + "logtofile(f\"{bg_noise=}\")\n", + "\n", + "logtofile(f\"{train_bg_unbiased_colour_mnist_model=}\")\n", + "if train_bg_unbiased_colour_mnist_model:\n", + " logtofile(f\"{save_bg_unbiased_colour_mnist_model_as=}\")\n", + " logtofile(f\"{original_train_epochs=}\")\n", + "else:\n", + " logtofile(f\"{bg_unbiased_colour_mnist_model_to_load=}\")\n", + "\n", + "logtofile(f\"{train_unbiased_colour_mnist_model=}\")\n", + "if train_unbiased_colour_mnist_model:\n", + " logtofile(f\"{save_unbiased_colour_mnist_model_as=}\")\n", + " logtofile(f\"{original_train_epochs=}\")\n", + "else:\n", + " logtofile(f\"{unbiased_colour_mnist_model_to_load=}\")\n", + "\n", + "logtofile(f\"{stitch_train_epochs=}\")\n", + "logtofile(f\"================================================\")" + ] + }, + { + "cell_type": "markdown", + "id": "6cc621f6-55e8-4191-8288-dc2493cd6bff", + "metadata": {}, + "source": [ + "mnist and cifar-10 both use 10-classes, with 60_000 train samples and 10_000 test samples. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d34a54d2-c8fa-4f51-8809-7a40b4fefc6c", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# unbiased means each digit has correct label and random colour - but bg means we will use colour as label (i.e. the bias_target will be the target)\n", + "bg_unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root=\"./MNIST\", batch_size=batch_size, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, bias_targets_as_targets=True)\n", + "bg_unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root=\"./MNIST\", batch_size=batch_size, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, bias_targets_as_targets=True)\n", + "\n", + "# biased means each digit has correct label and consistent colour - Expect network to learn the colours only\n", + "biased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root=\"./MNIST\", batch_size=batch_size, data_label_correlation=1.0, train=True, bg_noise_level=bg_noise, standard_getitem=True)\n", + "biased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root=\"./MNIST\", batch_size=batch_size, data_label_correlation=1.0, train=False, bg_noise_level=bg_noise, standard_getitem=True)\n", + "\n", + "# unbiased means each digit has correct label and random colour - Expect network to disregard colours?\n", + "unbiased_train_dataloader = colour_mnist.get_biased_mnist_dataloader(root=\"./MNIST\", batch_size=batch_size, data_label_correlation=0.1, train=True, bg_noise_level=bg_noise, standard_getitem=True)\n", + "unbiased_test_dataloader = colour_mnist.get_biased_mnist_dataloader(root=\"./MNIST\", batch_size=batch_size, data_label_correlation=0.1, train=False, bg_noise_level=bg_noise, standard_getitem=True)" + ] + }, + { + "cell_type": "markdown", + "id": "24842d82-5f62-4d1b-acc5-d1997a08b0b9", + "metadata": {}, + "source": [ + "## Set up resnet18 models and train it on versions of MNIST" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf635cfd-a9ad-4e37-98a0-80d0db2a3b9f", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "process_structure = dict()\n", + "device = 'cuda:0'\n", + "\n", + "\n", + "process_structure[\"bg\"] = dict()\n", + "process_structure[\"unbias\"] = dict()\n", + "\n", + "# \"bg_unbiased_colour\"\n", + "process_structure[\"bg\"][\"model\"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model\n", + "process_structure[\"bg\"][\"train\"] = train_bg_unbiased_colour_mnist_model \n", + "process_structure[\"bg\"][\"train_loader\"] = bg_unbiased_train_dataloader\n", + "process_structure[\"bg\"][\"test_loader\"] = bg_unbiased_test_dataloader\n", + "process_structure[\"bg\"][\"saveas\"] = save_bg_unbiased_colour_mnist_model_as\n", + "process_structure[\"bg\"][\"loadfrom\"] = bg_unbiased_colour_mnist_model_to_load\n", + "\n", + "# \"unbiased_colour_mnist\"\n", + "process_structure[\"unbias\"][\"model\"] = torchvision.models.resnet18(num_classes=10).to(device) # Untrained model\n", + "process_structure[\"unbias\"][\"train\"] = train_unbiased_colour_mnist_model\n", + "process_structure[\"unbias\"][\"train_loader\"] = unbiased_train_dataloader\n", + "process_structure[\"unbias\"][\"test_loader\"] = unbiased_test_dataloader\n", + "process_structure[\"unbias\"][\"saveas\"] = save_unbiased_colour_mnist_model_as\n", + "process_structure[\"unbias\"][\"loadfrom\"] = unbiased_colour_mnist_model_to_load\n", + "\n", + "for key, val in process_structure.items():\n", + " print(f\"Processing for {key=}\")\n", + " if val[\"train\"]:\n", + " train_model(model=val[\"model\"], train_loader=val[\"train_loader\"], \n", + " epochs=original_train_epochs, saveas=val[\"saveas\"], \n", + " description=key, device=device, logtofile=logtofile)\n", + " else:\n", + " logtofile(f\"{val['loadfrom']=}\")\n", + " val[\"model\"].load_state_dict(torch.load(val[\"loadfrom\"], map_location=torch.device(device)))\n", + " val[\"model\"].eval()\n" + ] + }, + { + "cell_type": "markdown", + "id": "b169c5c7-7929-48e1-82eb-6f047aa4e5f2", + "metadata": {}, + "source": [ + "## Measure the Accuracy, Record the Confusion Matrix\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4fbb925a-269d-4027-9500-6cdce4de9d70", + "metadata": {}, + "outputs": [], + "source": [ + "logtofile(\"Entering Confusion\")\n", + "# logtofile(process.memory_info().rss) # in bytes \n", + "\n", + "original_accuracy = dict()\n", + "for key, val in process_structure.items():\n", + " logtofile(f\"Accuracy Calculation for ResNet18 with {key=}\")\n", + " model = val[\"model\"]\n", + " model.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS\n", + " \n", + " # Compute the model accuracy on the test set\n", + " correct = 0\n", + " total = 0\n", + " \n", + " # assuming 10 classes\n", + " # rows represent actual class, columns are predicted\n", + " confusion_matrix = torch.zeros(10,10, dtype=torch.int)\n", + " \n", + " TDL = biased_test_dataloader # In test 2D - ALWAYS use biased dataset to measure/train stitch\n", + " for data in TDL:\n", + " inputs, labels = data\n", + " inputs = inputs.to(device)\n", + " labels = labels.to(device)\n", + " predictions = torch.argmax(model(inputs),1)\n", + " \n", + " matches = predictions == labels\n", + " correct += matches.sum().item()\n", + " total += len(labels)\n", + " for idx, l in enumerate(labels):\n", + " confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] \n", + " \n", + " logtofile(\"Test the Trained Resnet18 against BIASED TEST DATALOADER\")\n", + " acc = ((100.0 * correct) / total)\n", + " logtofile('Test Accuracy: %2.2f %%' % acc)\n", + " original_accuracy[key] = acc\n", + " logtofile('Confusion Matrix')\n", + " logtofile(confusion_matrix)\n", + " logtofile(confusion_matrix.sum())\n", + " # logtofile(process.memory_info().rss) # in bytes \n", + "\n", + "\n", + "logtofile(f\"{original_accuracy=}\")" + ] + }, + { + "cell_type": "markdown", + "id": "da373b34-fe35-4256-a9e0-1040f699d45d", + "metadata": {}, + "source": [ + "## Measure Rank with __biased__ dataloader (test) before cutting and stitching" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4f74591-2a3f-4521-8aec-32c324125a5b", + "metadata": {}, + "outputs": [], + "source": [ + "logtofile(\"Entering whole model check\")\n", + "# logtofile(process.memory_info().rss) # in bytes \n", + "\n", + "# For the Whole Model - but we will pass it through the RcvResNet18 function to get matching feature names\n", + "for key, val in process_structure.items():\n", + " \n", + " TDL = biased_test_dataloader # ALWAYS use biased dataloader for this test\n", + " \n", + " if val[\"train\"]:\n", + " filename = val[\"saveas\"] \n", + " else: \n", + " filename = val[\"loadfrom\"] \n", + " assert os.path.exists(filename)\n", + " mdl = torchvision.models.resnet18(num_classes=10) # Untrained model\n", + " state = torch.load(filename, map_location=torch.device(\"cpu\"))\n", + " mdl.load_state_dict(state, assign=True)\n", + " mdl=mdl.to(device)\n", + " mdl = RcvResNet18(mdl, -1, colour_mnist_shape, device).to(device)\n", + "\n", + " out_filename = filename.split('/')[-1].replace('.weights', '-test.csv')\n", + " \n", + " outpath = f\"./{results_root}_rank/{key}-bias-{seed}_{out_filename}\" # denote output name as <model_training_type>-dataset-<name>\n", + " \n", + " if os.path.exists(f\"{outpath}\"):\n", + " logtofile(f\"Already evaluated for {outpath}\")\n", + " continue\n", + " logtofile(f\"Measure Rank for {key=}\")\n", + " print(f\"output to {outpath}\")\n", + " \n", + " params = {}\n", + " params[\"model\"] = key\n", + " params[\"dataset\"] = \"bias\"\n", + " params[\"seed\"] = seed\n", + " if val[\"train\"]: # as only one network used, record its filename as both send and receive files\n", + " params[\"send_file\"] = val[\"saveas\"] \n", + " params[\"rcv_file\"] = val[\"saveas\"] \n", + " else: \n", + " params[\"send_file\"] = val[\"loadfrom\"] \n", + " params[\"rcv_file\"] = val[\"loadfrom\"] \n", + " \n", + " with torch.no_grad():\n", + " layers, features, handles = install_hooks(mdl)\n", + " \n", + " metrics = evaluate_model(mdl, TDL, 'acc', verbose=2)\n", + " params.update(metrics)\n", + " classes = None\n", + " df = perform_analysis(features, classes, layers, params, n=-1)\n", + " df.to_csv(f\"{outpath}\")\n", + " for h in handles:\n", + " h.remove()\n", + " del mdl, layers, features, metrics, params, df, handles\n", + " gc.collect()\n", + " # logtofile(process.memory_info().rss) # in bytes \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "de2928d3-9c5f-411f-a590-14fd04026ab6", + "metadata": {}, + "source": [ + "# Stitch at a given layer\n" + ] + }, + { + "cell_type": "markdown", + "id": "3cbaf773-ed43-4d91-b0a0-a35fd468ac02", + "metadata": {}, + "source": [ + "## Train the stitch layer and check rank" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c29d808-b86e-4a0c-a3c5-127c952f3ab9", + "metadata": {}, + "outputs": [], + "source": [ + "logtofile(\"Entering Stitch/Rank\")\n", + "# logtofile(process.memory_info().rss) # in bytes \n", + "\n", + "logtofile(f\"{device=}\")\n", + "stitching_accuracies = dict()\n", + "stitching_penalties = dict()\n", + "# NOTE this is only valid as all models are the same architecture\n", + "num_layers_in_model = len(list(process_structure[\"bg\"][\"model\"].children())) \n", + "for send_key, send_val in process_structure.items():\n", + " if (send_key != \"bg\"):\n", + " logtofile(f\"NOTE: Only running stitch with bg send model\")\n", + " continue\n", + " stitching_accuracies[send_key] = dict()\n", + " stitching_penalties[send_key] = dict()\n", + " \n", + " for rcv_key, rcv_val in process_structure.items(): \n", + " if (rcv_key != \"unbias\"):\n", + " logtofile(f\"NOTE: Only running stitch with unbias receive model\")\n", + " continue \n", + " \n", + " stitching_accuracies[send_key][rcv_key] = dict()\n", + " stitching_penalties[send_key][rcv_key] = dict()\n", + " for layer_to_cut_after in range(3,num_layers_in_model - 1):\n", + " # for consistency, use the rcv network for the filename stem.\n", + " if rcv_val[\"train\"]:\n", + " filename = rcv_val[\"saveas\"] \n", + " else: \n", + " filename = rcv_val[\"loadfrom\"] \n", + " \n", + " rank_filename = filename.split('/')[-1].replace('.weights', '-test.csv') \n", + " # denote output name as <model_training_type>-dataset-<name>\n", + " # where <model_training_type> is [sender_model or X][layer_to_cut_after][Receiver_model]\n", + " model_training_type = f\"{send_key}{layer_to_cut_after}{rcv_key}\"\n", + " dataset_type = \"bias\" # ALWAYS use bias dataset in this test\n", + " outpath = f\"./{results_root}_rank/{model_training_type}-{dataset_type}-{seed}_{rank_filename}\" \n", + " \n", + " if os.path.exists(f\"{outpath}\"):\n", + " logtofile(f\"Already evaluated for {outpath}\")\n", + " continue\n", + " ####################################################################################\n", + " logtofile(f\"Evaluate ranks and output to {outpath}\")\n", + " # logtofile(process.memory_info().rss) # in bytes \n", + "\n", + " logtofile(f\"Train the stitch to a model stitched after layer {layer_to_cut_after} from {send_key} to {rcv_key}\") \n", + " logtofile(f\"Use the biased data loader (train and test) regardless of what {rcv_key} was trained on\")\n", + " \n", + " # train a stitch on the unbiased_colour dataset to compare receiver network performance with stitched\n", + " model_stitched = StitchedResNet18(send_model=send_val[\"model\"], \n", + " after_layer_index=layer_to_cut_after, \n", + " rcv_model=rcv_val[\"model\"],\n", + " input_image_shape=colour_mnist_shape, device=device ).to(device)\n", + " \n", + " #############################################################\n", + " # store the initial stitch state\n", + " initial_stitch_weight = model_stitched.stitch.s_conv1.weight.clone()\n", + " initial_stitch_bias = model_stitched.stitch.s_conv1.bias.clone()\n", + " stitch_initial_weight_outpath = f\"./{results_root}/STITCH_initial_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}\" \n", + " stitch_initial_bias_outpath = f\"./{results_root}/STITCH_initial_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}\" \n", + " torch.save(initial_stitch_weight, stitch_initial_weight_outpath)\n", + " torch.save(initial_stitch_bias, stitch_initial_bias_outpath)\n", + " ############################################################\n", + " \n", + " # define the loss function and the optimiser\n", + " loss_function = nn.CrossEntropyLoss()\n", + " # Hernandez said : momentum 0.9, batch size 256, weight decay 0.01, learning rate 0.01, and a post-warmup cosine learning rate scheduler.\n", + " # optimiser = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)\n", + " optimiser = optim.SGD(model_stitched.parameters(), lr=1e-4, momentum=0.9, weight_decay=0.01)\n", + " \n", + " # Put top model into train mode so that bn and dropout perform in training mode\n", + " model_stitched.train()\n", + " # Freeze the whole model\n", + " model_stitched.requires_grad_(False)\n", + " # Un-Freeze the stitch layer\n", + " for name, param in model_stitched.stitch.named_parameters():\n", + " param.requires_grad_(True)\n", + " # the epoch loop: note that we're training the whole network\n", + " for epoch in range(stitch_train_epochs):\n", + " running_loss = 0.0\n", + " for data in biased_train_dataloader:\n", + " # data is (representations, labels) tuple\n", + " # get the inputs and put them on the GPU\n", + " inputs, labels = data\n", + " inputs = inputs.to(device)\n", + " labels = labels.to(device)\n", + " \n", + " # zero the parameter gradients\n", + " optimiser.zero_grad()\n", + " \n", + " # forward + loss + backward + optimise (update weights)\n", + " outputs = model_stitched(inputs)\n", + " loss = loss_function(outputs, labels)\n", + " lambda_reg = 0.01\n", + " l1_norm = sum(p.abs().sum() for p in model_stitched.stitch.parameters())\n", + " loss += lambda_reg * l1_norm\n", + " loss.backward()\n", + " optimiser.step()\n", + " \n", + " # keep track of the loss this epoch\n", + " running_loss += loss.item()\n", + " logtofile(\"Epoch %d, loss %4.2f\" % (epoch, running_loss))\n", + " # logtofile(process.memory_info().rss) # in bytes \n", + "\n", + " logtofile('**** Finished Training ****')\n", + " \n", + " model_stitched.eval() # ALWAYS DO THIS BEFORE YOU EVALUATE MODELS\n", + "\n", + " ############################################################\n", + " # store the trained stitch\n", + " trained_stitch_weight = model_stitched.stitch.s_conv1.weight.clone()\n", + " trained_stitch_bias = model_stitched.stitch.s_conv1.bias.clone()\n", + " stitch_trained_weight_outpath = f\"./{results_root}/STITCH_trained_weight_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}\" \n", + " stitch_trained_bias_outpath = f\"./{results_root}/STITCH_trained_bias_{model_training_type}-{dataset_type}-{seed}_{filename.split('/')[-1]}\" \n", + " torch.save(trained_stitch_weight, stitch_trained_weight_outpath)\n", + " torch.save(trained_stitch_bias, stitch_trained_bias_outpath)\n", + " \n", + " stitch_weight_diff = trained_stitch_weight - initial_stitch_weight\n", + " stitch_weight_delta = torch.linalg.norm(stitch_weight_diff).item()\n", + " logtofile(f\"Change in stitch weights: {stitch_weight_delta}\")\n", + " maxabsweight = torch.max(stitch_weight_diff.abs()).item()\n", + " logtofile(f\"Largest abs weight change: {maxabsweight}\")\n", + " stitch_weight_number = torch.sum(torch.where(stitch_weight_diff.abs() > 0.1*maxabsweight, True, False)).item()\n", + " logtofile(f\"Number of weights changing > 0.1 of that: {stitch_weight_number}\")\n", + "\n", + " \n", + " print(f\"Number of weight / bias in stitch layer is {len(initial_stitch_weight)}\")\n", + " stitch_bias_diff = trained_stitch_bias - initial_stitch_bias\n", + " stitch_bias_delta = torch.linalg.norm(stitch_bias_diff).item()\n", + " logtofile(f\"Change in stitch bias: {stitch_bias_delta}\")\n", + " maxabsbias = torch.max(stitch_bias_diff.abs()).item()\n", + " logtofile(f\"Largest abs bias change: {maxabsbias}\")\n", + " stitch_bias_number = torch.sum(torch.where(stitch_bias_diff.abs() > 0.1*maxabsbias, True, False)).item()\n", + " logtofile(f\"Number of bias changing > 0.1 of that: {stitch_bias_number}\")\n", + " ##############################################################\n", + "\n", + " \n", + " # Compute the model accuracy on the test set\n", + " correct = 0\n", + " total = 0\n", + " \n", + " # assuming 10 classes\n", + " # rows represent actual class, columns are predicted\n", + " confusion_matrix = torch.zeros(10,10, dtype=torch.int)\n", + " \n", + " for data in biased_test_dataloader: # Only use biased test data\n", + " inputs, labels = data\n", + " inputs = inputs.to(device)\n", + " labels = labels.to(device)\n", + " \n", + " predictions = torch.argmax(model_stitched(inputs),1)\n", + " matches = predictions == labels.to(device)\n", + " correct += matches.sum().item()\n", + " total += len(labels)\n", + " \n", + " for idx, l in enumerate(labels):\n", + " confusion_matrix[l, predictions[idx]] = 1 + confusion_matrix[l, predictions[idx]] \n", + " logtofile(\"Test the trained stitch against biased data\") \n", + " acc = ((100.0 * correct) / total)\n", + " logtofile('Test Accuracy: %2.2f %%' % acc)\n", + " logtofile('Confusion Matrix')\n", + " logtofile(confusion_matrix)\n", + " logtofile(\"===================================================================\")\n", + " # logtofile(process.memory_info().rss) # in bytes \n", + "\n", + " # Stitching penalty should be negative if there is an improvement, and is relative to the original receiver network\n", + " stitching_accuracies[send_key][rcv_key][layer_to_cut_after] = acc\n", + " stitching_penalties[send_key][rcv_key][layer_to_cut_after] = original_accuracy[rcv_key] - acc\n", + "\n", + " TDL = biased_test_dataloader\n", + " params = {}\n", + " params[\"model\"] = model_training_type # a mnemonic\n", + " params[\"dataset\"] = dataset_type\n", + " params[\"seed\"] = seed\n", + " if send_val[\"train\"]:\n", + " params[\"send_file\"] = send_val[\"saveas\"] \n", + " else: \n", + " params[\"send_file\"] = send_val[\"loadfrom\"] \n", + " if rcv_val[\"train\"]:\n", + " params[\"rcv_file\"] = rcv_val[\"saveas\"] \n", + " else: \n", + " params[\"rcv_file\"] = rcv_val[\"loadfrom\"] \n", + " params[\"stitch_weight_delta\"] = stitch_weight_delta\n", + " params[\"stitch_bias_delta\"] = stitch_bias_delta \n", + " params[\"stitch_weight_number\"] = stitch_weight_number\n", + " params[\"stitch_bias_number\"] = stitch_bias_number\n", + " # logtofile(process.memory_info().rss) # in bytes \n", + " with torch.no_grad():\n", + " layers, features, handles = install_hooks(model_stitched) \n", + " metrics = evaluate_model(model_stitched, TDL, 'acc', verbose=2)\n", + " params.update(metrics)\n", + " classes = None\n", + " df = perform_analysis(features, classes, layers, params, n=-1)\n", + " df.to_csv(f\"{outpath}\")\n", + " \n", + " for h in handles:\n", + " h.remove()\n", + " del model_stitched, layers, features, metrics, params, df, handles\n", + " gc.collect()\n", + " # logtofile(process.memory_info().rss) # in bytes \n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3db5808a-6351-4133-9e7c-85e8e9248cfb", + "metadata": {}, + "outputs": [], + "source": [ + "logtofile(f\"{stitching_accuracies=}\")\n", + "logtofile(f\"{stitching_penalties=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd669a19-7c42-4265-b8e8-57165995c097", + "metadata": {}, + "outputs": [], + "source": [ + "for s_key in stitching_accuracies: \n", + " for r_key in stitching_accuracies[s_key]:\n", + " logtofile(f\"{s_key}-{r_key}\")\n", + " logtofile(f\"{original_accuracy[r_key]=}\")\n", + " logtofile(\"Stitch Accuracy\")\n", + " for layer in stitching_accuracies[s_key][r_key]:\n", + " logtofile(f\"L{layer}: {stitching_accuracies[s_key][r_key][layer]}\")\n", + " logtofile(\"--------------------------\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "affectnet_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}