class: center, middle, inverse, title-slide .title[ # Lecture .mono[007] ] .subtitle[ ## Trees 🌲🌴🌳 ] .author[ ### Edward Rubin ] --- exclude: true --- layout: true # Admin --- class: inverse, middle --- name: admin-today ## Material .note[Last time:] Classification .note[Today!] Decision trees for regression and classification. --- name: admin-soon ## Upcoming .b[Readings] - .note[Today] .it[ISL] Ch. 8.1 - .note[Next] .it[ISL] Ch. 8.2 .b[Problem sets] - .it[Penalized regression and CV:] Coming soon! - .it[Classification:] Next! .b[Project] Start putting it together! --- layout: true # Decision trees --- class: inverse, middle --- name: fundamentals ## Fundamentals .attn[Decision trees] - split the .it[predictor space] (our `\(\mathbf{X}\)`) into regions - then predict the most-common value within a region -- .attn[Decision trees] 1. work for .hi[both classification and regression] -- 1. are inherently .hi[nonlinear] -- 1. are relatively .hi[simple] and .hi[interpretable] -- 1. often .hi[underperform] relative to competing methods -- 1. easily extend to .hi[very competitive ensemble methods] (*many* trees).super[🌲] .footnote[ 🌲 Though the ensembles will be much less interpretable. ] --- layout: true class: clear --- exclude: true --- .ex[Example:] .b[A simple decision tree] classifying credit-card default
--- name: ex-partition Let's see how the tree works -- —starting with our data (default: .orange[Yes] .it[vs.] .purple[No]). <img src="slides_files/figure-html/plot-raw-1.svg" style="display: block; margin: auto;" /> --- The .hi-pink[first partition] splits balance at $1,800. <img src="slides_files/figure-html/plot-split1-1.svg" style="display: block; margin: auto;" /> --- The .hi-pink[second partition] splits balance at $1,972, (.it[conditional on bal. > $1,800]). <img src="slides_files/figure-html/plot-split2-1.svg" style="display: block; margin: auto;" /> --- The .hi-pink[third partition] splits income at $27K .b[for] bal. between $1,800 and $1,972. <img src="slides_files/figure-html/plot-split3-1.svg" style="display: block; margin: auto;" /> --- These three partitions give us four .b[regions]... <img src="slides_files/figure-html/plot-split3b-1.svg" style="display: block; margin: auto;" /> --- .b[Predictions] cover each region (_e.g._, using the region's most common class). <img src="slides_files/figure-html/plot-split3c-1.svg" style="display: block; margin: auto;" /> --- .b[Predictions] cover each region (_e.g._, using the region's most common class). <img src="slides_files/figure-html/plot-split3d-1.svg" style="display: block; margin: auto;" /> --- .b[Predictions] cover each region (_e.g._, using the region's most common class). <img src="slides_files/figure-html/plot-split3e-1.svg" style="display: block; margin: auto;" /> --- .b[Predictions] cover each region (_e.g._, using the region's most common class). <img src="slides_files/figure-html/plot-split3f-1.svg" style="display: block; margin: auto;" /> --- name: defn The .hi-pink[regions] correspond to the tree's .attn[terminal nodes] (or .attn[leaves]).
--- The graph's .hi-pink[separating lines] correspond to the tree's .attn[internal nodes].
--- The segments connecting the nodes within the tree are its .attn[branches].
--- class: middle You now know the anatomy of a decision tree. But where do trees come from—how do we train.super[🌲] a tree? .footnote[ 🌲 grow ] --- layout: true # Decision trees --- name: growth ## Growing trees We will start with .attn[regression trees], _i.e._, trees used in regression settings. -- As we saw, the task of .hi[growing a tree] involves two main steps: 1. .b[Divide the predictor space] into `\(J\)` regions (using predictors `\(\mathbf{x}_1,\ldots,\mathbf{x}_p\)`) -- 1. .b[Make predictions] using the regions' mean outcome. <br>For region `\(R_j\)` predict `\(\hat{y}_{R_j}\)` where $$ `\begin{align} \hat{y}_{R_j} = \frac{1}{n_j} \sum_{i\in R_j} y \end{align}` $$ --- ## Growing trees We .hi[choose the regions to minimize RSS] .it[across all] `\(J\)` .note[regions], _i.e._, $$ `\begin{align} \sum_{j=1}^{J} \sum_{i=1}^{n_{j}} \left( y_i - \hat{y}_{R_j} \right)^2 \end{align}` $$ -- .b[Problem:] Examining every possible partition is computationally infeasible. -- .b[Solution:] a .it[top-down, greedy] algorithm named .attn[recursive binary splitting] - .attn[recursive] start with the "best" split, then find the next "best" split, ... - .attn[binary] each split creates two branches—"yes" and "no" - .attn[greedy] each step makes .it[best] split—no consideration of overall process --- ## Growing trees: Choosing a split .ex[Recall] Regression trees choose the split that minimizes RSS. To find this split, we need 1. a .purple[predictor], `\(\color{#6A5ACD}{\mathbf{x}_j}\)` 1. a .attn[cutoff] `\(\color{#e64173}{s}\)` that splits `\(\color{#6A5ACD}{\mathbf{x}_j}\)` into two parts: (1) `\(\color{#6A5ACD}{\mathbf{x}_j} < \color{#e64173}{s}\)` and (2) `\(\color{#6A5ACD}{\mathbf{x}_j} \ge \color{#e64173}{s}\)` -- Searching across each of our .purple[predictors] `\(\color{#6A5ACD}{j}\)` and all of their .pink[cutoffs] `\(\color{#e64173}{s}\)`, <br>we choose the combination that .b[minimizes RSS]. --- layout: true # Decision trees ## Example: Splitting --- name: ex-split .ex[Example] Consider the dataset
i
pred.
y
x.sub[1]
x.sub[2]
1
0
0
1
4
2
0
8
3
2
3
0
6
5
6
-- With just three observations, each variable only has two actual splits..super[🌲] .footnote[ 🌲 You can think about cutoffs as the ways we divide observations into two groups. ] --- One possible split: x.sub[1] at 2, which yields .purple[(.b[1]) x.sub[1] < 2] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 2]
i
pred.
y
x.sub[1]
x.sub[2]
1
0
0
1
4
2
7
8
3
2
3
7
6
5
6
--- One possible split: x.sub[1] at 2, which yields .purple[(.b[1]) x.sub[1] < 2] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 2]
i
pred.
y
x.sub[1]
x.sub[2]
1
0
0
1
4
2
7
8
3
2
3
7
6
5
6
This split yields an RSS of .purple[0.super[2]] + .pink[1.super[2]] + .pink[(-1).super[2]] = 2. -- .note[Note.sub[1]] Splitting x.sub[1] at 2 yields the same results as 1.5, 2.5—anything in (1, 3). -- <br>.note[Note.sub[2]] Trees often grow until they hit some number of observations in a leaf. --- An alternative split: x.sub[1] at 4, which yields .purple[(.b[1]) x.sub[1] < 4] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 4]
i
pred.
y
x.sub[1]
x.sub[2]
1
4
0
1
4
2
4
8
3
2
3
6
6
5
6
This split yields an RSS of .purple[(-4).super[2]] + .purple[4.super[2]] + .pink[0.super[2]] = 32. -- .it[Previous:] Splitting x.sub[1] at 4 yielded RSS = 2. .it.grey-light[(Much better)] --- Another split: x.sub[2] at 3, which yields .purple[(.b[1]) x.sub[2] < 3] .it[vs.] .pink[(.b[2]) x.sub[2] ≥ 3]
i
pred.
y
x.sub[1]
x.sub[2]
1
3
0
1
4
2
8
8
3
2
3
3
6
5
6
This split yields an RSS of .pink[(-3).super[2]] + .purple[0.super[2]] + .pink[3.super[2]] = 18. --- Final split: x.sub[2] at 5, which yields .purple[(.b[1]) x.sub[2] < 5] .it[vs.] .pink[(.b[2]) x.sub[2] ≥ 5]
i
pred.
y
x.sub[1]
x.sub[2]
1
4
0
1
4
2
4
8
3
2
3
6
6
5
6
This split yields an RSS of .pink[(-4).super[2]] + .pink[4.super[2]] + .purple[0.super[2]] = 32. --- Across our four possible splits (two variables each with two splits) - x.sub[1] with a cutoff of 2: .b[RSS] = 2 - x.sub[1] with a cutoff of 4: .b[RSS] = 32 - x.sub[2] with a cutoff of 3: .b[RSS] = 18 - x.sub[2] with a cutoff of 5: .b[RSS] = 32 our split of x.sub[1] at 2 generates the lowest RSS. --- layout: false class: clear, middle .note[Note:] Categorical predictors work in exactly the same way. <br>We want to try .b[all possible combinations] of the categories. .ex[Ex:] For a four-level categorical predicator (levels: A, B, C, D) .col-left[ - Split 1: .pink[A|B|C] .it[vs.] .purple[D] - Split 2: .pink[A|B|D] .it[vs.] .purple[C] - Split 3: .pink[A|C|D] .it[vs.] .purple[B] - Split 4: .pink[B|C|D] .it[vs.] .purple[A] ] .col-right[ - Split 5: .pink[A|B] .it[vs.] .purple[C|D] - Split 6: .pink[A|C] .it[vs.] .purple[B|D] - Split 7: .pink[A|D] .it[vs.] .purple[B|C] ] .clear-up[ we would need to try 7 possible splits. ] --- layout: true # Decision trees --- name: splits-more ## More splits Once we make our a split, we then continue splitting, <br>.b[conditional] on the regions from our previous splits. So if our first split creates R.sub[1] and R.sub[2], then our next split <br>searches the predictor space only in R.sub[1] or R.sub[2]..super[🌲] .footnote[ 🌲 We are no longer searching the full space—it is conditional on the previous splits. ] -- The tree continue to .b[grow until] it hits some specified threshold, <br>_e.g._, at most 5 observations in each leaf. --- ## Too many splits? One can have too many splits. .qa[Q] Why? -- .qa[A] "More splits" means 1. more flexibility (think about the bias-variance tradeoff/overfitting) 1. less interpretability (one of the selling points for trees) -- .qa[Q] So what can we do? -- .qa[A] Prune your trees! --- name: pruning ## Pruning .attn[Pruning] allows us to trim our trees back to their "best selves." .note[The idea:] Some regions may increase .hi[variance] more than they reduce .hi[bias]. <br> By removing these regions, we gain in test MSE. .note[Candidates for trimming:] Regions that do not .b[reduce RSS] very much. -- .note[Updated strategy:] Grow big trees `\(T_0\)` and then trim `\(T_0\)` to an optimal .attn[subtree]. -- .note[Updated problem:] Considering all possible subtrees can get expensive. --- ## Pruning .attn[Cost-complexity pruning].super[🌲] offers a solution. .footnote[ 🌲 Also called: .it[weakest-link pruning]. ] Just as we did with lasso, .attn[cost-complexity pruning] forces the tree to pay a price (penalty) to become more complex. .it[Complexity] here is defined as the number of regions `\(|T|\)`. --- ## Pruning Specifically, .attn[cost-complexity pruning] adds a penalty of `\(\alpha |T|\)` to the RSS, _i.e._, $$ `\begin{align} \sum_{m=1}^{|T|} \sum_{i:x\in R_m} \left( y_i - \hat{y}_{R_m} \right)^2 + \alpha |T| \end{align}` $$ For any value of `\(\alpha (\ge 0)\)`, we get a subtree `\(T\subset T_0\)`. -- `\(\alpha = 0\)` generates `\(T_0\)`, but as `\(\alpha\)` increases, we begin to cut back the tree. -- We choose `\(\alpha\)` via cross validation. --- name: ctree ## Classification trees Classification with trees is very similar to regression. -- .col-left[ .hi-purple[Regression trees] - .hi-slate[Predict:] Region's mean - .hi-slate[Split:] Minimize RSS - .hi-slate[Prune:] Penalized RSS ] -- .col-right[ .hi-pink[Classification trees] - .hi-slate[Predict:] Region's mode - .hi-slate[Split:] Min. Gini or entropy.super[🌲] - .hi-slate[Prune:] Penalized error rate.super[🌴] ] .footnote[ 🌲 Defined on the next slide. 🌴 ... or Gini index or entropy ] .clear-up[ An additional nuance for .attn[classification trees]: We typically care about the .b[proportions of classes in the leaves]—not just the final prediction. ] --- name: gini ## The Gini index Let `\(\hat{p}_{mk}\)` denote the proportion of observations in class `\(k\)` and region `\(m\)`. -- The .attn[Gini index] tells us about a region's "purity".super[🌲] $$ `\begin{align} G = \sum_{k=1}^{K} \hat{p}_{mk} \left( 1 - \hat{p}_{mk} \right) \end{align}` $$ if a region is very homogeneous, then the Gini index will be small. .footnote[ 🌲 This vocabulary is Voldemort's contribution to the machine-learning literature. ] Homogenous regions are easier to predict. <br>Reducing the Gini index yields to more homogeneous regions <br>∴ We want to minimize the Gini index. --- layout: false class: clear .b.pink[Gini as a function of 'purity'] <img src="slides_files/figure-html/plot-gini-1.svg" style="display: block; margin: auto;" /> --- layout: true # Decision trees --- name: entropy ## Entropy Let `\(\hat{p}_{mk}\)` denote the proportion of observations in class `\(k\)` and region `\(m\)`. .attn[Entropy] also measures the "purity" of a node/leaf $$ `\begin{align} D = - \sum_{k=1}^{K} \hat{p}_{mk} \log \left( \hat{p}_{mk} \right) \end{align}` $$ .attn[Entropy] is also minimized when `\(\hat{p}_{mk}\)` values are close to 0 and 1. --- layout: false class: clear .b.pink[Entropy as a function of 'purity'] <img src="slides_files/figure-html/plot-entropy-1.svg" style="display: block; margin: auto;" /> --- layout: true # Decision trees --- name: class-why ## Rational .qa[Q] Why are we using the Gini index or entropy (*vs.* error rate)? -- .qa[A] The error rate isn't sufficiently sensitive to grow good trees. <br> The Gini index and entropy tell us about the .b[composition] of the leaf. -- .ex[Ex.] Consider two different leaves in a three-level classification. .col-left[ .b[Leaf 1] - .b[A:] 51, .b[B:] 49, .b[C:] 00 - .hi-orange[Error rate:] 49% - .hi-purple[Gini index:] 0.4998 - .hi-pink[Entropy:] 0.6929 ] .col-right[ .b[Leaf 2] - .b[A:] 51, .b[B:] 25, .b[C:] 24 - .hi-orange[Error rate:] 49% - .hi-purple[Gini index:] 0.6198 - .hi-pink[Entropy:] 1.0325 ] .clear-up[ The .hi-purple[Gini index] and .hi-pink[entropy] tell us about the distribution. ] --- ## Classification trees When .b[growing] classification trees, we want to use the Gini index or entropy. However, when .b[pruning], the error rate is typically fine—especially if accuracy will be the final criterion. --- name: in-r ## In R To train decision trees in R, we can use `parsnip`, which draws upon `rpart`. In `parsnip`, we use the aptly named `decision_tree()` function. -- The `decision_tree()` model (with `rpart` engine) wants four inputs: - `mode`: `"regression"` or `"classification"` - `cost_complexity`: the cost (penalty) paid for complexity - `tree_depth`: *max.* tree depth (max. number of splits in a "branch") - `min_n`: *min.* # of observations for a node to split --- layout: false class: clear ```r # Define our CV split set.seed(12345) default_cv = default_df %>% vfold_cv(v = 5) # Define the decision tree default_tree = decision_tree( mode = "classification", cost_complexity = tune(), tree_depth = tune(), min_n = 10 # Arbitrarily choosing '10' ) %>% set_engine("rpart") # Define recipe default_recipe = recipe(default ~ ., data = default_df) # Define the workflow default_flow = workflow() %>% add_model(default_tree) %>% add_recipe(default_recipe) # Tune! default_cv_fit = default_flow %>% tune_grid( default_cv, grid = expand_grid( cost_complexity = seq(0, 0.15, by = 0.01), tree_depth = c(1, 2, 5, 10), ), metrics = metric_set(accuracy, roc_auc) ) ``` --- layout: true class: clear --- .b[Accuracy, complexity, and depth] <img src="slides_files/figure-html/plot-cv-cp-1.svg" style="display: block; margin: auto;" /> --- exclude: true .b[ROC AUC, complexity, and depth] <img src="slides_files/figure-html/plot-cv-auc-1.svg" style="display: block; margin: auto;" /> --- class: middle .b.slate[To plot the CV-chosen tree...] 1\. .b.pink[Fit] the chosen/best model. ``` r best_flow = default_flow %>% finalize_workflow(select_best(default_cv_fit, metric = "accuracy")) %>% fit(data = default_df) ``` 2\. .b.purple[Extract] the fitted model, *e.g.*, with `extract_fit_parsnip`. <br> .note[Old/deprecated way:] `pull_workflow_fit()` ``` r best_tree = best_flow %>% extract_fit_parsnip() ``` 3\. .b.orange[Plot] the tree, *e.g.*, with `rpart.plot()` from `rpart.plot`. ``` r best_tree$fit %>% rpart.plot() ``` --- class: clear, middle <img src="slides_files/figure-html/plot-rpart-cv-1.svg" style="display: block; margin: auto;" /> --- class: clear, middle The previous tree has cost complexity of 0.03 (and a max. depth of 5). We can compare this "best" tree to a less pruned/penalized tree - `cost_complexity = 0.005` - `tree_depth = 5` --- class: clear, middle <img src="slides_files/figure-html/plot-tree_complex-1.svg" style="display: block; margin: auto;" /> --- class: clear, middle What if we hold the cost complexity constant but increase the max. depth? - `cost_complexity = 0.005` - `tree_depth = 10` (moved up from `5`) --- class: clear, middle <img src="slides_files/figure-html/plot-tree_complexer-1.svg" style="display: block; margin: auto;" /> --- class: clear, middle What if we ratchet up complexity constant? - `cost_complexity = 0.1` (increased from `0.005`) - `tree_depth = 10` --- class: clear, middle <img src="slides_files/figure-html/plot-tree_simple-1.svg" style="display: block; margin: auto;" /> --- layout: false class: clear, middle .note[Fun tool] Grant McDermott's [`parttree` package](https://grantmcdermott.com/parttree/). --- layout: true class: clear, middle --- name: linearity .qa[Q] How do trees compare to linear models? .tran[.b[A] It depends how linear truth is.] --- .qa[Q] How do trees compare to linear models? .qa[A] It depends how linear the true boundary is. --- .b[Linear boundary:] trees struggle to recreate a line. <img src="images/compare-linear.png" width="5784" style="display: block; margin: auto;" /> .ex.small[Source: ISL, p. 315] --- .b[Nonlinear boundary:] trees easily replicate the nonlinear boundary. <img src="images/compare-nonlinear.png" width="5784" style="display: block; margin: auto;" /> .ex.small[Source: ISL, p. 315] --- layout: false name: tree-pro-con # Decision trees ## Strengths and weaknesses As with any method, decision trees have tradeoffs. .col-left.purple.small[ .b[Strengths] <br>.b[+] Easily explained/interpretted <br>.b[+] Include several graphical options <br>.b[+] Mirror human decision making? <br>.b[+] Handle num. or cat. on LHS/RHS.super[🌳] ] .footnote[ 🌳 Without needing to create lots of dummy variables! <br> .tran[🌴 Blank] ] -- .col-right.pink.small[ .b[Weaknesses] <br>.b[-] Outperformed by other methods <br>.b[-] Struggle with linearity <br>.b[-] Can be very "non-robust" ] .clear-up[ .attn[Non-robust:] Small data changes can cause huge changes in our tree. ] -- .note[Next:] Create ensembles of trees.super[🌲] to strengthen these weaknesses..super[🌴] .footnote[ .tran[🌴 Blank] <br> 🌲 Forests! 🌴 Which will also weaken some of the strengths. ] --- name: sources layout: false # Sources These notes draw upon - [An Introduction to Statistical Learning](http://faculty.marshall.usc.edu/gareth-james/ISL/) (*ISL*)<br>James, Witten, Hastie, and Tibshirani --- # Table of contents .col-left[ .smallest[ #### Admin - [Today](#admin-today) - [Upcoming](#admin-soon) #### Decision trees 1. [Fundamentals](#fundamentals) 1. [Partitioning predictors](#ex-partition) 1. [Definitions](#defn) 1. [Growing trees](#growth) 1. [Example: Splitting](#ex-split) 1. [More splits](#splits-more) 1. [Pruning](#pruning) 1. [Classification trees](#ctree) - [The Gini index](#gini) - [Entropy](#entropy) - [Rationale](#class-why) 1. [In R](#in-r) 1. [Linearity](#linearity) 1. [Strengths and weaknesses](#tree-pro-con) ] ] .col-right[ .smallest[ #### Other - [Sources/references](#sources) ] ]