Based on the hal9001 R package by Rachael Phillips, Jeremy Coyle, and Nima Hejazi.

Objectives

By the end of this demo, you will be able to:

  • Fit a HAL with CV-selected penalty and a undermoothed HAL.
  • Interpret the results from HAL-fit.
  • Estimate the target parameter with HAL as a working model and obtain delta method based confidence intervals.
    • Average treatment effects.
    • Conditional average treatment effects.
    • Dose response curves.

1. Introduction

The Highly Adaptive Lasso (HAL) is a non-parametric machine learning algorithm that “creates efficient estimators of realistic statistical models of pathwise differential parameters. HAL creates efficient estimators of realistic statistical models of pathwise differential parameters. Furthermore, these models provide asymptotically normal estimators for non-pathwise-differential parameters, which can enable investigators to use HAL-based substitution estimators and the delta method rather than rely on parametric models.” [Highly Adaptive LASSO: Machine Learning That Provides Valid Nonparametric Inference in Realistic Models]

This github page is to provide example code for straightforward uses of HAL.

2. How to Fit HAL

In this section, the functions of fitting HAL is demonstrated.

Fitting any HAL consists of the following 2 steps:
1. Define the outcome vector, Y, and the covariates matrix X.
2. Fit the HAL with selected smoothness order and other parameters.

Running example with WASH Benefits dataset:
We will use the WASH Benefits Bangladesh study as an example to guide this demo of HAL-based plug-in estimator.

Preliminaries

First, we need to load the data, relevant functions, and relevant packages into the R session, and then make sure there is no missing values in the dataset.

Load the data

library("data.table")

washb_data <- fread(
  paste0(
    "https://raw.githubusercontent.com/tlverse/tlverse-data/master/",
    "wash-benefits/washb_data.csv"
  ),
  stringsAsFactors = TRUE
)

Take a look at the first few rows of WASH dataset:

head(washb_data)
##      whz      tr fracode month aged    sex momage          momedu momheight
## 1:  0.00 Control  N05265     9  268   male     30  Primary (1-5y)    146.40
## 2: -1.16 Control  N05265     9  286   male     25  Primary (1-5y)    148.75
## 3: -1.05 Control  N08002     9  264   male     25  Primary (1-5y)    152.15
## 4: -1.26 Control  N08002     9  252 female     28  Primary (1-5y)    140.25
## 5: -0.59 Control  N06531     9  336 female     19 Secondary (>5y)    150.95
## 6: -0.51 Control  N06531     9  304   male     20 Secondary (>5y)    154.20
##                     hfiacat Nlt18 Ncomp watmin elec floor walls roof
## 1:              Food Secure     3    11      0    1     0     1    1
## 2: Moderately Food Insecure     2     4      0    1     0     1    1
## 3:              Food Secure     1    10      0    0     0     1    1
## 4:              Food Secure     3     5      0    1     0     1    1
## 5:              Food Secure     2     7      0    1     0     1    1
## 6:   Severely Food Insecure     0     3      1    1     0     1    1
##    asset_wardrobe asset_table asset_chair asset_khat asset_chouki asset_tv
## 1:              0           1           1          1            0        1
## 2:              0           1           0          1            1        0
## 3:              0           0           1          0            1        0
## 4:              1           1           1          1            0        0
## 5:              1           1           1          1            1        0
## 6:              0           0           0          0            1        0
##    asset_refrig asset_bike asset_moto asset_sewmach asset_mobile
## 1:            0          0          0             0            1
## 2:            0          0          0             0            1
## 3:            0          0          0             0            1
## 4:            0          1          0             0            1
## 5:            0          0          0             0            1
## 6:            0          0          0             0            1

Remove missing values Here we simply keep the complete cases by removing the rows with missing values. You can choose any other methods to deal with missing values.

washb_data <- washb_data[complete.cases(washb_data), ]

With the goal of demonstrating how to use HAL, we will only use observations that had the control or WSH interventions. Then we will set the treatment variable (tr) to be 1 if being treated by WSH interventions.

washb_data_demo <- washb_data[washb_data$tr %in% c("Control", "WSH"), ]
washb_data_demo$tr = as.numeric(washb_data_demo$tr == "WSH")
table(washb_data_demo$tr)
## 
##    0    1 
## 1165  583

Install hal9001 pacakge (as needed)

