1 Introduction

DALEX is designed to work with various black-box models like tree ensembles, linear models, neural networks etc. Unfortunately R packages that create such models are very inconsistent. Different tools use different interfaces to train, validate and use models. Fortunately DALEX can handle it all easily.

In this vignette we will show explanations for models from parsnip Max Kuhn and Davis Vaughan (2019).

2 Regression use case - titanic

library(DALEX)
library(parsnip)
library(dplyr)

To illustrate applications of DALEX to regression problems we use a titanic dataset available in the base DALEX package. Our goal is to predict the fare based on selected features such as gender, age, class, place of embarking, number of sibilings/spouses aboard,number of parents/childrean aboard, and surviving status . It should be noted that It should be noted that all of these variables are discrete. Other important note is that target variable, fare, is 0 for crew members, musicians and employees of the sipyard company, therefore we remove those rows

Titanic dataset was copied from the stablelearner package and went through few variable transformations. The complete list of persons on the RMS titanic was downloaded from https://www.encyclopedia-titanica.org on April 5, 2016. The information given in sibsp and parch was adopoted from a data set obtained from http://biostat.mc.vanderbilt.edu/DataSets.

titanic_r <- na.omit(select(titanic, -c(country)) %>%
                       filter(fare > 0))
head(titanic_r)
##   gender age class    embarked  fare sibsp parch survived
## 1   male  42   3rd Southampton  7.11     0     0       no
## 2   male  13   3rd Southampton 20.05     0     2       no
## 3   male  16   3rd Southampton 20.05     1     1       no
## 4 female  39   3rd Southampton 20.05     1     1      yes
## 5 female  16   3rd Southampton  7.13     0     0      yes
## 6   male  25   3rd Southampton  7.13     0     0      yes

2.1 The explain() function

The first step of using the DALEX package is to wrap-up the black-box model with meta-data that unifies model interfacing.

First, we create a train and test sets which ones are needed to train the parsnip models when we don’t have an additional test set given.

set.seed(123)
titanic_r$survived <- factor(titanic_r$survived)
train_index <- sample(1:nrow(titanic_r), 0.7 * nrow(titanic_r))
test_index <- setdiff(1:nrow(titanic_r), train_index)
titanic_r_test <- titanic_r[test_index,]

In this vignette we will use three models: boosting tree, single layer neural network, support vector machines for regression.

According to the semantics of the parnsip package at the beginning we have to make our regression model. In contrast to other similar packages, like mlr or caret, task and learner are one object which may be tuned in many ways at further steps.

bt_model <- boost_tree(trees = 2000, mtry = 4, mode = "regression") %>%  set_engine("xgboost")
nn_model <- mlp(penalty = 10, epochs = varying(), mode = "regression") %>% set_engine("nnet")
svm_model<- svm_rbf(mode = "regression", rbf_sigma = 0.2) %>% set_engine("kernlab")

Additionally, for the neural network model we set epochs parameter which was passed as placeholder thanks to varying() function.

nn_model <- nn_model %>% update(epochs = 30)

Below, we use the parsnip function fit() to train our models.

bt_fitted <- bt_model%>% fit(fare ~ ., data = titanic_r[train_index,])
nn_fitted <- nn_model%>% fit(fare ~ ., data = titanic_r[train_index,])
svm_fitted <- svm_model%>% fit(fare ~ ., data = titanic_r[train_index,])

To create an explainer for these models it is enough to use explain() function with the model, data and y parameters. Validation dataset for the models is titanic_r_test data from the DALEX package. For the models created by parsnip package we have to provide custom predict function which takes two arguments: model and newdata and returns a numeric vector with predictions because function predict() from parsnip returns predictions in object not recognised by explain() function.

data(apartmentsTest)
custom_predict <- function(object, newdata) {pred <- predict(object, newdata)
                                              response <- pred$.pred
                                              return(response)}

explainer_regr_bt <- DALEX::explain(bt_fitted, 
                                    data = titanic_r_test, 
                                    y = titanic_r_test$fare, 
                                    predict_function = custom_predict, 
                                    label = "bt", 
                                    colorize = FALSE,
                                    verbose = FALSE)
explainer_regr_nn <- DALEX::explain(nn_fitted,
                                    data = titanic_r_test,
                                    y = titanic_r_test$fare,
                                    predict_function = custom_predict,
                                    label = "nn",
                                    colorize = FALSE,
                                    verbose = FALSE)
