Skip to content
Snippets Groups Projects
Commit ea6f2a0d authored by D.K.Burns's avatar D.K.Burns
Browse files

Initial commit

parents
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
%% Cell type:markdown id: tags:
© University of Southampton IT Innovation Centre, 2020-2021
Copyright in this software belongs to University of Southampton
IT Innovation Centre of Gamma House, Enterprise Road,
Chilworth Science Park, Southampton, SO16 7NS, UK.
This software may not be used, sold, licensed, transferred, copied
or reproduced in whole or in part in any manner or form or in or
on any media by any person other than in accordance with the terms
of the Licence Agreement supplied with the software, or otherwise
without the prior written consent of the copyright owners.
This software is distributed WITHOUT ANY WARRANTY, without even the
implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE, except where stated in the Licence Agreement supplied with
the software.
Created for Project : Alan Turing Institute Project EP/N510129/1
Decision support algorithms for emergency departments
%% Cell type:code id: tags:
``` python
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# stats imports
from statsmodels.stats.proportion import proportion_confint
from umap import UMAP
# clustering imports
from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score
import seaborn as sns
```
%% Cell type:markdown id: tags:
### Load output from model
%% Cell type:code id: tags:
``` python
shap_df = pd.read_csv('%path to input file%')
shap_values = np.load('%path to input file%')
```
%% Cell type:code id: tags:
``` python
shap_df["preds"] = np.load("%path to input file%")
```
%% Cell type:code id: tags:
``` python
embedder = UMAP(random_state=42, n_neighbors=75)
X_emb = embedder.fit_transform(shap_values)
```
%% Cell type:code id: tags:
``` python
shap_df['X1'] = X_emb[:,0]
shap_df['X2'] = X_emb[:,1]
```
%% Cell type:code id: tags:
``` python
clustering = DBSCAN(eps=3)
cluster_labels = clustering.fit_predict(shap_df.loc[:,['X1','X2']])
shap_df['cluster_labels'] = cluster_labels
```
%% Cell type:code id: tags:
``` python
n_clusters = shap_df['cluster_labels'].nunique()
```
%% Cell type:markdown id: tags:
#### Plot the raw data alongslide the labelled clusters
%% Cell type:code id: tags:
``` python
import matplotlib
```
%% Cell type:code id: tags:
``` python
cmap = matplotlib.cm.get_cmap('tab10')
```
%% Cell type:code id: tags:
``` python
colors = [cmap(x) for x in range(10)]
cmap_colors = [colors[i] for i in cluster_labels]
```
%% Cell type:code id: tags:
``` python
fig, axes = plt.subplots(ncols=2)
fig.subplots_adjust(left=0.02, bottom=0.06, right=0.95, top=0.94, wspace=0.1)
crange2 = axes[1].scatter(shap_df['X1'], shap_df['X2'], c=cmap_colors, s=10)
#cax2 = fig.add_axes([0.525, 0.07, 0.015, 0.9])
#fig.colorbar(crange2, cax=cax2, ticks=np.arange(0,10))
#cax2.tick_params(which='both', labelsize=8)
#cax2.set_ylabel('Cluster label')
#plt.colorbar(cax, ax=axes[1], shrink=0.9, label='Cluster label')
#axes[0].axis('off')
crange = axes[0].scatter(shap_df['X1'], shap_df['X2'], c=shap_df['preds'], s=10, vmin=0, vmax=0.6)
#plt.colorbar(cax, ax=axes[0], shrink=0.9, label='Predicted risk', orientation='horizontal')
cax = fig.add_axes([0.40, 0.1, 0.015, 0.35])
fig.colorbar(crange, cax=cax)
cax.tick_params(which='both', labelsize=7)
cax.set_ylabel('Predicted risk',fontsize=8)
# add text
coords = {'0':(20,9), '1':(15, -2.5), '2':(-2,2), '3':(5,-2.5), '4':(0,15),
'5':(3,20),'6':(6,-10), '7':(15,-7.5)}
for key, value in coords.items():
axes[1].text(*value, key, color=colors[int(key)], size=10)
for ax in axes:
ax.tick_params(which='both', tick1On=False, label1On=False)
plt.setp(ax.spines.values(), color='0.9')
ax.set_xlim(-15,22.5)
ax.set_ylim(-15,25)
# add the labels
labels = ['a)', 'b)']
for i, ax in enumerate(axes):
ax.text(0.03,0.92, labels[i], fontsize=10, transform=ax.transAxes,
bbox=dict(facecolor='White',alpha=0.6, edgecolor='white', boxstyle='round'))
fig.set_size_inches(6,3)
#axes[1].axis('off')
plt.tight_layout()
plt.savefig('%path to output file%', dpi=300, bbox_inches='tight')
```
%% Cell type:code id: tags:
``` python
sns.histplot(shap_df['preds'])
```
%% Cell type:markdown id: tags:
#### Summarize properties of each cluster
%% Cell type:code id: tags:
``` python
def create_cluster_summary(df):
bin_cols = ['reattended_in_72hours']
cont_cols = ['Condition count','30 day visit count', 'Age', 'preds']
cat_cols = ['Diagnosis', 'Triage complaint']
summary = shap_df.groupby('cluster_labels').agg(['count', 'sum', 'mean', 'std'])[bin_cols+cont_cols]
# calculate standard error on continuous columns
for col in cont_cols:
summary.loc[:,(col,'std_err')] = summary[col]['std']/summary[col]['count']**0.5
# calculate 95 % CI for binary columns
for col in bin_cols:
ci_lower, ci_upper = proportion_confint(summary[col]['sum'], summary[col]['count'], method='wilson')
summary.loc[:,(col,'ci_lower')] = ci_lower
summary.loc[:,(col,'ci_upper')] = ci_upper
return summary[bin_cols+cont_cols] # sort multi-index
```
%% Cell type:code id: tags:
``` python
summary = create_cluster_summary(shap_df)
```
%% Cell type:code id: tags:
``` python
shap_df
```
%% Cell type:code id: tags:
``` python
summary.columns.append()
```
%% Cell type:code id: tags:
``` python
summary['Age'][['mean','std_err']]
```
%% Cell type:code id: tags:
``` python
shap_df.columns
```
%% Cell type:code id: tags:
``` python
shap_value_df = pd.DataFrame(shap_values, columns=shap_df.loc[:,'Age':'Temperature'].columns)
shap_value_df['cluster_labels'] = cluster_labels
```
%% Cell type:code id: tags:
``` python
shap_value_df['medical_history'] = shap_value_df.loc[:,'Hypertension':'Allergy'].sum(axis=1)
```
%% Cell type:code id: tags:
``` python
shap_value_df = shap_value_df.abs()
```
%% Cell type:code id: tags:
``` python
cluster_mean_shaps = shap_value_df.groupby('cluster_labels').mean()
```
%% Cell type:code id: tags:
``` python
cluster_feature_ranks = []
cluster_feature_rank_values = []
for clabel in range(n_clusters):
print('---'*5)
print(clabel)
print(cluster_mean_shaps.loc[clabel,:].sort_values()[-3:])
cluster_feature_ranks.append(cluster_mean_shaps.loc[clabel,:].sort_values()[-3:].index.tolist()[::-1])
cluster_feature_rank_values.append(cluster_mean_shaps.loc[clabel,:].sort_values()[-3:].values.tolist()[::-1])
cluster_feature_ranks = np.array(cluster_feature_ranks)
cluster_feature_rank_values = np.array(cluster_feature_rank_values)
```
%% Cell type:code id: tags:
``` python
summary = pd.concat(
(
summary,
pd.DataFrame(
data = cluster_feature_ranks,
columns = pd.MultiIndex.from_tuples([
("Most important SHAP values (mean absolute SHAP values)", "1st"),
("Most important SHAP values (mean absolute SHAP values)", "2nd"),
("Most important SHAP values (mean absolute SHAP values)", "3rd")
])
),
pd.DataFrame(
data = cluster_feature_rank_values,
columns = pd.MultiIndex.from_tuples([
("Mean absolute SHAP value", "1st"),
("Mean absolute SHAP value", "2nd"),
("Mean absolute SHAP value", "3rd")
]
)
)), axis = 1)
summary.to_csv("model_output/cluster_summary_table.csv")
```
This diff is collapsed.
%% Cell type:markdown id:fd3885d7-26e9-4a94-8375-4c6c608d6837 tags:
© University of Southampton IT Innovation Centre, 2020-2021
Copyright in this software belongs to University of Southampton
IT Innovation Centre of Gamma House, Enterprise Road,
Chilworth Science Park, Southampton, SO16 7NS, UK.
This software may not be used, sold, licensed, transferred, copied
or reproduced in whole or in part in any manner or form or in or
on any media by any person other than in accordance with the terms
of the Licence Agreement supplied with the software, or otherwise
without the prior written consent of the copyright owners.
This software is distributed WITHOUT ANY WARRANTY, without even the
implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE, except where stated in the Licence Agreement supplied with
the software.
Created for Project : Alan Turing Institute Project EP/N510129/1
Decision support algorithms for emergency departments
%% Cell type:code id:8a7f342b-bad9-4cc5-a946-60dc7f8266ab tags:
``` python
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
np.random.seed(42)
```
%% Cell type:code id:supposed-nothing tags:
``` python
VALUES_PATH = "%path to input file%"
SHAP_DF_PATH = "%path to input file%"
SAVE_PATH = '%path to output directory%'# folder to save to
```
%% Cell type:code id:color-pollution tags:
``` python
shap_values = np.load(VALUES_PATH)
shap_df = pd.read_csv(SHAP_DF_PATH)
shap_values_df = pd.DataFrame(shap_values, columns=shap_df.columns[1:-3])
```
%% Cell type:markdown id:artistic-vegetable tags:
#### Functions used throughout
%% Cell type:code id:earlier-acceptance tags:
``` python
def _remove_axis(ax):
"""
Removes top and right parts of axis.
Parameters:
-----------
ax : matplotlib.pyplot.Axis,
An axis object
"""
# Hide the right and top spines
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# Only show ticks on the left and bottom spines
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
def _apply_clever_jitter(shaps, row_height=0.4):
"""
Took this from the SHAP library:
https://github.com/slundberg/shap/blob/d0b4d59f96adc5d067586c0dd4f7f2326532c47a/shap/plots/_beeswarm.py#L305
It creates jitter which is proportional to the number of data points in a given range of shap values.
This looks tidier than just applying uniform jitter to the points.
Parameters:
-----------
shaps : np.array,
SHAP values for a single feature. Shape (n_instances,)
"""
N = len(shaps)
nbins = 100
quant = np.round(nbins * (shaps - np.min(shaps)) / (np.max(shaps) - np.min(shaps) + 1e-8))
inds = np.argsort(quant + np.random.randn(N) * 1e-6)
layer = 0
last_bin = -1
ys = np.zeros(N)
for ind in inds:
if quant[ind] != last_bin:
layer = 0
ys[ind] = np.ceil(layer / 2) * ((layer % 2) * 2 - 1)
layer += 1
last_bin = quant[ind]
ys *= 0.9 * (row_height / np.max(ys + 1))
return ys
def shap_summary_plot(shap_values_df,
shap_df,
ax=None,
cmap='cividis',
feature_cols=None):
"""
"""
n_instances = shap_values_df.shape[0]
if feature_cols is None:
# find top ten by mean absolute SHAP values
top_ten = shap_values_df.abs().mean(axis=0).sort_values()[-10:].index
else:
top_ten = feature_cols
# make fake y-axis positons, adding jitter to the points
y = np.zeros(shape=(n_instances, 10))
for i in range(10):
y[:,i] += i + _apply_clever_jitter(shap_values_df[top_ten].to_numpy()[:,i])
# plot graphic
if ax is None:
fig, ax = plt.subplots()
for i in range(10):
colors = shap_df[top_ten].iloc[:,i]
# if column not number, plot all points as grey
col_dtype = colors.dtype
if col_dtype not in ['float64','float32','int64','int32','int16']:
colors = '0.6'
x_values = shap_values_df[top_ten].iloc[:,i]
ax.scatter(x_values, y[:,i], c=colors, cmap=cmap, s=8, alpha=0.9)
# add light gray dashed lines
ax.axhline(y=i, ls='--',lw=0.5, color='0.9', zorder=-1e7)
# format the plot
_remove_axis(ax)
ax.set_yticks(range(10))
ax.set_yticklabels(top_ten)
ax.set_xlabel('SHAP value', fontsize=10)
```
%% Cell type:code id:worst-continent tags:
``` python
shap_df.columns.values
```
%% Cell type:code id:surprising-samoa tags:
``` python
fig_cols = ['Condition count', '30 day visit count', 'Diagnosis', 'Triage complaint', 'Current smoker', 'History of smoking', 'Lives alone', 'Hour of day', 'Harmful use of alcohol', 'Depression']
```
%% Cell type:markdown id:intermediate-closer tags:
### Rename any column names to display nicely
You will have to rename any columns which do not display nicely in the figure below.
%% Cell type:code id:nuclear-sitting tags:
``` python
mapper = {'Pulse_rate':'Pulse rate','Systolic_bp': 'Systolic BP'}
shap_df = shap_df.rename(mapper, axis=1)
shap_values_df = shap_values_df.rename(mapper, axis=1)
```
%% Cell type:markdown id:boxed-alias tags:
### Create the figure
%% Cell type:code id:pressing-money tags:
``` python
# FIGURE params
AXIS_LABEL_SIZE = 10
MARKERSIZE = 8# feel free to up size if they are a bit small
CMAP = 'cividis'
# parameters for figure labels: a), b) c)
TBOX_PARAMS = {'facecolor':'white',
'alpha':0.85,
'linewidth':0.5,
'edgecolor':'gray',
'boxstyle':'round,pad=0.35'}
```
%% Cell type:code id:color-charity tags:
``` python
fig = plt.figure()
gs = GridSpec(2, 3)
ax1 = fig.add_subplot(gs[:,0:2])
ax2 = fig.add_subplot(gs[0, 2])
ax3 = fig.add_subplot(gs[1, 2])
# plot panel a
shap_summary_plot(shap_values_df, shap_df, ax=ax1,
feature_cols=fig_cols[::-1])
# plot panel b and c
for ax, col_name in zip([ax2, ax3], ['Hour of day', '30 day visit count']):
ax.scatter(shap_df[col_name],
shap_values_df[col_name],
marker='o',
s=MARKERSIZE,
c=shap_df[col_name],
cmap=CMAP)
ax.set_xlabel(col_name, fontsize=AXIS_LABEL_SIZE)
ax.set_ylabel('SHAP value', fontsize=AXIS_LABEL_SIZE)
# format tick labels and ax2, ax3
for ax in [ax2, ax3]:
ax.tick_params(labelsize=9)
_remove_axis(ax)
# hack for panel a
ax1.set_xlim(-0.1,0.5)
# add figure panel labels
ax1.text(0.03, 0.94, 'a)', fontsize=10, transform=ax1.transAxes, bbox=TBOX_PARAMS)
for ax, lab in zip([ax2, ax3], ['b)', 'c)']):
ax.text(0.07, 0.92, lab, fontsize=10, transform=ax.transAxes, bbox=TBOX_PARAMS)
# tidy up the plot
fig.set_size_inches(7.75,4.25)
plt.tight_layout()
plt.savefig(SAVE_PATH + "shap_summary_plot.png", dpi=250, bbox_inches='tight')
```
%% Cell type:code id:spare-suicide tags:
``` python
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment