DALEX explainers may be used to see what type of relation the model can learn / what the model has learned.
If we know the ground truth then we may verify model capability of learning particular types of relations.
Let’s simulate a model response as a function of four arguments
\[ (2x_1-1)^2 + sin(10 x_2) + x_3^{6} + (2 x_4 - 1) + |2x_5-1| \]
Let’s compare four models: fandom forest, svm, lm and the ground truth.
library(randomForest)
library(DALEX)
library(e1071)
library(rms)
df <- data.frame(y, X1, X2, X3, X4, X5)
model_rf <- randomForest(y~., df)
model_svm <- svm(y ~ ., df)
model_lm <- lm(y ~ ., df)
# thanks to https://github.com/pbiecek/DALEX/issues/24
## important setup step required for use of rms functions
dd <- datadist(df)
options(datadist="dd")
## add rcs terms to linear model
## this is a very convenient, objective way to account for non-linearity
## still a "linear" model because terms are linear combinations (additive)
model_rms <- ols(y ~ rcs(X1) + rcs(X2) + rcs(X3) + rcs(X4) + rcs(X5), df)
ex_rf <- explain(model_rf, data = df, y = df$y)
ex_svm <- explain(model_svm, data = df, y = df$y)
ex_lm <- explain(model_lm, data = df, y = df$y)
ex_rms <- explain(model_rms, label = "rms", data = df, y = df$y)
ex_tr <- explain(NULL, data = df[,-1],
predict_function = function(m, x) f(x[,1], x[,2], x[,3], x[,4], x[,5]),
label = "True Model")
For X1
we want to see (2*x1 - 1)^2
.
The linear model cannot guess the relation without prior preprocessing, the random forest is seeing something but the closest bet is from svm models.
library(ggplot2)
plot(model_profile(ex_rf, "X1"),
model_profile(ex_svm, "X1"),
model_profile(ex_lm, "X1"),
model_profile(ex_rms, "X1"),
model_profile(ex_tr, "X1")) +
ggtitle("Responses for X1. Truth: y ~ (2*x1 - 1)^2")
For X2
we want to see sin(10 * x2)
.
The random forest guesses the shape, svm is not that elastic, the linear model does not see anything.
plot(model_profile(ex_rf, "X2"),
model_profile(ex_svm, "X2"),
model_profile(ex_lm, "X2"),
model_profile(ex_rms, "X2"),
model_profile(ex_tr, "X2")) +
ggtitle("Responses for X2. Truth: y ~ sin(10 * x2)")
For X3
we want to see x3^6
.
The random forest is still able to guesses the shape, svm and linear are close.
plot(model_profile(ex_rf, "X3"),
model_profile(ex_svm, "X3"),
model_profile(ex_lm, "X3"),
model_profile(ex_rms, "X3"),
model_profile(ex_tr, "X3")) +
ggtitle("Responses for X3. Truth: y ~ x3^6")
For X4
we want to see 2 x4 - 1
.
The linear model is doing the best job (as expected), svm are still pretty good, random forest model is more biased towards the mean.
plot(model_profile(ex_rf, "X4"),
model_profile(ex_svm, "X4"),
model_profile(ex_lm, "X4"),
model_profile(ex_rms, "X4"),
model_profile(ex_tr, "X4")) +
ggtitle("Responses for X4. Truth: y ~ (2 * x4 - 1)")
For X5
we want to see |2 x5 - 1|
.
All models except the linear one are guessing the shape.
plot(model_profile(ex_rf, "X5"),
model_profile(ex_svm, "X5"),
model_profile(ex_lm, "X5"),
model_profile(ex_rms, "X5"),
model_profile(ex_tr, "X5")) +
ggtitle("Responses for X5. Truth: y ~ |2 * x5 - 1|")
## R version 3.6.3 (2020-02-29)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 18363)
##
## Matrix products: default
##
## locale:
## [1] LC_COLLATE=Polish_Poland.1250 LC_CTYPE=Polish_Poland.1250
## [3] LC_MONETARY=Polish_Poland.1250 LC_NUMERIC=C
## [5] LC_TIME=Polish_Poland.1250
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] rms_5.1-3 SparseM_1.78 Hmisc_4.3-1
## [4] ggplot2_3.3.0 Formula_1.2-3 survival_3.1-8
## [7] lattice_0.20-38 e1071_1.7-3 DALEX_2.0.1
## [10] randomForest_4.6-14
##
## loaded via a namespace (and not attached):
## [1] Rcpp_1.0.4 mvtnorm_1.1-0 png_0.1-7
## [4] class_7.3-15 zoo_1.8-7 digest_0.6.25
## [7] R6_2.4.1 backports_1.1.5 acepack_1.4.1
## [10] MatrixModels_0.4-1 evaluate_0.14 pillar_1.4.3
## [13] rlang_0.4.6 multcomp_1.4-12 rstudioapi_0.11
## [16] data.table_1.12.8 rpart_4.1-15 Matrix_1.2-18
## [19] checkmate_2.0.0 rmarkdown_2.1 labeling_0.3
## [22] splines_3.6.3 stringr_1.4.0 foreign_0.8-76
## [25] htmlwidgets_1.5.1 munsell_0.5.0 compiler_3.6.3
## [28] xfun_0.12 pkgconfig_2.0.3 base64enc_0.1-3
## [31] htmltools_0.4.0 nnet_7.3-12 tidyselect_1.1.0
## [34] tibble_2.1.3 gridExtra_2.3 htmlTable_1.13.3
## [37] codetools_0.2-16 crayon_1.3.4 dplyr_1.0.0
## [40] withr_2.1.2 MASS_7.3-51.5 grid_3.6.3
## [43] nlme_3.1-144 polspline_1.1.17 gtable_0.3.0
## [46] lifecycle_0.2.0 magrittr_1.5 scales_1.1.0
## [49] stringi_1.4.6 farver_2.0.3 latticeExtra_0.6-29
## [52] generics_0.0.2 vctrs_0.3.1 sandwich_2.5-1
## [55] TH.data_1.0-10 RColorBrewer_1.1-2 tools_3.6.3
## [58] glue_1.3.2 purrr_0.3.3 ingredients_2.0
## [61] jpeg_0.1-8.1 yaml_2.2.1 colorspace_1.4-1
## [64] cluster_2.1.0 knitr_1.28 quantreg_5.54