explainer_regr_svm <- DALEX::explain(svm_fitted,
                                     data = titanic_r_test,
                                     y = titanic_r_test$fare,
                                     predict_function = custom_predict, 
                                     label = "svm",
                                     colorize = FALSE,
                                     verbose = FALSE)

2.2 Model performance

Function model_performance() calculates predictions and residuals for validation dataset.

mp_regr_bt <- model_performance(explainer_regr_bt)
mp_regr_nn <- model_performance(explainer_regr_nn)
mp_regr_svm <- model_performance(explainer_regr_svm)

Generic function print() returns deciles for residuals.

mp_regr_bt
## Measures for:  regression
## mse        : 3004.486 
## rmse       : 54.81319 
## r2         : -0.1617351 
## mad        : 6.056754
## 
## Residuals:
##           0%          10%          20%          30%          40%          50% 
## -344.0690149  -27.3608732   -9.6301566   -3.2425701   -1.0582287   -0.1446486 
##          60%          70%          80%          90%         100% 
##    1.0915377    4.1375328   10.5687443   23.0671528  474.1560865

Generic function plot() shows reversed empirical cumulative distribution function for absolute values from residuals. Plots can be generated for one or more models.

plot(mp_regr_bt, mp_regr_nn, mp_regr_svm)

The figure above shows residuals are similiar for all of three models used above.

We are also able to use the plot() function to get an alternative comparison of residuals. Setting the geom = "boxplot" parameter we can compare the distribution of residuals for selected models.

plot(mp_regr_bt, mp_regr_nn, mp_regr_svm, geom = "boxplot")

2.3 Variable importance

Using he DALEX package we are able to better understand which variables are important.

Model agnostic variable importance is calculated by means of permutations. We simply substract the loss function calculated for validation dataset with permuted values for a single variable from the loss function calculated for validation dataset.

This method is implemented in the model_parts() function.

vi_regr_bt <- model_parts(explainer_regr_bt, loss_function = loss_root_mean_square)
vi_regr_nn <- model_parts(explainer_regr_nn, loss_function = loss_root_mean_square)
vi_regr_svm <- model_parts(explainer_regr_svm, loss_function = loss_root_mean_square)

We can compare all models using the generic plot() function.

plot(vi_regr_bt, vi_regr_nn, vi_regr_svm)

Length of the interval coresponds to a variable importance. Longer interval means larger loss, so the variable is more important.

For better comparison of the models we can hook the variabe importance at 0 using the type=difference.

vi_regr_bt <- model_parts(explainer_regr_bt, 
                          loss_function = loss_root_mean_square, 
                          type = "difference")
vi_regr_nn <- model_parts(explainer_regr_nn, 
                          loss_function = loss_root_mean_square, 
                          type = "difference")
vi_regr_svm <- model_parts(explainer_regr_svm, 
                           loss_function = loss_root_mean_square,
                           type = "difference")

plot(vi_regr_bt, vi_regr_nn, vi_regr_svm)

We see that in boosting tree and neural network model the most important variable is class, surprisingly survived variable was not so important.

2.4 Variable response

Explainers presented in this section are designed to better understand the relation between a variable and model output.

2.5 Break down plots

We have already shown how to check variable importance and many other model’s explenations using DALEX package. All of those refered to global behaviour of our models. Now we will look at specific passenger, twenty seven years old woman that has been traveling in the first class, and look how our model predicted her fare.

pb_bt <- predict_parts(explainer_regr_bt, 
                       new_observation = filter(titanic_r_test, 
                                            gender == "female", 
                                            age == 27, 
                                            class == "1st")[1,],
                       type = "break_down")
pb_nn <- predict_parts(explainer_regr_nn, 
                       new_observation = filter(titanic_r_test, 
                                            gender == "female", 
                                            age == 27, 
                                            class == "1st")[1,],
                       type = "break_down")
pb_svm <- predict_parts(explainer_regr_svm, 
                       new_observation = filter(titanic_r_test, 
                                            gender == "female", 
                                            age == 27, 
                                            class == "1st")[1,],
                       type = "break_down")
plot(pb_bt)

plot(pb_nn)

plot(pb_svm)

3 Classification use case - titanic

