Machine learning models are being developed in different tools and different languages. On one hand it is great as everyone can choose tool that suits his needs, but on another hand it is hard to compare such models in a different way that simply by comparing accuracy scores.
In this vignette we will show how DALEX
can be used for comparison of models across different languages.
We trained four models, gbm
and CatBoost
in R, h2o
implementation of in java and scikit-learn
implementation of gbm in Python. Then we visually explore their similarities and differences through DALEX explainers.
We use titanic
dataset. It is divided into titanic_test
and titanic_train
and stored in csv files. For this dataset we will train binary classifiers that predicts probability of survival from Titanic disaster.
kable(titanic_test_X %>% head(), "html") %>%
kable_styling("striped") %>%
scroll_box(width = "100%")
gender.female | gender.male | age | class.1st | class.2nd | class.3rd | class.deck.crew | class.engineering.crew | class.restaurant.staff | class.victualling.crew | embarked.Belfast | embarked.Cherbourg | embarked.Queenstown | embarked.Southampton | fare | sibsp | parch |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 0 | 16 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 7.1300 | 0 | 0 |
0 | 1 | 25 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 7.1300 | 0 | 0 |
1 | 0 | 28 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 24.0000 | 1 | 0 |
0 | 1 | 20 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 7.1806 | 0 | 0 |
0 | 1 | 30 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 7.0500 | 0 | 0 |
1 | 0 | 19 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 7.1701 | 1 | 0 |
kable(titanic_train_X %>% head(), "html") %>%
kable_styling("striped") %>%
scroll_box(width = "100%")
gender.female | gender.male | age | class.1st | class.2nd | class.3rd | class.deck.crew | class.engineering.crew | class.restaurant.staff | class.victualling.crew | embarked.Belfast | embarked.Cherbourg | embarked.Queenstown | embarked.Southampton | fare | sibsp | parch |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 42 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 7.1100 | 0 | 0 |
0 | 1 | 13 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 20.0500 | 0 | 2 |
0 | 1 | 16 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 20.0500 | 1 | 1 |
1 | 0 | 39 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 20.0500 | 1 | 1 |
0 | 1 | 30 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 24.0000 | 1 | 0 |
0 | 1 | 27 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 18.1509 | 0 | 0 |
First, we train R implementation of gbm
which we will acces thorugh mlr
package as a wrapper. Specifying most of parameters helps us fighting similar models across languages, at least in theory.
library("mlr")
set.seed(123, "L'Ecuyer")
task <- makeClassifTask(
id = "R",
data = cbind(titanic_train_X, titanic_train_Y),
target = "survived"
)
learner <- makeLearner(
"classif.gbm",
par.vals = list(
distribution = "bernoulli",
n.trees = 5000,
interaction.depth = 4,
n.minobsinnode = 12,
shrinkage = 0.001,
bag.fraction = 0.5,
train.fraction = 1
),
predict.type = "prob"
)
r_gbm <- train(learner, task)
performance(predict(r_gbm,
newdata = cbind(titanic_test_X, titanic_test_Y)),
measures = auc)
## auc
## 0.826275
CatBoost
is a machine learning algorithm that uses gradient boosting on decision trees. It is similiar in spirit to gbm
and we will see how similar are trained models.
library("catboost")
pool_train <- catboost.load_pool(titanic_train_X, titanic_train_Y$survived)
pool_test <- catboost.load_pool(titanic_test_X)
r_catboost <- catboost.train(pool_train,
test_pool = NULL,
params = list(
custom_loss = "AUC",
iterations = 5000,
depth = 4,
logging_level = "Silent"
))
preds <- catboost.predict(r_catboost, pool_test, prediction_type = "Probability")
## [1] 0.8146678
We will access H2O models via h2o
R package. Using h2o documentation we are able to match as many parameters as it possible which will help objectively compare models.
set.seed(123, "L'Ecuyer")
java_h2o_gbm <- h2o.gbm(
training_frame = titanic_h2o,
y = "survived",
distribution = "bernoulli",
ntrees = 5000,
max_depth = 4,
min_rows = 12,
learn_rate = 0.001
)
h2o.auc(h2o.performance(java_h2o_gbm, newdata = titanic_test_h2o))
## [1] 0.8146509
Inspection of models that have been created at Python via R is not as hard as it may seem to. It is possible thanks to two packages. reticulate
available at R and pickle
from Python.
from pandas import DataFrame, read_csv
import pandas as pd
import pickle
import sklearn.ensemble
from sklearn.metrics import auc, accuracy_score, confusion_matrix, mean_squared_error
titanic_train_X = pd.read_csv("titanic_train.csv").drop("survived", axis=1)
titanic_train_Y = pd.read_csv("titanic_train.csv").survived
model = sklearn.ensemble.GradientBoostingClassifier(
n_estimators= 5000,
learning_rate=0.001,
max_depth=4,
min_samples_split = 12
)
model = model.fit(titanic_train_X, titanic_train_Y)
pickle.dump(model, open("gbm.pkl", "wb"))
library("reticulate")
# if needed install
# py_install("pandas")
# py_install("scikit-learn")
python_scikitlearn_gbm <- py_load_object("gbm.pkl", pickle = "pickle")
preds <- python_scikitlearn_gbm$predict_proba(titanic_test_X)[, 2]
mltools::auc_roc(preds, y)
## [1] 0.8244396
scikit-learn
turned up to be better than h2o
and slightly better than R model. But is is a big difference? Let’s explore these models in details.
Because all four packages return slightly different objects, we have to create DALEX wrappers around them. For 3 of them we will use DALEXtra
package, and forcatboost
we will use custom predict function.
# catboost wrapper
catboost_predict <- function(object, newdata) {
newdata_pool <- catboost.load_pool(newdata)
return( catboost.predict(object, newdata_pool, prediction_type = "Probability"))
}
Now we can create DALEX wrappers around our models.
library("DALEX")
library("DALEXtra")
r_explain <- DALEXtra::explain_mlr(r_gbm,
data = titanic_test_X,
y = y,
label = "gmb (R)",
type = "classification",
verbose = FALSE)
catboost_explain <- DALEX::explain(r_catboost,
data = titanic_test_X,
y = y,
label = "CatBoost (R)",
predict_function = catboost_predict,
type = "classification",
verbose = FALSE)
h2o_explain <- DALEXtra::explain_h2o(java_h2o_gbm,
data = titanic_test_X,
y = y,
label = "gbm (h2o/java)",
type = "classification",
verbose = FALSE)
py_explain <- DALEXtra::explain_scikitlearn("gbm.pkl",
data = titanic_test_X,
y = y,
label = "gbm (python/sklearn)",
type = "classification",
verbose = FALSE)
With explainers ready, we can compare our models in order to find possible differences. Models performance and residual distribution gets our first look.
plot(
model_performance(r_explain),
model_performance(h2o_explain),
model_performance(py_explain),
model_performance(catboost_explain)
)
As we can see, models some models differ from each other altough the java and python models are quite similar.
Here the drop in 1 - AUC
is used to compare variable performance. Keep in mind that when defining custom_loss_function
you have to provide arguments in correct order. Real values of y
first and predicted second.
custom_loss_function <- function(y, yhat) {
1 - mltools::auc_roc(yhat, y)
}
mp1 <- model_parts(r_explain, type = "difference", loss_function = custom_loss_function)
mp2 <- model_parts(h2o_explain, type = "difference", loss_function = custom_loss_function)
mp3 <- model_parts(py_explain, type = "difference", loss_function = custom_loss_function)
mp4 <- model_parts(catboost_explain, type = "difference", loss_function = custom_loss_function)
Let’s see what features are important in in 2 selected models - created using h2o
and sklearn
.
We can see significant difference. h2o
model figured out correlation between gender.male
and gender.female
and dropped one of them. sklearn
(and other models too) use both of those columns. What is interesting, next four most significant variables are the same.
pdp_r <- model_profile(r_explain,
variable_splits = list(fare = seq(0,100,0.1)))
pdp_h2o <- model_profile(h2o_explain,
variable_splits = list(fare = seq(0,100,0.1)))
pdp_py <- model_profile(py_explain,
variable_splits = list(fare = seq(0,100,0.1)))
pdp_catboost <- model_profile(catboost_explain,
variable_splits = list(fare = seq(0,100,0.1)))
plot(pdp_r, pdp_h2o, pdp_py, pdp_catboost)
We can see the difference in how our models behave for different values of fare
.
pdp_r <- model_profile(r_explain,
variable_splits = list(age = seq(0,80,0.1)))
pdp_h2o <- model_profile(h2o_explain,
variable_splits = list(age = seq(0,80,0.1)))
pdp_py <- model_profile(py_explain,
variable_splits = list(age = seq(0,80,0.1)))
pdp_catboost <- model_profile(catboost_explain,
variable_splits = list(age = seq(0,80,0.1)))
plot(pdp_r, pdp_h2o, pdp_py, pdp_catboost)
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 18363)
##
## Matrix products: default
##
## Random number generation:
## RNG: L'Ecuyer-CMRG
## Normal: Inversion
## Sample: Rejection
##
## locale:
## [1] LC_COLLATE=English_United States.1252
## [2] LC_CTYPE=English_United States.1252
## [3] LC_MONETARY=English_United States.1252
## [4] LC_NUMERIC=C
## [5] LC_TIME=English_United States.1252
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] DALEXtra_2.0 DALEX_2.0.1 reticulate_1.18 h2o_3.32.0.1
## [5] catboost_0.20 mlr_2.18.0 ParamHelpers_1.14 ggplot2_3.3.3
## [9] kableExtra_1.3.1 dplyr_1.0.2
##
## loaded via a namespace (and not attached):
## [1] Rcpp_1.0.5 lattice_0.20-41 digest_0.6.25 R6_2.4.1
## [5] backports_1.1.10 evaluate_0.14 httr_1.4.2 highr_0.8
## [9] pillar_1.4.6 mltools_0.3.5 rlang_0.4.7 rstudioapi_0.11
## [13] data.table_1.13.4 Matrix_1.2-18 checkmate_2.0.0 rmarkdown_2.6
## [17] labeling_0.3 splines_4.0.2 webshot_0.5.2 stringr_1.4.0
## [21] RCurl_1.98-1.2 bit_4.0.4 munsell_0.5.0 compiler_4.0.2
## [25] xfun_0.19 pkgconfig_2.0.3 gbm_2.1.8 BBmisc_1.11
## [29] htmltools_0.5.0 tidyselect_1.1.0 tibble_3.0.3 codetools_0.2-16
## [33] XML_3.99-0.5 viridisLite_0.3.0 crayon_1.3.4 withr_2.3.0
## [37] bitops_1.0-6 rappdirs_0.3.1 grid_4.0.2 jsonlite_1.7.2
## [41] gtable_0.3.0 lifecycle_0.2.0 magrittr_2.0.1 scales_1.1.1
## [45] stringi_1.5.3 farver_2.0.3 parallelMap_1.5.0 xml2_1.3.2
## [49] ellipsis_0.3.1 generics_0.1.0 vctrs_0.3.4 fastmatch_1.1-0
## [53] tools_4.0.2 bit64_4.0.5 glue_1.4.2 purrr_0.3.4
## [57] ingredients_2.0 parallel_4.0.2 survival_3.1-12 yaml_2.2.1
## [61] colorspace_1.4-1 rvest_0.3.6 knitr_1.30