To install any packge in R, we recommend clearing the R workspace and subsequently restarting the R session. In RStudio, you can accomplish this by navigating to the “Session” tab, selecting “Clear Workspace,” and then proceeding to click on “Session” again and choosing “Restart R.”

For standard use, we recommend installing the package from CRAN via:

install.packages("hal9001")

Load hal9001 package

Once the hal9001 package is installed, we can load it and other R packages:

library(hal9001)

load the dplyr package for easy and quick data manipulation.
load the glmnet package used in the undersmoothing HAL process.

library(dplyr)
library(glmnet)

Source relevant functions

You need to download the ‘functions_UHAL_deltaCI.R’ for the github repo. Then source it to load relevant functions for fitting Undermoothed HAL and obtainning confidence intervals using the delta method.

source('functions_UHAL_deltaCI.R')

2.1 Prepare data

First, identify the outcome variable and save it as a numeric vector. In this example, we are interested in weight-for-height z-score (whz) as the outcome, which is continuouts.

Y <- as.numeric(washb_data_demo$whz)
head(Y)
## [1]  0.00 -1.16 -1.05 -1.26 -0.59 -0.51

Then, identify the treatment variable and other covariates, save them as a numeric data frame or a numeric matrix. If the treatment or other variables are categorical or ordinal, we need to convert them into numeric or dummy variables.
For the WASH dataset, there are many choices of covariate combinations. Since the goal is to demonstrate the use of HAL, we will pick a smaller set of covariates with shorter run times.

# x_names = names(washb_data_demo)[names(washb_data_demo) != 'whz']
washb_data_demo$sexF = as.numeric(washb_data_demo$sex == "female")
x_names = c("tr", "month", "aged", "sexF", "momage", "momheight", "hfiacat", "Nlt18", "Ncomp")
X <- washb_data_demo %>% 
    select(all_of(x_names)) %>% 
    mutate_if(sapply(., is.factor), as.numeric)
head(X)
##    tr month aged sexF momage momheight hfiacat Nlt18 Ncomp
## 1:  0     9  268    0     30    146.40       1     3    11
## 2:  0     9  286    0     25    148.75       3     2     4
## 3:  0     9  264    0     25    152.15       1     1    10
## 4:  0     9  252    1     28    140.25       1     3     5
## 5:  0     9  336    1     19    150.95       1     2     7
## 6:  0     9  304    0     20    154.20       4     0     3

2.2 Fit HAL

2.2.1 HAL (CV-HAL)

To fit the CV-HAL using the prepared X and Y above, we call the fit_hal’ function from the fit_hal package. The following provided employs the default configuration to train HAL. This involves utilizing 10-fold cross-validation to select lambda, setting max_degree to 1, smoothness_orders to 1, family to “guassian”, and incorporating additional arguments like num_knots. You can look for fit_hal in the help pane in Rstudio or go to https://rdrr.io/cran/hal9001/man/fit_hal.html for more information. To have reproducible results, we will set a random number generator.

set.seed(123)

hal_fit <- fit_hal(X = X, Y = Y)

hal_fit$times
##                   user.self sys.self elapsed user.child sys.child
## enumerate_basis       0.081    0.006   0.087          0         0
## design_matrix         0.867    0.028   0.899          0         0
## reduce_basis          0.000    0.000   0.000          0         0
## remove_duplicates     0.000    0.000   0.000          0         0
## lasso               108.873    0.392 109.328          0         0
## total               109.821    0.426 110.315          0         0

Specifying smoothness of the HAL model
We can use smoothness_orders arguement to enforce smoothness on the functional form of the HAL fit.
* smoothness_orders = 0 provides a segmented constant approximation (via zero-order basis functions) that accommodates abrupt changes within the function. This is valuable when desiring to avoid assuming any level of smoothness or continuity in the underlying function.
* smoothness_orders = 1 gives a piece-wise linear fit (via first-order basis functions), which is continuous and mostly differentiable.
* smoothness_orders = k gives a segmented polynomial fit (via k-order basis functions).

Mathematically, HAL fitting will find the optimal fit under the constraint that the total variation of the function’s k-th derivative is bounded by some constant, which is selected through cross-validation.

When not sure about smoothness_orders, we suggest use the default value 1.

hal_fit_zeroorder <- fit_hal(X = X, Y = Y, smoothness_orders = 0)
hal_fit_firstorder <- fit_hal(X = X, Y = Y) d