As previously, in regression use case, to illustrate applications of DALEX to classification problems we are going to use titanic dataset. We want to classify if specified passager has survived Titanic’s maiden voyage. Originally dataset had factor with two levels, yes and no, at target column, but we have changed it to 1 and 0. Reasons will be explained below. Our classification will be based on seven features from this data set, we drop country column.

library(parsnip)
library(dplyr)
library(DALEX)
titanic_num <- na.omit(select(titanic, -c(country)))
titanic_num$survived <-  (titanic_num$survived %>% as.numeric()) - 1

First, we once again create a train and test sets which ones are needed to train the parsnip models when we don’t have an additional test set given.

set.seed(123)
titanic_num$survived <- factor(titanic_num$survived)
train_index <- sample(1:nrow(titanic_num), 0.7 * nrow(titanic_num))
test_index <- setdiff(1:nrow(titanic_num), train_index)
titanic_num_test <- titanic_num[test_index,]

In this vignette, we will use 3 models: random forest, logistic regression and support vector machines for classification.

According to the semantics of the parsnip package we have to make our tasks. As it was mentioned in the regression use case, be aware of differences bewteen parsnip and other similar packages, like mlr or caret. Task and learner are one object here.

rf_model <- rand_forest(trees = 2000, mtry = 4, mode = "classification") %>%  set_engine("ranger")
lr_model <- logistic_reg(penalty = 10, mixture = 0.1, mode = "classification") %>% set_engine("glm")
svm_model <- svm_rbf(mode = "classification", rbf_sigma = 0.2) %>% set_engine("kernlab")

Next, we use fit() to train our models.

rf_fitted <- rf_model%>% fit(survived ~ ., data = titanic_num[train_index,])
lr_fitted <- lr_model%>% fit(survived ~ ., data = titanic_num[train_index,])
svm_fitted <- svm_model%>% fit(survived ~ ., data = titanic_num[train_index,])

As previously, to create an explainer for these models we use explain() function. Validation dataset for the models is titanic_num_test.

In this case we consider the differences between observed class and predicted probabilities to be residuals. So, we have to provide custom predict function which takes two arguments: model and newdata and returns a numeric vector with probabilities. Keep in mind, that if you want to use default loss function while creating model_performance object later one, you have to change our factor target vector to numeric right now.

y_test <- as.numeric(as.character(titanic_num_test$survived))
custom_predict_classif <- function(objectPred, set){
  as.data.frame(predict(objectPred, set, type = "prob"))[,2]
}
explainer_classif_rf <- DALEX::explain(rf_fitted,
                                       data = titanic_num_test, 
                                       y = y_test, 
                                       label = "rf",
                                       predict_function = custom_predict_classif, 
                                       colorize = FALSE,
                                       verbose = FALSE)
explainer_classif_lr <- DALEX::explain(lr_fitted, 
                                       data = titanic_num_test, 
                                       y = y_test,
                                       label = "lr", 
                                       predict_function = custom_predict_classif,
                                       colorize = FALSE,
                                       verbose = FALSE)
explainer_classif_svm <- DALEX::explain(svm_fitted,
                                        data = titanic_num_test, 
                                        y = y_test, 
                                        label = "svm", 
                                        predict_function = custom_predict_classif,
                                        colorize = FALSE,
                                        verbose = FALSE)

3.1 Model performance

Function model_performance() calculates predictions and residuals for validation dataset titanic_num_test.

We use the generic plot() function to get a comparison of models.

mp_classif_rf <- model_performance(explainer_classif_rf)
mp_classif_lr <- model_performance(explainer_classif_lr)
mp_classif_svm <- model_performance(explainer_classif_svm)
plot(mp_classif_rf, mp_classif_lr, mp_classif_svm)

Setting the geom = "boxplot" parameter let us compare the distribution of residuals for selected models.

plot(mp_classif_rf, mp_classif_lr, mp_classif_svm, geom = "boxplot")

3.2 Variable importance

Function models_parts() computes variable importances. Output may be plotted with generic function plot()

vi_classif_rf <- model_parts(explainer_classif_rf, loss_function = loss_root_mean_square)
vi_classif_lr <- model_parts(explainer_classif_lr, loss_function = loss_root_mean_square)
vi_classif_svm <- model_parts(explainer_classif_svm, loss_function = loss_root_mean_square)
plot(vi_classif_rf, vi_classif_lr, vi_classif_svm)

