This package offers an efficient implementation of Kernel SHAP, see [1] and [2]. For up to \(p=8\) features, the resulting Kernel SHAP values are exact regarding the selected background data. For larger \(p\), an almost exact hybrid algorithm involving iterative sampling is used by default.
The typical workflow to explain any model object
:
X
to be explained. If the training dataset is small, simply
use the full training data for this purpose. X
should only
contain feature columns.bg_X
to calculate
marginal means. For this purpose, set aside 50 to 500 rows from the
training data. If the training data is small, use the full training
data. In cases with a natural “off” value (like MNIST digits), this can
also be a single row with all values set to the off value.kernelshap(object, X, bg_X, ...)
to calculate SHAP values.
Runtime is proportional to nrow(X)
, while memory
consumption scales linearly in nrow(bg_X)
.Remarks
bg_w
.# From CRAN
install.packages("kernelshap")
# Or the development version:
::install_github("ModelOriented/kernelshap") devtools
Let’s model diamonds prices!
library(ggplot2)
library(kernelshap)
library(shapviz)
<- transform(
diamonds
diamonds,log_price = log(price),
log_carat = log(carat)
)
<- lm(log_price ~ log_carat + clarity + color + cut, data = diamonds)
fit_lm
# 1) Sample rows to be explained
set.seed(10)
<- c("log_carat", "clarity", "color", "cut")
xvars <- diamonds[sample(nrow(diamonds), 1000), xvars]
X
# 2) Select background data
<- diamonds[sample(nrow(diamonds), 200), ]
bg_X
# 3) Crunch SHAP values for all 1000 rows of X (~6 seconds)
system.time(
<- kernelshap(fit_lm, X, bg_X = bg_X)
shap_lm
)
shap_lm
# SHAP values of first 2 observations:
# carat clarity color cut
# [1,] 1.2692479 0.1081900 -0.07847065 0.004630899
# [2,] -0.4499226 -0.1111329 0.11832292 0.026503850
# 4) Analyze
<- shapviz(shap_lm)
sv_lm sv_importance(sv_lm)
sv_dependence(sv_lm, "log_carat", color_var = NULL)
We can also explain a specific prediction instead of the full model:
<- diamonds[5000, xvars]
single_row
|>
fit_lm kernelshap(single_row, bg_X = bg_X) |>
shapviz() |>
sv_waterfall()
We can use the same X
and bg_X
to inspect
other models:
library(ranger)
<- ranger(
fit_rf ~ log_carat + clarity + color + cut,
log_price data = diamonds,
num.trees = 20,
seed = 20
)
<- kernelshap(fit_rf, X, bg_X = bg_X)
shap_rf
shap_rf
# SHAP values of first 2 observations:
# log_carat clarity color cut
# [1,] 1.1987785 0.09578879 -0.1397765 0.002761832
# [2,] -0.4969451 -0.12006207 0.1050928 0.029680717
<- shapviz(shap_rf)
sv_rf sv_importance(sv_rf, kind = "bee", show_numbers = TRUE)
sv_dependence(sv_rf, "log_carat")
Or a deep neural net (results not fully reproducible):
library(keras)
<- keras_model_sequential()
nn |>
nn layer_dense(units = 30, activation = "relu", input_shape = 4) |>
layer_dense(units = 15, activation = "relu") |>
layer_dense(units = 1)
|>
nn compile(optimizer = optimizer_adam(0.1), loss = "mse")
<- list(
cb callback_early_stopping(patience = 20),
callback_reduce_lr_on_plateau(patience = 5)
)
|>
nn fit(
x = data.matrix(diamonds[xvars]),
y = diamonds$log_price,
epochs = 100,
batch_size = 400,
validation_split = 0.2,
callbacks = cb
)
<- function(mod, X) predict(mod, data.matrix(X), batch_size = 10000)
pred_fun <- kernelshap(nn, X, bg_X = bg_X, pred_fun = pred_fun)
shap_nn
<- shapviz(shap_nn)
sv_nn sv_importance(sv_nn, show_numbers = TRUE)
sv_dependence(sv_nn, "clarity")
Parallel computing is supported via foreach
, at the
price of losing the progress bar. Note that this does not work with
Keras models (and some others).
library(doFuture)
# Set up parallel backend
registerDoFuture()
plan(multisession, workers = 4) # Windows
# plan(multicore, workers = 4) # Linux, macOS, Solaris
# ~3 seconds on second run
system.time(
<- kernelshap(fit_lm, X, bg_X = bg_X, parallel = TRUE)
s )
On Windows, sometimes not all packages or global objects are passed
to the parallel sessions. In this case, the necessary instructions to
foreach
can be specified through a named list via
parallel_args
, see the following example:
library(mgcv)
<- gam(log_price ~ s(log_carat) + clarity + color + cut, data = diamonds)
fit_gam
system.time(
<- kernelshap(
shap_gam
fit_gam,
X, bg_X = bg_X,
parallel = TRUE,
parallel_args = list(.packages = "mgcv")
)
)
shap_gam
# SHAP values of first 2 observations:
# log_carat clarity color cut
# [1,] 1.2714988 0.1115546 -0.08454955 0.003220451
# [2,] -0.5153642 -0.1080045 0.11967804 0.031341595
Here, we provide some working examples for “tidymodels”, “caret”, and “mlr3”.
library(tidymodels)
library(kernelshap)
<- iris %>%
iris_recipe recipe(Sepal.Length ~ .)
<- linear_reg() %>%
reg set_engine("lm")
<- workflow() %>%
iris_wf add_recipe(iris_recipe) %>%
add_model(reg)
<- iris_wf %>%
fit fit(iris)
<- kernelshap(fit, iris[, -1], bg_X = iris)
ks ks
library(caret)
library(kernelshap)
library(shapviz)
<- train(
fit ~ .,
Sepal.Length data = iris,
method = "lm",
tuneGrid = data.frame(intercept = TRUE),
trControl = trainControl(method = "none")
)
<- kernelshap(fit, iris[, -1], predict, bg_X = iris)
s <- shapviz(s)
sv sv_waterfall(sv, 1)
library(mlr3)
library(mlr3learners)
library(kernelshap)
library(shapviz)
$get("iris")
mlr_taskstsk("iris")
<- TaskRegr$new(id = "iris", backend = iris, target = "Sepal.Length")
task_iris <- lrn("regr.lm")
fit_lm $train(task_iris)
fit_lm<- kernelshap(fit_lm, iris[-1], bg_X = iris)
s <- shapviz(s)
sv sv_dependence(sv, "Species")
[1] Scott M. Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions. Advances in Neural Information Processing Systems 30, 2017.
[2] Ian Covert and Su-In Lee. Improving KernelSHAP: Practical Shapley Value Estimation Using Linear Regression. Proceedings of The 24th International Conference on Artificial Intelligence and Statistics, PMLR 130:3457-3465, 2021.