2.2.2 Undersmoothed HAL

Three steps to fit a undersmoothed HAL:
1. We need to fit a cv-hal with the arguement return_x_basis = TRUE by calling the fit_hal function.
2. We pass in X, Y, hal_cv fit object, and family to the undersmooth_hal function (from the ‘functions_UHAL_deltaCI.R’ file) to get the undersmoothed lambda.
3. Then we fit the undersmoothed HAL by calling fit_hal function again with:
- lambda to be the undersmoothed lambda,
- fit_control = list(cv_select = FALSE) indicating no need to select a lambda through cross validation,
- and other arguements the same as the initial fitting of cv-hal.

set.seed(123)
hal_cv_fit <- fit_hal(X = X, Y = Y, return_x_basis = TRUE)

undersmooth_hal_results <- undersmooth_hal(X, Y, 
                                           fit_init = hal_cv_fit,
                                           family = "gaussian")

hal_u_fit <- fit_hal(X = X, Y = Y, 
                return_x_basis = TRUE,
                fit_control = list(cv_select = FALSE),
                lambda = undersmooth_hal_results$lambda_under)

There can be situations that there is no need for undersmoothing or we cannot undersmooth further. We can check this by whether lambda_init and lambda_under in the results of calling undersmooth_hal function are the same. If this is the case, we can just proceed with hal_cv.

2.3 Sumarizing HAL fits

The summary of a HAL fit object will have the table of non-zero coefficients and the corresponding basis functions. The lambda used will also be stored here.

hal_fit_summary <- summary(hal_u_fit)
print(hal_fit_summary)
## 
## 
## Summary of non-zero coefficients is based on lambda of 10.99614 
## 
##              coef
##     -7.219003e-01
##      9.692868e-05
##      8.841356e-05
##      6.517827e-05
##     -6.044765e-05
## ---              
##      8.379390e-09
##     -5.753798e-09
##      3.231291e-10
##      1.834215e-10
##     -1.001140e-10
##                                                                                                                            term
##                                                                                                                     (Intercept)
##             [ I(aged >= 42)*(aged - 42)^1 ] * [ I(momheight >= 152.3)*(momheight - 152.3)^1 ] * [ I(Nlt18 >= 1)*(Nlt18 - 1)^1 ]
##                                   [ I(tr >= 0)*(tr - 0)^1 ] * [ I(aged >= 42)*(aged - 42)^1 ] * [ I(Ncomp >= 4)*(Ncomp - 4)^1 ]
##         [ I(aged >= 42)*(aged - 42)^1 ] * [ I(momheight >= 155.05)*(momheight - 155.05)^1 ] * [ I(Ncomp >= 10)*(Ncomp - 10)^1 ]
##       [ I(aged >= 332)*(aged - 332)^1 ] * [ I(momheight >= 120.65)*(momheight - 120.65)^1 ] * [ I(Ncomp >= 12)*(Ncomp - 12)^1 ]
## ---                                                                                                                            
##         [ I(aged >= 299)*(aged - 299)^1 ] * [ I(momheight >= 143.65)*(momheight - 143.65)^1 ] * [ I(Ncomp >= 8)*(Ncomp - 8)^1 ]
##       [ I(aged >= 42)*(aged - 42)^1 ] * [ I(momage >= 18)*(momage - 18)^1 ] * [ I(momheight >= 145.75)*(momheight - 145.75)^1 ]
##     [ I(aged >= 233)*(aged - 233)^1 ] * [ I(momage >= 14)*(momage - 14)^1 ] * [ I(momheight >= 148.75)*(momheight - 148.75)^1 ]
##     [ I(aged >= 196)*(aged - 196)^1 ] * [ I(momage >= 14)*(momage - 14)^1 ] * [ I(momheight >= 120.65)*(momheight - 120.65)^1 ]
##     [ I(aged >= 260)*(aged - 260)^1 ] * [ I(momage >= 18)*(momage - 18)^1 ] * [ I(momheight >= 120.65)*(momheight - 120.65)^1 ]

3. Obtainning Predictions

3.1 Predictions

We will use the previourly obtained fitted undersmoothed HAL object, hal_u_fit, to get the predicted whz value for each individual within the dataset.

hal_u_preds <- predict(hal_u_fit, new_data = X)
head(hal_u_preds)
## [1] -0.7095915 -0.5612373 -0.4113562 -0.6926089 -0.6274143 -0.4830798

