--- title: "Lecture .mono[007]" subtitle: "Trees 🌲🌴🌳" author: "Edward Rubin" #date: "`r format(Sys.time(), '%d %B %Y')`" date: "20 February 2020" output: xaringan::moon_reader: css: ['default', 'metropolis', 'metropolis-fonts', 'my-css.css'] # self_contained: true nature: highlightStyle: github highlightLines: true countIncrementalSlides: false --- exclude: true ```{R, setup, include = F} library(pacman) p_load( ISLR, broom, tidyverse, ggplot2, ggthemes, ggforce, ggridges, cowplot, scales, latex2exp, viridis, extrafont, gridExtra, plotly, ggformula, DiagrammeR, kableExtra, DT, huxtable, data.table, dplyr, snakecase, janitor, lubridate, knitr, caret, rpart, rpart.plot, rattle, here, magrittr, parallel ) # Define colors red_pink = "#e64173" turquoise = "#20B2AA" orange = "#FFA500" red = "#fb6107" blue = "#3b3b9a" green = "#8bb174" grey_light = "grey70" grey_mid = "grey50" grey_dark = "grey20" purple = "#6A5ACD" slate = "#314f4f" # Knitr options opts_chunk$set( comment = "#>", fig.align = "center", fig.height = 7, fig.width = 10.5, warning = F, message = F ) opts_chunk$set(dev = "svg") options(device = function(file, width, height) { svg(tempfile(), width = width, height = height) }) options(knitr.table.format = "html") ``` --- layout: true # Admin --- class: inverse, middle --- name: admin-today ## Material Decision trees for regression and classification. --- name: admin-soon ## Upcoming .b[Readings] - .note[Today] .it[ISL] Ch. 8.1 - .note[Next] .it[ISL] Ch. 8.2 .b[Problem sets] - .it[Classification] Due today - Let Connor know if you are resubmitting .b[Project] Project topic due before midnight on Friday. --- layout: true # Decision trees --- class: inverse, middle --- name: fundamentals ## Fundamentals .attn[Decision trees] - split the .it[predictor space] (our $\mathbf{X}$) into regions - then predict the most-common value within a region -- .attn[Tree-based methods] 1. work for .hi[both classification and regression] -- 1. are inherently .hi[nonlinear] -- 1. are relatively .hi[simple] and .hi[interpretable] -- 1. often .hi[underperform] relatively to competing methods -- 1. easily extend to .hi[very competitive ensemble methods] (*many* trees).super[🌲] .footnote[ 🌲 Though the ensembles will be much less interpretable. ] --- layout: true class: clear --- exclude: true ```{R, data-default, include = F} # Load 'Defualt' data from 'ISLR' default_df = ISLR::Default %>% as_tibble() ``` --- .ex[Example:] .b[A simple decision tree] classifying credit-card default ```{R, tree-graph, echo = F, cache = T} DiagrammeR::grViz(" digraph { graph [layout = dot, overlap = false, fontsize = 14] node [shape = oval, fontname = 'Fira Sans', color = Gray95, style = filled] s1 [label = 'Bal. > 1,800'] s2 [label = 'Bal. < 1,972'] s3 [label = 'Inc. > 27K'] node [shape = egg, fontname = 'Fira Sans', color = Purple, style = filled, fontcolor = White] l1 [label = 'No (98%)'] l4 [label = 'No (69%)'] node [shape = egg, fontname = 'Fira Sans', color = Orange, style = filled, fontcolor = White] l2 [label = 'Yes (76%)'] l3 [label = 'Yes (59%)'] edge [fontname = 'Fira Sans', color = Grey70] s1 -> l1 [label = 'F'] s1 -> s2 [label = 'T'] s2 -> s3 [label = 'T'] s2 -> l2 [label = 'F'] s3 -> l3 [label = 'T'] s3 -> l4 [label = 'F'] } ") ``` --- name: ex-partition Let's see how the tree works -- —starting with our data (default: .orange[Yes] .it[vs.] .purple[No]). ```{R, partition-base, include = F, cache = T} gg_base = ggplot( data = default_df, aes(x = balance, y = income, color = default, alpha = default) ) + geom_hline(yintercept = 0) + geom_vline(xintercept = 0) + geom_point(size = 2) + scale_y_continuous("Income", labels = dollar) + scale_x_continuous("Balance", labels = dollar) + scale_color_manual("Defaulted:", values = c(purple, orange), labels = c("No", "Yes")) + scale_alpha_manual("Defaulted:", values = c(0.1, 0.8), labels = c("No", "Yes")) + theme_minimal(base_size = 20, base_family = "Fira Sans Book") + theme(legend.position = "none") ``` ```{R, plot-raw, echo = F} gg_base ``` --- The .hi-pink[first partition] splits balance at $1,800. ```{R, plot-split1, echo = F, cache = T} gg_base + annotate( "segment", x = 1800, xend = 1800, y = -Inf, yend = Inf, color = red_pink, size = 1.2 ) ``` --- The .hi-pink[second partition] splits balance at $1,972, (.it[conditional on bal. > $1,800]). ```{R, plot-split2, echo = F, cache = T} gg_base + annotate( "segment", x = 1800, xend = 1800, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1972, xend = 1972, y = -Inf, yend = Inf, color = red_pink, size = 1.2 ) ``` --- The .hi-pink[third partition] splits income at $27K .b[for] bal. between $1,800 and $1,972. ```{R, plot-split3, echo = F, cache = T} gg_base + annotate( "segment", x = 1800, xend = 1800, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1972, xend = 1972, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1800, xend = 1972, y = 27e3, yend = 27e3, color = red_pink, size = 1.2 ) ``` --- These three partitions give us four .b[regions]... ```{R, plot-split3b, echo = F, cache = T} gg_base + annotate( "segment", x = 1800, xend = 1800, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1972, xend = 1972, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1800, xend = 1972, y = 27e3, yend = 27e3, linetype = "longdash" ) + annotate("text", x = 900, y = 37500, label = expression(R[1]), size = 8, family = "Fira Sans Book" ) + annotate("text", x = 1886, y = 5.1e4, label = expression(R[2]), size = 8, family = "Fira Sans Book" ) + annotate("text", x = 1886, y = 1e4, label = expression(R[3]), size = 8, family = "Fira Sans Book" ) + annotate("text", x = 2336, y = 37500, label = expression(R[4]), size = 8, family = "Fira Sans Book" ) ``` --- .b[Predictions] cover each region (_e.g._, using the region's most common class). ```{R, plot-split3c, echo = F, cache = T} gg_base + annotate( "rect", xmin = 0, xmax = 1800, ymin = 0, ymax = Inf, fill = purple, alpha = 0.3 ) + annotate( "segment", x = 1800, xend = 1800, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1972, xend = 1972, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1800, xend = 1972, y = 27e3, yend = 27e3, linetype = "longdash" ) ``` --- .b[Predictions] cover each region (_e.g._, using the region's most common class). ```{R, plot-split3d, echo = F, cache = T} gg_base + annotate( "rect", xmin = 0, xmax = 1800, ymin = 0, ymax = Inf, fill = purple, alpha = 0.3 ) + annotate( "rect", xmin = 1800, xmax = 1972, ymin = 27e3, ymax = Inf, fill = orange, alpha = 0.3 ) + annotate( "segment", x = 1800, xend = 1800, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1972, xend = 1972, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1800, xend = 1972, y = 27e3, yend = 27e3, linetype = "longdash" ) ``` --- .b[Predictions] cover each region (_e.g._, using the region's most common class). ```{R, plot-split3e, echo = F, cache = T} gg_base + annotate( "rect", xmin = 0, xmax = 1800, ymin = 0, ymax = Inf, fill = purple, alpha = 0.3 ) + annotate( "rect", xmin = 1800, xmax = 1972, ymin = 27e3, ymax = Inf, fill = orange, alpha = 0.3 ) + annotate( "rect", xmin = 1800, xmax = 1972, ymin = 0, ymax = 27e3, fill = purple, alpha = 0.3 ) + annotate( "segment", x = 1800, xend = 1800, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1972, xend = 1972, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1800, xend = 1972, y = 27e3, yend = 27e3, linetype = "longdash" ) ``` --- .b[Predictions] cover each region (_e.g._, using the region's most common class). ```{R, plot-split3f, echo = F, cache = T} gg_base + annotate( "rect", xmin = 0, xmax = 1800, ymin = 0, ymax = Inf, fill = purple, alpha = 0.3 ) + annotate( "rect", xmin = 1800, xmax = 1972, ymin = 27e3, ymax = Inf, fill = orange, alpha = 0.3 ) + annotate( "rect", xmin = 1800, xmax = 1972, ymin = 0, ymax = 27e3, fill = purple, alpha = 0.3 ) + annotate( "rect", xmin = 1972, xmax = Inf, ymin = 0, ymax = Inf, fill = orange, alpha = 0.3 ) + annotate( "segment", x = 1800, xend = 1800, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1972, xend = 1972, y = -Inf, yend = Inf, linetype = "longdash" ) + annotate( "segment", x = 1800, xend = 1972, y = 27e3, yend = 27e3, linetype = "longdash" ) ``` --- name: defn The .hi-pink[regions] correspond to the tree's .attn[terminal nodes] (or .attn[leaves]). ```{R, tree-leaves, echo = F, cache = T} DiagrammeR::grViz(" digraph { graph [layout = dot, overlap = false, fontsize = 14] node [shape = oval, fontname = 'Fira Sans', color = Gray95, style = filled] s1 [label = 'Bal. > 1,800'] s2 [label = 'Bal. < 1,972'] s3 [label = 'Inc. > 27K'] node [shape = egg, fontname = 'Fira Sans', color = DeepPink, style = filled, fontcolor = White] l1 [label = 'No (98%)'] l4 [label = 'No (69%)'] node [shape = egg, fontname = 'Fira Sans', color = DeepPink, style = filled, fontcolor = White] l2 [label = 'Yes (76%)'] l3 [label = 'Yes (59%)'] edge [fontname = 'Fira Sans', color = Grey70] s1 -> l1 [label = 'F'] s1 -> s2 [label = 'T'] s2 -> s3 [label = 'T'] s2 -> l2 [label = 'F'] s3 -> l3 [label = 'T'] s3 -> l4 [label = 'F'] } ") ``` --- The graph's .hi-pink[separating lines] correspond to the tree's .attn[internal nodes]. ```{R, tree-internal, echo = F, cache = T} DiagrammeR::grViz(" digraph { graph [layout = dot, overlap = false, fontsize = 14] node [shape = oval, fontname = 'Fira Sans', color = DeepPink, style = filled, fontcolor = White] s1 [label = 'Bal. > 1,800'] s2 [label = 'Bal. < 1,972'] s3 [label = 'Inc. > 27K'] node [shape = egg, fontname = 'Fira Sans', color = Grey95, style = filled, fontcolor = White] l1 [label = 'No (98%)'] l4 [label = 'No (69%)'] node [shape = egg, fontname = 'Fira Sans', color = Grey95, style = filled, fontcolor = White] l2 [label = 'Yes (76%)'] l3 [label = 'Yes (59%)'] edge [fontname = 'Fira Sans', color = Grey70] s1 -> l1 [label = 'F'] s1 -> s2 [label = 'T'] s2 -> s3 [label = 'T'] s2 -> l2 [label = 'F'] s3 -> l3 [label = 'T'] s3 -> l4 [label = 'F'] } ") ``` --- The segments connecting the nodes within the tree are its .attn[branches]. ```{R, tree-branches, echo = F, cache = T} DiagrammeR::grViz(" digraph { graph [layout = dot, overlap = false, fontsize = 14] node [shape = oval, fontname = 'Fira Sans', color = Grey95, style = filled, fontcolor = White] s1 [label = 'Bal. > 1,800'] s2 [label = 'Bal. < 1,972'] s3 [label = 'Inc. > 27K'] node [shape = egg, fontname = 'Fira Sans', color = Grey95, style = filled, fontcolor = White] l1 [label = 'No (98%)'] l4 [label = 'No (69%)'] node [shape = egg, fontname = 'Fira Sans', color = Grey95, style = filled, fontcolor = White] l2 [label = 'Yes (76%)'] l3 [label = 'Yes (59%)'] edge [fontname = 'Fira Sans', color = DeepPink] s1 -> l1 [label = 'F'] s1 -> s2 [label = 'T'] s2 -> s3 [label = 'T'] s2 -> l2 [label = 'F'] s3 -> l3 [label = 'T'] s3 -> l4 [label = 'F'] } ") ``` --- class: middle You now know the anatomy of a decision tree. But where do trees come from—how do we train.super[🌲] a tree? .footnote[ 🌲 grow ] --- layout: true # Decision trees --- name: growth ## Growing trees We will start with .attn[regression trees], _i.e._, trees used in regression settings. -- As we saw, the task of .hi[growing a tree] involves two main steps: 1. .b[Divide the predictor space] into $J$ regions (using predictors $\mathbf{x}_1,\ldots,\mathbf{x}_p$) -- 1. .b[Make predictions] using the regions' mean outcome.
For region $R_j$ predict $\hat{y}_{R_j}$ where $$ \begin{align} \hat{y}_{R_j} = \frac{1}{n_j} \sum_{i\in R_j} y \end{align} $$ --- ## Growing trees We .hi[choose the regions to minimize RSS] .it[across all] $J$ [regions], _i.e._, $$ \begin{align} \sum_{j=1}^{J} \left( y_i - \hat{y}_{R_j} \right)^2 \end{align} $$ -- .b[Problem:] Examining every possible parition is computationally infeasible. -- .b[Solution:] a .it[top-down, greedy] algorithm named .attn[recursive binary splitting] - .attn[recursive] start with the "best" split, then find the next "best" split, ... - .attn[binary] each split creates two branches—"yes" and "no" - .attn[greedy] each step makes .it[best] split—no consideration of overall process --- ## Growing trees: Choosing a split .ex[Recall] Regression trees choose the split that minimizes RSS. To find this split, we need 1. a .purple[predictor], $\color{#6A5ACD}{\mathbf{x}_j}$ 1. a .attn[cutoff] $\color{#e64173}{s}$ that splits $\color{#6A5ACD}{\mathbf{x}_j}$ into two parts: (1) $\color{#6A5ACD}{\mathbf{x}_j} < \color{#e64173}{s}$ and (2) $\color{#6A5ACD}{\mathbf{x}_j} \ge \color{#e64173}{s}$ -- Searching across each of our .purple[predictors] $\color{#6A5ACD}{j}$ and all of their .pink[cutoffs] $\color{#e64173}{s}$,
we choose the combination that .b[minimizes RSS]. --- layout: true # Decision trees ## Example: Splitting --- name: ex-split .ex[Example] Consider the dataset ```{R, data-ex-split, echo = F} ex_df = tibble( "i" = 1:3, "pred." = c(0, 0, 0), "y" = c(0, 8, 6), "x.sub[1]" = c(1, 3, 5), "x.sub[2]" = c(4, 2, 6) ) ex_df %>% hux() %>% add_colnames() %>% set_align(1:4, 1:5, "center") %>% set_bold(1, 1:5, T) %>% set_bottom_border(1, c(1,3:5), 1) %>% set_text_color(1:4, 2, "white") ``` -- With just three observations, each variable only has two actual splits..super[🌲] .footnote[ 🌲 You can think about cutoffs as the ways we divide observations into two groups. ] --- One possible split: x.sub[1] at 2, which yields .purple[(.b[1]) x.sub[1] < 2] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 2] ```{R, ex-split1, echo = F} split1 = ex_df %>% mutate("pred." = c(0, 7, 7)) %>% hux() %>% add_colnames() %>% set_align(1:4, 1:5, "center") %>% set_bold(1, 1:5, T) %>% set_bottom_border(1, 1:5, 1) %>% set_text_color(2, 1:4, purple) %>% set_text_color(3:4, 1:4, red_pink) split1 %>% set_text_color(1:4, 2, "white") %>% set_bottom_border(1, 2, 0) ``` --- One possible split: x.sub[1] at 2, which yields .purple[(.b[1]) x.sub[1] < 2] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 2] ```{R, ex-split1b, echo = F} split1 = ex_df %>% mutate("pred." = c(0, 7, 7)) %>% hux() %>% add_colnames() %>% set_align(1:4, 1:5, "center") %>% set_bold(1, 1:5, T) %>% set_bottom_border(1, 1:5, 1) %>% set_text_color(2, 1:4, purple) %>% set_text_color(3:4, 1:4, red_pink) %>% set_bold(1:4, 2, T) split1 ``` This split yields an RSS of .purple[0.super[2]] + .pink[1.super[2]] + .pink[(-1).super[2]] = 2. -- .note[Note.sub[1]] Splitting x.sub[1] at 2 yields that same results as 1.5, 2.5—anything in (1, 3). -- .note[Note.sub[2]] Trees often grow until they hit some number of observations in a leaf. --- An alternative split: x.sub[1] at 4, which yields .purple[(.b[1]) x.sub[1] < 4] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 4] ```{R, ex-split2, echo = F} ex_df %>% mutate("pred." = c(4, 4, 6)) %>% hux() %>% add_colnames() %>% set_align(1:4, 1:5, "center") %>% set_bold(1, 1:5, T) %>% set_bottom_border(1, 1:5, 1) %>% set_text_color(2:3, 1:4, purple) %>% set_text_color(4, 1:4, red_pink) %>% set_bold(1:4, 2, T) ``` This split yields an RSS of .purple[(-4).super[2]] + .purple[4.super[2]] + .pink[0.super[2]] = 32. -- .it[Previous:] Splitting x.sub[1] at 4 yielded RSS = 2. .it.grey-light[(Much better)] --- Another split: x.sub[2] at 3, which yields .purple[(.b[1]) x.sub[1] < 3] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 3] ```{R, ex-split3, echo = F} ex_df %>% mutate("pred." = c(3, 8, 3)) %>% hux() %>% add_colnames() %>% set_align(1:4, 1:5, "center") %>% set_bold(1, 1:5, T) %>% set_bottom_border(1, 1:5, 1) %>% set_text_color(c(2,4), c(1:3,5), red_pink) %>% set_text_color(3, c(1:3,5), purple) %>% set_bold(1:4, 2, T) ``` This split yields an RSS of .pink[(-3).super[2]] + .purple[0.super[2]] + .pink[3.super[2]] = 18. --- Final split: x.sub[2] at 5, which yields .purple[(.b[1]) x.sub[1] < 5] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 5] ```{R, ex-split4, echo = F} ex_df %>% mutate("pred." = c(4, 4, 6)) %>% hux() %>% add_colnames() %>% set_align(1:4, 1:5, "center") %>% set_bold(1, 1:5, T) %>% set_bottom_border(1, 1:5, 1) %>% set_text_color(2:3, c(1:3,5), purple) %>% set_text_color(4, c(1:3,5), red_pink) %>% set_bold(1:4, 2, T) ``` This split yields an RSS of .pink[(-4).super[2]] + .pink[4.super[2]] + .purple[0.super[2]] = 32. --- Across our four possible splits (two variables each with two splits) - x.sub[1] with a cutoff of 2: .b[RSS] = 2 - x.sub[1] with a cutoff of 4: .b[RSS] = 32 - x.sub[2] with a cutoff of 3: .b[RSS] = 18 - x.sub[2] with a cutoff of 5: .b[RSS] = 32 our split of x.sub[1] at 2 generates the lowest RSS. --- layout: false class: clear, middle .note[Note:] Categorical predictors work in exactly the same way.
We want to try .b[all possible combinations] of the categories. .ex[Ex:] For a four-level categorical predicator (levels: A, B, C, D) .col-left[ - Split 1: .pink[A|B|C] .it[vs.] .purple[D] - Split 2: .pink[A|B|D] .it[vs.] .purple[C] - Split 3: .pink[A|C|D] .it[vs.] .purple[B] - Split 4: .pink[B|C|D] .it[vs.] .purple[A] ] .col-right[ - Split 5: .pink[A|B] .it[vs.] .purple[C|D] - Split 6: .pink[A|C] .it[vs.] .purple[B|D] - Split 7: .pink[A|D] .it[vs.] .purple[B|C] ] .clear-up[ we would need to try 7 possible splits. ] --- layout: true # Decision trees --- name: splits-more ## More splits Once we make our a split, we then continue splitting,
.b[conditional] on the regions from our previous splits. So if our first split creates R.sub[1] and R.sub[2], then our next split
searches the predictor space only in R.sub[1] or R.sub[2]..super[🌲] .footnote[ 🌲 We are no longer searching the full space—it is conditional on the previous splits. ] -- The tree continue to .b[grow until] it hits some specified threshold,
_e.g._, at most 5 observations in each leaf. --- ## Too many splits? One can have too many splits. .qa[Q] Why? -- .qa[A] "More splits" means 1. more flexibility (think about the bias-variance tradeoff/overfitting) 1. less interpretability (one of the selling points for trees) -- .qa[Q] So what can we do? -- .qa[A] Prune your trees! --- name: pruning ## Pruning .attn[Pruning] allows us to trim our trees back to their "best selves." .note[The idea:] Some regions may increase .hi[variance] more than they reduce .hi[bias].
By removing these regions, we gain in test MSE. .note[Candidates for trimming:] Regions that do not .b[reduce RSS] very much. -- .note[Updated strategy:] Grow big trees $T_0$ and then trim $T_0$ to an optimal .attn[subtree]. -- .note[Updated problem:] Considering all possible subtrees can get expensive. --- ## Pruning .attn[Cost-complexity pruning].super[🌲] offers a solution. .footnote[ 🌲 Also called: .it[weakest-link pruning]. ] Just as we did with lasso, .attn[cost-complexity pruning] forces the tree to pay a price (penalty) to become more complex .it[Complexity] here is defined as the number of regions $|T|$. --- ## Pruning Specifically, .attn[cost-complexity pruning] adds a penalty of $\alpha |T|$ to the RSS, _i.e._, $$ \begin{align} \sum_{m=1}^{|T|} \sum_{i:x\in R_m} \left( y_i - \hat{y}_{R_m} \right)^2 + \alpha |T| \end{align} $$ For any value of $\alpha (\ge 0)$, we get a subtree $T\subset T_0$. -- $\alpha = 0$ generates $T_0$, but as $\alpha$ increases, we begin to cut back the tree. -- We choose $\alpha$ via cross validation. --- name: ctree ## Classification trees Classification with trees is very similar to regression. -- .col-left[ .hi-purple[Regression trees] - .hi-slate[Predict:] Region's mean - .hi-slate[Split:] Minimize RSS - .hi-slate[Prune:] Penalized RSS ] -- .col-right[ .hi-pink[Classification trees] - .hi-slate[Predict:] Region's mode - .hi-slate[Split:] Min. Gini or entropy.super[🌲] - .hi-slate[Prune:] Penalized error rate.super[🌴] ] .footnote[ 🌲 Defined on the next slide. 🌴 ... or Gini index or entropy ] .clear-up[ An additional nuance for .attn[classification trees]: We typically care about the .b[proportions of classes in the leaves]—not just the final prediction. ] --- name: gini ## The Gini index Let $\hat{p}_{mk}$ denote the proportion of observations in class $k$ and region $m$. -- The .attn[Gini index] tells us about a region's "purity".super[🌲] $$ \begin{align} G = \sum_{k=1}^{K} \hat{p}_{mk} \left( 1 - \hat{p}_{mk} \right) \end{align} $$ if a region is very homogeneous, then the Gini index will be small. .footnote[ 🌲 This vocabulary is Voldemort's contribution to the machine-learning literature. ] Homogenous regions are easier to predict.
Reducing the Gini index yields to more homogeneous regions
∴ We want to minimize the Gini index. --- name: entropy ## Entropy Let $\hat{p}_{mk}$ denote the proportion of observations in class $k$ and region $m$. .attn[Entropy] also measures the "purity" of a node/leaf $$ \begin{align} D = - \sum_{k=1}^{K} \hat{p}_{mk} \log \left( \hat{p}_{mk} \right) \end{align} $$ .attn[Entropy] is also minimized when $\hat{p}_{mk}$ values are close to 0 and 1. --- name: class-why ## Rational .qa[Q] Why are we using the Gini index or entropy (*vs.* error rate)? -- .qa[A] The error rate isn't sufficiently sensitive to grow good trees.
The Gini index and entropy tell us about the .b[composition] of the leaf. -- .ex[Ex.] Consider two different leaves in a three-level classification. .col-left[ .b[Leaf 1] - .b[A:] 51, .b[B:] 49, .b[C:] 00 - .hi-orange[Error rate:] 49% - .hi-purple[Gini index:] 0.4998 - .hi-pink[Entropy:] 0.6929 ] .col-right[ .b[Leaf 2] - .b[A:] 51, .b[B:] 25, .b[C:] 24 - .hi-orange[Error rate:] 49% - .hi-purple[Gini index:] 0.6198 - .hi-pink[Entropy:] 1.0325 ] .clear-up[ The .hi-purple[Gini index] and .hi-pink[entropy] tell us about the distribution. ] --- ## Classification trees When .b[growing] classification trees, we want to use the Gini index or entropy. However, when .b[pruning], the error rate is typically fine—especially if accuracy will be the final criterion. --- name: in-r ## In R To train decision trees in R, we can use `caret`, which draws upon `rpart`. -- To `train()` our model in `caret` - our `method` is `"rpart"` - the main tuning parameter is `cp`, the .note[complexity parameter] (penalty) ```{R, train-rpart, cache = T} # Set seed set.seed(12345) # CV and train default_tree = train( default ~ ., data = default_df, method = "rpart", trControl = trainControl("cv", number = 5), tuneGrid = data.frame(cp = seq(0, 0.2, by = 0.005)) ) ``` --- layout: true class: clear --- .b[Accuracy and complexity] via `cp`, the penalty for complexity ```{R, plot-cv-cp, echo = F} ggplot( data = default_tree$results, aes(x = cp, y = Accuracy) ) + geom_line(size = 0.4) + geom_point(size = 3.5) + theme_minimal(base_size = 20, base_family = "Fira Sans Book") ``` --- class: middle To plot the CV-chosen tree, we need to 1. .b[extract] the fitted model, _e.g._, `default_tree$finalModel` 1. apply a .b[plotting function] _e.g._, `rpart.plot()` from `rpart.plot` --- class: clear, middle ```{R, plot-rpart-cv, echo = F} rpart.plot( default_tree$finalModel, extra = 104, box.palette = "Oranges", branch.lty = 3, shadow.col = "gray", nn = TRUE, cex = 1.3 ) ``` --- class: clear, middle which we can compare to a less unpruned tree (`cp = 0.005`) --- class: clear, middle ```{R, plot-tree_complex, echo = F} tree_complex = train( default ~ ., data = default_df, method = "rpart", trControl = trainControl("none"), tuneGrid = data.frame(cp = 0.005) ) rpart.plot( tree_complex$finalModel, extra = 104, box.palette = "Oranges", branch.lty = 3, shadow.col = "gray", nn = TRUE, cex = 1.2 ) ``` --- class: clear, middle And now for a more penalized tree (`cp = 0.1`)... --- class: clear, middle ```{R, plot-tree_simple, echo = F} tree_simple = train( default ~ ., data = default_df, method = "rpart", trControl = trainControl("none"), tuneGrid = data.frame(cp = 0.1) ) rpart.plot( tree_simple$finalModel, extra = 104, box.palette = "Oranges", branch.lty = 3, shadow.col = "gray", nn = TRUE, cex = 1.3 ) ``` --- layout: true class: clear, middle --- name: linearity .qa[Q] How do trees compare to linear models? .tran[.b[A] It depends how linear truth is.] --- .qa[Q] How do trees compare to linear models? .qa[A] It depends how linear the true boundary is. --- .b[Linear boundary:] trees struggle to recreate a line. ```{R, fig-compare-linear, echo = F} knitr::include_graphics("images/compare-linear.png") ``` .ex.small[Source: ISL, p. 315] --- .b[Nonlinear boundary:] trees easily replicate the nonlinear boundary. ```{R, fig-compare-nonlinear, echo = F} knitr::include_graphics("images/compare-nonlinear.png") ``` .ex.small[Source: ISL, p. 315] --- layout: false name: tree-pro-con # Decision trees ## Strengths and weaknesses As with any method, decision trees have tradeoffs. .col-left.purple.small[ .b[Strengths]
.b[+] Easily explained/interpretted
.b[+] Include several graphical options
.b[+] Mirror human decision making?
.b[+] Handle num. or cat. on LHS/RHS.super[🌳] ] .footnote[ 🌳 Without needing to create lots of dummy variables!
.tran[🌴 Blank] ] -- .col-right.pink.small[ .b[Weaknesses]
.b[-] Outperformed by other methods
.b[-] Struggle with linearity
.b[-] Can be very "non-robust" ] .clear-up[ .attn[Non-robust:] Small data changes can cause huge changes in our tree. ] -- .note[Next:] Create ensembles of trees.super[🌲] to strengthen these weaknesses..super[🌴] .footnote[ .tran[🌴 Blank]
🌲 Forests! 🌴 Which will also weaken some of the strengths. ] --- name: sources layout: false # Sources These notes draw upon - [An Introduction to Statistical Learning](http://faculty.marshall.usc.edu/gareth-james/ISL/) (*ISL*)
James, Witten, Hastie, and Tibshirani --- # Table of contents .col-left[ .smallest[ #### Admin - [Today](#admin-today) - [Upcoming](#admin-soon) #### Decision trees 1. [Fundamentals](#fundamentals) 1. [Partitioning predictors](#ex-partition) 1. [Definitions](#defn) 1. [Growing trees](#growth) 1. [Example: Splitting](#ex-split) 1. [More splits](#splits-more) 1. [Pruning](#pruning) 1. [Classification trees](#ctree) - [The Gini index](#gini) - [Entropy](#entropy) - [Rationale](#class-why) 1. [In R](#in-r) 1. [Linearity](#linearity) 1. [Strengths and weaknesses](#tree-pro-con) ] ] .col-right[ .smallest[ #### Other - [Sources/references](#sources) ] ]