Machine learning models are being developed in different tools and different languages. On one hand it is great as everyone can choose tool that suits his needs, but on another hand it is hard to compare such models in a different way that simply by comparing accuracy scores.
In this vignette we will show how DALEX can be used for comparison of models across different languages.
We trained four models, gbm and CatBoost in R, h2o implementation of in java and scikit-learn implementation of gbm in Python. Then we visually explore their similarities and differences through DALEX explainers.
We use titanic dataset. It is divided into titanic_test and titanic_train and stored in csv files. For this dataset we will train binary classifiers that predicts probability of survival from Titanic disaster.
kable(titanic_test_X %>% head(), "html") %>%
  kable_styling("striped") %>%
  scroll_box(width = "100%")| gender.female | gender.male | age | class.1st | class.2nd | class.3rd | class.deck.crew | class.engineering.crew | class.restaurant.staff | class.victualling.crew | embarked.Belfast | embarked.Cherbourg | embarked.Queenstown | embarked.Southampton | fare | sibsp | parch | 
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 0 | 16 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 7.1300 | 0 | 0 | 
| 0 | 1 | 25 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 7.1300 | 0 | 0 | 
| 1 | 0 | 28 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 24.0000 | 1 | 0 | 
| 0 | 1 | 20 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 7.1806 | 0 | 0 | 
| 0 | 1 | 30 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 7.0500 | 0 | 0 | 
| 1 | 0 | 19 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 7.1701 | 1 | 0 | 
kable(titanic_train_X %>% head(), "html") %>%
  kable_styling("striped") %>%
  scroll_box(width = "100%")| gender.female | gender.male | age | class.1st | class.2nd | class.3rd | class.deck.crew | class.engineering.crew | class.restaurant.staff | class.victualling.crew | embarked.Belfast | embarked.Cherbourg | embarked.Queenstown | embarked.Southampton | fare | sibsp | parch | 
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 42 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 7.1100 | 0 | 0 | 
| 0 | 1 | 13 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 20.0500 | 0 | 2 | 
| 0 | 1 | 16 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 20.0500 | 1 | 1 | 
| 1 | 0 | 39 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 20.0500 | 1 | 1 | 
| 0 | 1 | 30 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 24.0000 | 1 | 0 | 
| 0 | 1 | 27 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 18.1509 | 0 | 0 | 
First, we train R implementation of gbm which we will acces thorugh mlr package as a wrapper. Specifying most of parameters helps us fighting similar models across languages, at least in theory.
library("mlr")
set.seed(123, "L'Ecuyer")
task <- makeClassifTask(
            id = "R",
            data = cbind(titanic_train_X, titanic_train_Y),
            target = "survived"
          )
learner <- makeLearner(
            "classif.gbm",
            par.vals = list(
              distribution = "bernoulli",
              n.trees = 5000,
              interaction.depth = 4,
              n.minobsinnode = 12,
              shrinkage = 0.001,
              bag.fraction = 0.5,
              train.fraction = 1
            ),
            predict.type = "prob"
          )
r_gbm <- train(learner, task)
performance(predict(r_gbm, 
                    newdata = cbind(titanic_test_X, titanic_test_Y)), 
            measures = auc)##      auc 
## 0.826275
CatBoost is a machine learning algorithm that uses gradient boosting on decision trees. It is similiar in spirit to gbm and we will see how similar are trained models.
library("catboost")
pool_train <- catboost.load_pool(titanic_train_X, titanic_train_Y$survived)
pool_test  <- catboost.load_pool(titanic_test_X)
r_catboost <- catboost.train(pool_train, 
               test_pool = NULL, 
               params = list(
                 custom_loss = "AUC",
                 iterations = 5000,
                 depth = 4,
                 logging_level = "Silent"
               ))