3.2 Counterfactual predictions

Counterfactual predictions are predicted values under an intervention of interest.
To continue the example on the WASH Benefits Bangladesh study and using the same subset of the data, suppose we would like to obtain predictions for every subject’s weight-for-height z-score (whz) outcome under an treatment (tr) intervention that sets it to the WSH interventions. To do so, we need to do two simple steps: 1) create a copy of the X matrix and intervene on tr in the copied matrix; 2) obtain prediction using the intervened matrix.

# 1. copy data & intervene
X_intervene <- X
X_intervene$tr = 1 
head(X_intervene)
##    tr month aged sexF momage momheight hfiacat Nlt18 Ncomp
## 1:  1     9  268    0     30    146.40       1     3    11
## 2:  1     9  286    0     25    148.75       3     2     4
## 3:  1     9  264    0     25    152.15       1     1    10
## 4:  1     9  252    1     28    140.25       1     3     5
## 5:  1     9  336    1     19    150.95       1     2     7
## 6:  1     9  304    0     20    154.20       4     0     3
# 2. predictions
counterfactual_pred <- predict(hal_u_fit, new_data = X_intervene)
head(counterfactual_pred)
## [1] -0.5561809 -0.5592996 -0.2424205 -0.6498592 -0.6191041 -0.4957763

4. Obtainning Estimates and Inferences

4.1 Counterfactual Outcome Mean

Counterfactual outcome mean is the average outcome that all individuals in the target population would have experienced if they had received a particular treatment or exposure value. Suppose we are interested in the average weight-for-height z-score (whz) if all of the people in Bangladesh had the WSH interventions (i.e. the conterfactual outcome mean): \(\mathbb{E}_P(Y_1)\).

CMean_infr = CounterfactualMean_inferece(X, Y, treatment = "tr", treatment_level = 1, hal_fit = hal_u_fit, family = "gaussian")
print(CMean_infr)
## $Target_parameter
## [1] "EE[Y|A=1, W]"
## 
## $Estimate
## [1] -0.4579066
## 
## $SE
## [1] 3.695635
## 
## $CI
## [1] -7.701218  6.785405

4.2 ATE

The average treatment effect (ATE) is a measure of difference in outcomes of interest between the treatment intervention to the control.
Now suppose we are interested in the causal effect of weight-for-height z-score (whz) due to WSH interventions, which is the difference between average counterfactual weight-for-height z-score if all of the people in Bangladesh got the WSH interventions versus all of the people in Bangladesh did not have WSH interventions (i.e. the average treatment effect): \[\Psi(P) = \mathbb{E}_P(Y_1) - \mathbb{E}_P(Y_0)\] where \(Y_a\) denotes the counterfactual outcome (weight-for-height z-score), if possibly contrary to fact, the person had interventions status A = a.

# Using the provided function to obtain the ATE: 
ATE_infr = ATE_inferece(X, Y, treatment = "tr", hal_fit = hal_u_fit, family = "gaussian")

print(ATE_infr)
## $Target_parameter
## [1] "E[Y_1] - E[Y_0]"
## 
## $Estimate
## [1] 0.1588137
## 
## $SE
## [1] 0.342477
## 
## $CI
## [1] -0.5124289  0.8300563

4.2 CATE

The conditional average treatment effect (CATE) is the average treatment effect on an outcome variable for a specific subgroup or conditional population defined by a set of covariates or characteristics.

Suppose we are interested in the causal effect of weight-for-height z-score (whz) due to WSH interventions among males (i.e. the conditional average treatment effect among males): \[\Psi(P_{male}) = \mathbb{E}_{P_{male}}(Y_1) - \mathbb{E}_{P_{male}}(Y_0)\].

CATE_infr = ATE_inferece(X, Y, treatment = "tr", condition = "sexF", condition_level = 1, hal_fit = hal_u_fit, family = "gaussian")

print(CATE_infr)
## $Target_parameter
## [1] "E[Y_1|sexF=1] - E[Y_0|sexF=1]"
## 
## $Estimate
## [1] 0.1604635
## 
## $SE
## [1] 0.3391918
## 
## $CI
## [1] -0.5043403  0.8252673

5. Dore Response Curve Example with a Simulated Dataset

5.1 Fit HAL

5.2 Dose response curve estimatation

5.3 Confidence intervals