Skip to content
Snippets Groups Projects
figure_3_generator.ipynb 10.1 KiB
Newer Older
D.K.Burns's avatar
D.K.Burns committed
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fd3885d7-26e9-4a94-8375-4c6c608d6837",
   "metadata": {},
   "source": [
    "© University of Southampton IT Innovation Centre, 2020-2021 \n",
    "\n",
    "Copyright in this software belongs to University of Southampton \n",
    "IT Innovation Centre of Gamma House, Enterprise Road, \n",
    "Chilworth Science Park, Southampton, SO16 7NS, UK. \n",
    "\n",
    "This software may not be used, sold, licensed, transferred, copied \n",
    "or reproduced in whole or in part in any manner or form or in or \n",
    "on any media by any person other than in accordance with the terms \n",
    "of the Licence Agreement supplied with the software, or otherwise \n",
    "without the prior written consent of the copyright owners. \n",
    "\n",
    "This software is distributed WITHOUT ANY WARRANTY, without even the \n",
    "implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR \n",
    "PURPOSE, except where stated in the Licence Agreement supplied with \n",
    "the software. \n",
    "\n",
    "Created for Project :   Alan Turing Institute Project EP/N510129/1 \n",
    "Decision support algorithms for emergency departments "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a7f342b-bad9-4cc5-a946-60dc7f8266ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from matplotlib.gridspec import GridSpec\n",
    "\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "supposed-nothing",
   "metadata": {},
   "outputs": [],
   "source": [
    "VALUES_PATH = \"%path to input file%\"\n",
    "SHAP_DF_PATH = \"%path to input file%\"\n",
    "SAVE_PATH = '%path to output directory%'# folder to save to"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "color-pollution",
   "metadata": {},
   "outputs": [],
   "source": [
    "shap_values = np.load(VALUES_PATH)\n",
    "shap_df = pd.read_csv(SHAP_DF_PATH)\n",
    "\n",
    "shap_values_df = pd.DataFrame(shap_values, columns=shap_df.columns[1:-3])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "artistic-vegetable",
   "metadata": {},
   "source": [
    "#### Functions used throughout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "earlier-acceptance",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _remove_axis(ax):\n",
    "    \"\"\"\n",
    "    Removes top and right parts of axis.\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    ax : matplotlib.pyplot.Axis,\n",
    "        An axis object\n",
    "    \"\"\"\n",
    "    # Hide the right and top spines\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    ax.spines['top'].set_visible(False)\n",
    "\n",
    "    # Only show ticks on the left and bottom spines\n",
    "    ax.yaxis.set_ticks_position('left')\n",
    "    ax.xaxis.set_ticks_position('bottom')\n",
    "    \n",
    "def _apply_clever_jitter(shaps, row_height=0.4):\n",
    "    \"\"\"\n",
    "    Took this from the SHAP library:\n",
    "    https://github.com/slundberg/shap/blob/d0b4d59f96adc5d067586c0dd4f7f2326532c47a/shap/plots/_beeswarm.py#L305\n",
    "    \n",
    "    It creates jitter which is proportional to the number of data points in a given range of shap values. \n",
    "    \n",
    "    This looks tidier than just applying uniform jitter to the points.\n",
    "    \n",
    "    Parameters:\n",
    "    -----------\n",
    "    shaps : np.array,\n",
    "        SHAP values for a single feature. Shape (n_instances,)\n",
    "    \n",
    "    \"\"\"\n",
    "    N = len(shaps)\n",
    "    nbins = 100\n",
    "    quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))\n",
    "    inds = np.argsort(quant + np.random.randn(N) * 1e-6)\n",
    "    layer = 0\n",
    "    last_bin = -1\n",
    "    ys = np.zeros(N)\n",
    "    for ind in inds:\n",
    "        if quant[ind] != last_bin:\n",
    "            layer = 0\n",
    "        ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)\n",
    "        layer += 1\n",
    "        last_bin = quant[ind]\n",
    "    ys *= 0.9 * (row_height / np.max(ys + 1))\n",
    "    \n",
    "    return ys\n",
    "    \n",
    "def shap_summary_plot(shap_values_df,\n",
    "                      shap_df,\n",
    "                      ax=None,\n",
    "                      cmap='cividis',\n",
    "                      feature_cols=None):\n",
    "    \"\"\"\n",
    "    \"\"\"\n",
    "    n_instances = shap_values_df.shape[0]\n",
    "    \n",
    "    if feature_cols is None:\n",
    "    # find top ten by mean absolute SHAP values\n",
    "            top_ten = shap_values_df.abs().mean(axis=0).sort_values()[-10:].index\n",
    "    else:\n",
    "        top_ten = feature_cols\n",
    "    \n",
    "    # make fake y-axis positons, adding jitter to the points\n",
    "    y = np.zeros(shape=(n_instances, 10))\n",
    "    for i in range(10):\n",
    "        y[:,i] += i + _apply_clever_jitter(shap_values_df[top_ten].to_numpy()[:,i])\n",
    "        \n",
    "    # plot graphic\n",
    "    if ax is None:\n",
    "        fig, ax = plt.subplots()\n",
    "        \n",
    "    for i in range(10):\n",
    "        colors = shap_df[top_ten].iloc[:,i]\n",
    "        # if column not number, plot all points as grey\n",
    "        col_dtype = colors.dtype\n",
    "        if col_dtype not in ['float64','float32','int64','int32','int16']:\n",
    "            colors = '0.6'\n",
    "        x_values = shap_values_df[top_ten].iloc[:,i]\n",
    "        ax.scatter(x_values, y[:,i],  c=colors, cmap=cmap, s=8, alpha=0.9)\n",
    "        \n",
    "        # add light gray dashed lines\n",
    "        ax.axhline(y=i, ls='--',lw=0.5, color='0.9', zorder=-1e7)\n",
    "    \n",
    "    # format the plot\n",
    "    _remove_axis(ax)\n",
    "    ax.set_yticks(range(10))\n",
    "    ax.set_yticklabels(top_ten)\n",
    "    ax.set_xlabel('SHAP value', fontsize=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "worst-continent",
   "metadata": {},
   "outputs": [],
   "source": [
    "shap_df.columns.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "surprising-samoa",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig_cols = ['Condition count', '30 day visit count', 'Diagnosis', 'Triage complaint', 'Current smoker', 'History of smoking', 'Lives alone', 'Hour of day', 'Harmful use of alcohol', 'Depression']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "intermediate-closer",
   "metadata": {},
   "source": [
    "### Rename any column names to display nicely\n",
    "\n",
    "You will have to rename any columns which do not display nicely in the figure below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "nuclear-sitting",
   "metadata": {},
   "outputs": [],
   "source": [
    "mapper = {'Pulse_rate':'Pulse rate','Systolic_bp': 'Systolic BP'}\n",
    "\n",
    "shap_df = shap_df.rename(mapper, axis=1)\n",
    "shap_values_df = shap_values_df.rename(mapper, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "boxed-alias",
   "metadata": {},
   "source": [
    "### Create the figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "pressing-money",
   "metadata": {},
   "outputs": [],
   "source": [
    "# FIGURE params\n",
    "\n",
    "AXIS_LABEL_SIZE = 10\n",
    "MARKERSIZE = 8# feel free to up size if they are a bit small\n",
    "CMAP = 'cividis'\n",
    "\n",
    "# parameters for figure labels: a), b) c)\n",
    "TBOX_PARAMS = {'facecolor':'white',\n",
    "               'alpha':0.85,\n",
    "               'linewidth':0.5,\n",
    "               'edgecolor':'gray',\n",
    "               'boxstyle':'round,pad=0.35'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "color-charity",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "gs = GridSpec(2, 3)\n",
    "\n",
    "ax1 = fig.add_subplot(gs[:,0:2])\n",
    "ax2 = fig.add_subplot(gs[0, 2])\n",
    "ax3 = fig.add_subplot(gs[1, 2])\n",
    "\n",
    "# plot panel a\n",
    "shap_summary_plot(shap_values_df, shap_df, ax=ax1,\n",
    "                  feature_cols=fig_cols[::-1])\n",
    "\n",
    "# plot panel b and c\n",
    "for ax, col_name in zip([ax2, ax3], ['Hour of day', '30 day visit count']):\n",
    "    ax.scatter(shap_df[col_name],\n",
    "               shap_values_df[col_name], \n",
    "               marker='o', \n",
    "               s=MARKERSIZE, \n",
    "               c=shap_df[col_name],\n",
    "               cmap=CMAP)\n",
    "    ax.set_xlabel(col_name, fontsize=AXIS_LABEL_SIZE)\n",
    "    ax.set_ylabel('SHAP value', fontsize=AXIS_LABEL_SIZE)\n",
    "\n",
    "# format tick labels and ax2, ax3\n",
    "for ax in [ax2, ax3]:\n",
    "    ax.tick_params(labelsize=9)\n",
    "    _remove_axis(ax)\n",
    "    \n",
    "# hack for panel a\n",
    "ax1.set_xlim(-0.1,0.5)\n",
    "\n",
    "# add figure panel labels\n",
    "ax1.text(0.03, 0.94, 'a)', fontsize=10, transform=ax1.transAxes, bbox=TBOX_PARAMS)\n",
    "for ax, lab in zip([ax2, ax3], ['b)', 'c)']):\n",
    "    ax.text(0.07, 0.92, lab, fontsize=10, transform=ax.transAxes, bbox=TBOX_PARAMS)\n",
    "\n",
    "# tidy up the plot\n",
    "fig.set_size_inches(7.75,4.25)\n",
    "plt.tight_layout()\n",
    "plt.savefig(SAVE_PATH + \"shap_summary_plot.png\", dpi=250, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "spare-suicide",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}