preds <- catboost.predict(r_catboost, pool_test, prediction_type = "Probability")## [1] 0.8146678
We will access H2O models via h2o R package. Using h2o documentation we are able to match as many parameters as it possible which will help objectively compare models.
set.seed(123, "L'Ecuyer")
java_h2o_gbm <- h2o.gbm(
                  training_frame = titanic_h2o,
                  y = "survived",
                  distribution = "bernoulli",
                  ntrees = 5000,
                  max_depth = 4,
                  min_rows =  12,
                  learn_rate = 0.001
                )
h2o.auc(h2o.performance(java_h2o_gbm, newdata = titanic_test_h2o))## [1] 0.8146509
Inspection of models that have been created at Python via R is not as hard as it may seem to. It is possible thanks to two packages. reticulate available at R and pickle from Python.
from pandas import DataFrame, read_csv
import pandas as pd 
import pickle
import sklearn.ensemble
from sklearn.metrics import auc, accuracy_score, confusion_matrix, mean_squared_error
titanic_train_X = pd.read_csv("titanic_train.csv").drop("survived", axis=1)
titanic_train_Y = pd.read_csv("titanic_train.csv").survived
model = sklearn.ensemble.GradientBoostingClassifier(
  n_estimators= 5000,
  learning_rate=0.001, 
  max_depth=4, 
  min_samples_split = 12
)
model = model.fit(titanic_train_X, titanic_train_Y)
pickle.dump(model, open("gbm.pkl", "wb"))library("reticulate")
# if needed install
# py_install("pandas")
# py_install("scikit-learn")
python_scikitlearn_gbm <- py_load_object("gbm.pkl", pickle = "pickle")
preds <- python_scikitlearn_gbm$predict_proba(titanic_test_X)[, 2]
mltools::auc_roc(preds, y)## [1] 0.8244396
scikit-learn turned up to be better than h2o and slightly better than R model. But is is a big difference? Let’s explore these models in details.
Because all four packages return slightly different objects, we have to create DALEX wrappers around them. For 3 of them we will use DALEXtra package, and forcatboost we will use custom predict function.
# catboost wrapper 
catboost_predict <- function(object, newdata) {
  newdata_pool <- catboost.load_pool(newdata)
  return( catboost.predict(object, newdata_pool, prediction_type = "Probability"))
}Now we can create DALEX wrappers around our models.
library("DALEX")
library("DALEXtra")
r_explain <- DALEXtra::explain_mlr(r_gbm,
                            data = titanic_test_X, 
                            y = y, 
                            label = "gmb (R)",
                            type = "classification",
                            verbose = FALSE)
catboost_explain <- DALEX::explain(r_catboost,
                                   data = titanic_test_X,
                                   y = y, 
                                   label = "CatBoost (R)",
                                   predict_function = catboost_predict,
                                   type = "classification",
                                   verbose = FALSE)
h2o_explain <- DALEXtra::explain_h2o(java_h2o_gbm,
                              data = titanic_test_X, 
                              y = y, 
                              label = "gbm (h2o/java)",
                              type = "classification",
                              verbose = FALSE)
py_explain <- DALEXtra::explain_scikitlearn("gbm.pkl",
                      data = titanic_test_X, 
                      y = y, 
                      label = "gbm (python/sklearn)",
                      type = "classification",
                      verbose = FALSE)With explainers ready, we can compare our models in order to find possible differences. Models performance and residual distribution gets our first look.
plot(
  model_performance(r_explain),
  model_performance(h2o_explain),
  model_performance(py_explain),
  model_performance(catboost_explain)
  )As we can see, models some models differ from each other altough the java and python models are quite similar.
Here the drop in 1 - AUC is used to compare variable performance. Keep in mind that when defining custom_loss_function you have to provide arguments in correct order. Real values of y first and predicted second.
custom_loss_function <- function(y, yhat) {
  1 - mltools::auc_roc(yhat, y)
}
mp1 <- model_parts(r_explain, type = "difference", loss_function = custom_loss_function)
mp2 <- model_parts(h2o_explain, type = "difference", loss_function = custom_loss_function)
mp3 <- model_parts(py_explain, type = "difference", loss_function = custom_loss_function)
mp4 <- model_parts(catboost_explain, type = "difference", loss_function = custom_loss_function)Let’s see what features are important in in 2 selected models - created using h2o and sklearn.
We can see significant difference. h2o model figured out correlation between gender.male and gender.female and dropped one of them. sklearn (and other models too) use both of those columns. What is interesting, next four most significant variables are the same.
pdp_r   <- model_profile(r_explain, 
                          variable_splits = list(fare = seq(0,100,0.1)))