Left edges of intervals start in full model. Length of the interval coresponds to a variable importance. Longer interval means larger loss, so the variable is more important.

3.3 Variable response

As previously we create explainers which are designed to better understand the relation between a variable and model output: PDP plots and ALE plots.

3.3.1 Partial Depedence Plot

Partial Dependence Plots (PDP) are one of the most popular methods for exploration of the relation between a continuous variable and the model outcome.

Function model_profile() with the parameter type = "partial" to calculate PDP response.

pdp_classif_rf  <- model_profile(explainer_classif_rf, variable = "fare", type = "partial")
pdp_classif_lr  <- model_profile(explainer_classif_lr, variable = "fare", type = "partial")
pdp_classif_svm  <- model_profile(explainer_classif_svm, variable = "fare", type = "partial")
plot(pdp_classif_rf, pdp_classif_lr, pdp_classif_svm)

We use PDP plots to compare our 3 models. As we can see above performance of random forest may tell us that we have non-linear relation in the data. It looks like the linear regression did not capture that relation.

3.3.2 Acumulated Local Effects plot

Acumulated Local Effects (ALE) plot is the extension of PDP, that is more suited for highly correlated variables.

Function model_profile() with the parameter type = "accumulated" to calculate the ALE curve for the variable fare.

ale_classif_rf  <- model_profile(explainer_classif_rf, variable = "fare", type = "accumulated")
ale_classif_lr  <- model_profile(explainer_classif_lr, variable = "fare", type = "accumulated")
ale_classif_svm  <- model_profile(explainer_classif_svm, variable = "fare", type = "accumulated")
plot(ale_classif_rf, ale_classif_lr, ale_classif_svm)

3.4 Break down plots

As previously in regression use case, we will look at specific passenger now. This time we will see hove our model predicted survival of two years old girl that was travelling at 3rd class.

pb_rf <- predict_parts(explainer_classif_rf, 
                       new_observation = filter(titanic_num_test, 
                                            gender == "female",
                                            age == 2,
                                            class == "3rd")[1,],
                       type = "break_down")
pb_lr <- predict_parts(explainer_classif_lr, 
                       new_observation = filter(titanic_num_test, 
                                            gender == "female",
                                            age == 2,
                                            class == "3rd")[1,],
                       type = "break_down")
pb_svm <- predict_parts(explainer_classif_svm, 
                       new_observation = filter(titanic_num_test, 
                                            gender == "female",
                                            age == 2,
                                            class == "3rd")[1,],
                       type = "break_down")
plot(pb_rf)

plot(pb_lr)

plot(pb_svm)

4 Session inof

sessionInfo()
## R version 3.6.3 (2020-02-29)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 18363)
## 
## Matrix products: default
## 
## locale:
## [1] LC_COLLATE=Polish_Poland.1250  LC_CTYPE=Polish_Poland.1250   
## [3] LC_MONETARY=Polish_Poland.1250 LC_NUMERIC=C                  
## [5] LC_TIME=Polish_Poland.1250    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] dplyr_1.0.0   parsnip_0.1.4 DALEX_2.0.1  
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_1.0.4        pillar_1.4.3      compiler_3.6.3    ingredients_2.0  
##  [5] tools_3.6.3       digest_0.6.25     lattice_0.20-38   evaluate_0.14    
##  [9] lifecycle_0.2.0   tibble_2.1.3      gtable_0.3.0      pkgconfig_2.0.3  
## [13] rlang_0.4.6       Matrix_1.2-18     yaml_2.2.1        xfun_0.12        
## [17] ranger_0.12.1     stringr_1.4.0     knitr_1.28        generics_0.0.2   
## [21] vctrs_0.3.1       nnet_7.3-12       xgboost_1.2.0.1   grid_3.6.3       
## [25] tidyselect_1.1.0  data.table_1.12.8 glue_1.3.2        R6_2.4.1         
## [29] iBreakDown_1.3.1  rmarkdown_2.1     farver_2.0.3      kernlab_0.9-29   
## [33] ggplot2_3.3.0     purrr_0.3.3       tidyr_1.0.2       magrittr_1.5     
## [37] scales_1.1.0      htmltools_0.4.0   ellipsis_0.3.0    colorspace_1.4-1 
## [41] labeling_0.3      stringi_1.4.6     munsell_0.5.0     crayon_1.3.4