pdp_h2o  <- model_profile(h2o_explain, 
                          variable_splits = list(fare = seq(0,100,0.1)))
pdp_py <- model_profile(py_explain,
                         variable_splits = list(fare = seq(0,100,0.1)))
pdp_catboost <- model_profile(catboost_explain, 
                         variable_splits = list(fare = seq(0,100,0.1)))
plot(pdp_r, pdp_h2o, pdp_py, pdp_catboost) We can see the difference in how our models behave for different values of fare.
pdp_r   <- model_profile(r_explain, 
                          variable_splits = list(age = seq(0,80,0.1)))
pdp_h2o  <- model_profile(h2o_explain, 
                          variable_splits = list(age = seq(0,80,0.1)))
pdp_py <- model_profile(py_explain,
                         variable_splits = list(age = seq(0,80,0.1)))
pdp_catboost <- model_profile(catboost_explain, 
                         variable_splits = list(age = seq(0,80,0.1)))
plot(pdp_r, pdp_h2o, pdp_py, pdp_catboost)## R version 4.0.2 (2020-06-22)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 18363)
## 
## Matrix products: default
## 
## Random number generation:
##  RNG:     L'Ecuyer-CMRG 
##  Normal:  Inversion 
##  Sample:  Rejection 
##  
## locale:
## [1] LC_COLLATE=English_United States.1252 
## [2] LC_CTYPE=English_United States.1252   
## [3] LC_MONETARY=English_United States.1252
## [4] LC_NUMERIC=C                          
## [5] LC_TIME=English_United States.1252    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] DALEXtra_2.0      DALEX_2.0.1       reticulate_1.18   h2o_3.32.0.1     
##  [5] catboost_0.20     mlr_2.18.0        ParamHelpers_1.14 ggplot2_3.3.3    
##  [9] kableExtra_1.3.1  dplyr_1.0.2      
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_1.0.5        lattice_0.20-41   digest_0.6.25     R6_2.4.1         
##  [5] backports_1.1.10  evaluate_0.14     httr_1.4.2        highr_0.8        
##  [9] pillar_1.4.6      mltools_0.3.5     rlang_0.4.7       rstudioapi_0.11  
## [13] data.table_1.13.4 Matrix_1.2-18     checkmate_2.0.0   rmarkdown_2.6    
## [17] labeling_0.3      splines_4.0.2     webshot_0.5.2     stringr_1.4.0    
## [21] RCurl_1.98-1.2    bit_4.0.4         munsell_0.5.0     compiler_4.0.2   
## [25] xfun_0.19         pkgconfig_2.0.3   gbm_2.1.8         BBmisc_1.11      
## [29] htmltools_0.5.0   tidyselect_1.1.0  tibble_3.0.3      codetools_0.2-16 
## [33] XML_3.99-0.5      viridisLite_0.3.0 crayon_1.3.4      withr_2.3.0      
## [37] bitops_1.0-6      rappdirs_0.3.1    grid_4.0.2        jsonlite_1.7.2   
## [41] gtable_0.3.0      lifecycle_0.2.0   magrittr_2.0.1    scales_1.1.1     
## [45] stringi_1.5.3     farver_2.0.3      parallelMap_1.5.0 xml2_1.3.2       
## [49] ellipsis_0.3.1    generics_0.1.0    vctrs_0.3.4       fastmatch_1.1-0  
## [53] tools_4.0.2       bit64_4.0.5       glue_1.4.2        purrr_0.3.4      
## [57] ingredients_2.0   parallel_4.0.2    survival_3.1-12   yaml_2.2.1       
## [61] colorspace_1.4-1  rvest_0.3.6       knitr_1.30