---
title: "Lecture .mono[007]"
subtitle: "Trees 🌲🌴🌳"
author: "Edward Rubin"
#date: "`r format(Sys.time(), '%d %B %Y')`"
date: "20 February 2020"
output:
xaringan::moon_reader:
css: ['default', 'metropolis', 'metropolis-fonts', 'my-css.css']
# self_contained: true
nature:
highlightStyle: github
highlightLines: true
countIncrementalSlides: false
---
exclude: true
```{R, setup, include = F}
library(pacman)
p_load(
ISLR,
broom, tidyverse,
ggplot2, ggthemes, ggforce, ggridges, cowplot, scales,
latex2exp, viridis, extrafont, gridExtra, plotly, ggformula,
DiagrammeR,
kableExtra, DT, huxtable,
data.table, dplyr, snakecase, janitor,
lubridate, knitr,
caret, rpart, rpart.plot, rattle,
here, magrittr, parallel
)
# Define colors
red_pink = "#e64173"
turquoise = "#20B2AA"
orange = "#FFA500"
red = "#fb6107"
blue = "#3b3b9a"
green = "#8bb174"
grey_light = "grey70"
grey_mid = "grey50"
grey_dark = "grey20"
purple = "#6A5ACD"
slate = "#314f4f"
# Knitr options
opts_chunk$set(
comment = "#>",
fig.align = "center",
fig.height = 7,
fig.width = 10.5,
warning = F,
message = F
)
opts_chunk$set(dev = "svg")
options(device = function(file, width, height) {
svg(tempfile(), width = width, height = height)
})
options(knitr.table.format = "html")
```
---
layout: true
# Admin
---
class: inverse, middle
---
name: admin-today
## Material
Decision trees for regression and classification.
---
name: admin-soon
## Upcoming
.b[Readings]
- .note[Today] .it[ISL] Ch. 8.1
- .note[Next] .it[ISL] Ch. 8.2
.b[Problem sets]
- .it[Classification] Due today
- Let Connor know if you are resubmitting
.b[Project] Project topic due before midnight on Friday.
---
layout: true
# Decision trees
---
class: inverse, middle
---
name: fundamentals
## Fundamentals
.attn[Decision trees]
- split the .it[predictor space] (our $\mathbf{X}$) into regions
- then predict the most-common value within a region
--
.attn[Tree-based methods]
1. work for .hi[both classification and regression]
--
1. are inherently .hi[nonlinear]
--
1. are relatively .hi[simple] and .hi[interpretable]
--
1. often .hi[underperform] relatively to competing methods
--
1. easily extend to .hi[very competitive ensemble methods] (*many* trees).super[🌲]
.footnote[
🌲 Though the ensembles will be much less interpretable.
]
---
layout: true
class: clear
---
exclude: true
```{R, data-default, include = F}
# Load 'Defualt' data from 'ISLR'
default_df = ISLR::Default %>% as_tibble()
```
---
.ex[Example:] .b[A simple decision tree] classifying credit-card default
```{R, tree-graph, echo = F, cache = T}
DiagrammeR::grViz("
digraph {
graph [layout = dot, overlap = false, fontsize = 14]
node [shape = oval, fontname = 'Fira Sans', color = Gray95, style = filled]
s1 [label = 'Bal. > 1,800']
s2 [label = 'Bal. < 1,972']
s3 [label = 'Inc. > 27K']
node [shape = egg, fontname = 'Fira Sans', color = Purple, style = filled, fontcolor = White]
l1 [label = 'No (98%)']
l4 [label = 'No (69%)']
node [shape = egg, fontname = 'Fira Sans', color = Orange, style = filled, fontcolor = White]
l2 [label = 'Yes (76%)']
l3 [label = 'Yes (59%)']
edge [fontname = 'Fira Sans', color = Grey70]
s1 -> l1 [label = 'F']
s1 -> s2 [label = 'T']
s2 -> s3 [label = 'T']
s2 -> l2 [label = 'F']
s3 -> l3 [label = 'T']
s3 -> l4 [label = 'F']
}
")
```
---
name: ex-partition
Let's see how the tree works
--
—starting with our data (default: .orange[Yes] .it[vs.] .purple[No]).
```{R, partition-base, include = F, cache = T}
gg_base = ggplot(
data = default_df,
aes(x = balance, y = income, color = default, alpha = default)
) +
geom_hline(yintercept = 0) +
geom_vline(xintercept = 0) +
geom_point(size = 2) +
scale_y_continuous("Income", labels = dollar) +
scale_x_continuous("Balance", labels = dollar) +
scale_color_manual("Defaulted:", values = c(purple, orange), labels = c("No", "Yes")) +
scale_alpha_manual("Defaulted:", values = c(0.1, 0.8), labels = c("No", "Yes")) +
theme_minimal(base_size = 20, base_family = "Fira Sans Book") +
theme(legend.position = "none")
```
```{R, plot-raw, echo = F}
gg_base
```
---
The .hi-pink[first partition] splits balance at $1,800.
```{R, plot-split1, echo = F, cache = T}
gg_base +
annotate(
"segment",
x = 1800, xend = 1800, y = -Inf, yend = Inf,
color = red_pink, size = 1.2
)
```
---
The .hi-pink[second partition] splits balance at $1,972, (.it[conditional on bal. > $1,800]).
```{R, plot-split2, echo = F, cache = T}
gg_base +
annotate(
"segment",
x = 1800, xend = 1800, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1972, xend = 1972, y = -Inf, yend = Inf,
color = red_pink, size = 1.2
)
```
---
The .hi-pink[third partition] splits income at $27K .b[for] bal. between $1,800 and $1,972.
```{R, plot-split3, echo = F, cache = T}
gg_base +
annotate(
"segment",
x = 1800, xend = 1800, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1972, xend = 1972, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1800, xend = 1972, y = 27e3, yend = 27e3,
color = red_pink, size = 1.2
)
```
---
These three partitions give us four .b[regions]...
```{R, plot-split3b, echo = F, cache = T}
gg_base +
annotate(
"segment",
x = 1800, xend = 1800, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1972, xend = 1972, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1800, xend = 1972, y = 27e3, yend = 27e3,
linetype = "longdash"
) +
annotate("text",
x = 900, y = 37500, label = expression(R[1]),
size = 8, family = "Fira Sans Book"
) +
annotate("text",
x = 1886, y = 5.1e4, label = expression(R[2]),
size = 8, family = "Fira Sans Book"
) +
annotate("text",
x = 1886, y = 1e4, label = expression(R[3]),
size = 8, family = "Fira Sans Book"
) +
annotate("text",
x = 2336, y = 37500, label = expression(R[4]),
size = 8, family = "Fira Sans Book"
)
```
---
.b[Predictions] cover each region (_e.g._, using the region's most common class).
```{R, plot-split3c, echo = F, cache = T}
gg_base +
annotate(
"rect",
xmin = 0, xmax = 1800, ymin = 0, ymax = Inf,
fill = purple, alpha = 0.3
) +
annotate(
"segment",
x = 1800, xend = 1800, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1972, xend = 1972, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1800, xend = 1972, y = 27e3, yend = 27e3,
linetype = "longdash"
)
```
---
.b[Predictions] cover each region (_e.g._, using the region's most common class).
```{R, plot-split3d, echo = F, cache = T}
gg_base +
annotate(
"rect",
xmin = 0, xmax = 1800, ymin = 0, ymax = Inf,
fill = purple, alpha = 0.3
) +
annotate(
"rect",
xmin = 1800, xmax = 1972, ymin = 27e3, ymax = Inf,
fill = orange, alpha = 0.3
) +
annotate(
"segment",
x = 1800, xend = 1800, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1972, xend = 1972, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1800, xend = 1972, y = 27e3, yend = 27e3,
linetype = "longdash"
)
```
---
.b[Predictions] cover each region (_e.g._, using the region's most common class).
```{R, plot-split3e, echo = F, cache = T}
gg_base +
annotate(
"rect",
xmin = 0, xmax = 1800, ymin = 0, ymax = Inf,
fill = purple, alpha = 0.3
) +
annotate(
"rect",
xmin = 1800, xmax = 1972, ymin = 27e3, ymax = Inf,
fill = orange, alpha = 0.3
) +
annotate(
"rect",
xmin = 1800, xmax = 1972, ymin = 0, ymax = 27e3,
fill = purple, alpha = 0.3
) +
annotate(
"segment",
x = 1800, xend = 1800, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1972, xend = 1972, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1800, xend = 1972, y = 27e3, yend = 27e3,
linetype = "longdash"
)
```
---
.b[Predictions] cover each region (_e.g._, using the region's most common class).
```{R, plot-split3f, echo = F, cache = T}
gg_base +
annotate(
"rect",
xmin = 0, xmax = 1800, ymin = 0, ymax = Inf,
fill = purple, alpha = 0.3
) +
annotate(
"rect",
xmin = 1800, xmax = 1972, ymin = 27e3, ymax = Inf,
fill = orange, alpha = 0.3
) +
annotate(
"rect",
xmin = 1800, xmax = 1972, ymin = 0, ymax = 27e3,
fill = purple, alpha = 0.3
) +
annotate(
"rect",
xmin = 1972, xmax = Inf, ymin = 0, ymax = Inf,
fill = orange, alpha = 0.3
) +
annotate(
"segment",
x = 1800, xend = 1800, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1972, xend = 1972, y = -Inf, yend = Inf,
linetype = "longdash"
) +
annotate(
"segment",
x = 1800, xend = 1972, y = 27e3, yend = 27e3,
linetype = "longdash"
)
```
---
name: defn
The .hi-pink[regions] correspond to the tree's .attn[terminal nodes] (or .attn[leaves]).
```{R, tree-leaves, echo = F, cache = T}
DiagrammeR::grViz("
digraph {
graph [layout = dot, overlap = false, fontsize = 14]
node [shape = oval, fontname = 'Fira Sans', color = Gray95, style = filled]
s1 [label = 'Bal. > 1,800']
s2 [label = 'Bal. < 1,972']
s3 [label = 'Inc. > 27K']
node [shape = egg, fontname = 'Fira Sans', color = DeepPink, style = filled, fontcolor = White]
l1 [label = 'No (98%)']
l4 [label = 'No (69%)']
node [shape = egg, fontname = 'Fira Sans', color = DeepPink, style = filled, fontcolor = White]
l2 [label = 'Yes (76%)']
l3 [label = 'Yes (59%)']
edge [fontname = 'Fira Sans', color = Grey70]
s1 -> l1 [label = 'F']
s1 -> s2 [label = 'T']
s2 -> s3 [label = 'T']
s2 -> l2 [label = 'F']
s3 -> l3 [label = 'T']
s3 -> l4 [label = 'F']
}
")
```
---
The graph's .hi-pink[separating lines] correspond to the tree's .attn[internal nodes].
```{R, tree-internal, echo = F, cache = T}
DiagrammeR::grViz("
digraph {
graph [layout = dot, overlap = false, fontsize = 14]
node [shape = oval, fontname = 'Fira Sans', color = DeepPink, style = filled, fontcolor = White]
s1 [label = 'Bal. > 1,800']
s2 [label = 'Bal. < 1,972']
s3 [label = 'Inc. > 27K']
node [shape = egg, fontname = 'Fira Sans', color = Grey95, style = filled, fontcolor = White]
l1 [label = 'No (98%)']
l4 [label = 'No (69%)']
node [shape = egg, fontname = 'Fira Sans', color = Grey95, style = filled, fontcolor = White]
l2 [label = 'Yes (76%)']
l3 [label = 'Yes (59%)']
edge [fontname = 'Fira Sans', color = Grey70]
s1 -> l1 [label = 'F']
s1 -> s2 [label = 'T']
s2 -> s3 [label = 'T']
s2 -> l2 [label = 'F']
s3 -> l3 [label = 'T']
s3 -> l4 [label = 'F']
}
")
```
---
The segments connecting the nodes within the tree are its .attn[branches].
```{R, tree-branches, echo = F, cache = T}
DiagrammeR::grViz("
digraph {
graph [layout = dot, overlap = false, fontsize = 14]
node [shape = oval, fontname = 'Fira Sans', color = Grey95, style = filled, fontcolor = White]
s1 [label = 'Bal. > 1,800']
s2 [label = 'Bal. < 1,972']
s3 [label = 'Inc. > 27K']
node [shape = egg, fontname = 'Fira Sans', color = Grey95, style = filled, fontcolor = White]
l1 [label = 'No (98%)']
l4 [label = 'No (69%)']
node [shape = egg, fontname = 'Fira Sans', color = Grey95, style = filled, fontcolor = White]
l2 [label = 'Yes (76%)']
l3 [label = 'Yes (59%)']
edge [fontname = 'Fira Sans', color = DeepPink]
s1 -> l1 [label = 'F']
s1 -> s2 [label = 'T']
s2 -> s3 [label = 'T']
s2 -> l2 [label = 'F']
s3 -> l3 [label = 'T']
s3 -> l4 [label = 'F']
}
")
```
---
class: middle
You now know the anatomy of a decision tree.
But where do trees come from—how do we train.super[🌲] a tree?
.footnote[
🌲 grow
]
---
layout: true
# Decision trees
---
name: growth
## Growing trees
We will start with .attn[regression trees], _i.e._, trees used in regression settings.
--
As we saw, the task of .hi[growing a tree] involves two main steps:
1. .b[Divide the predictor space] into $J$ regions (using predictors $\mathbf{x}_1,\ldots,\mathbf{x}_p$)
--
1. .b[Make predictions] using the regions' mean outcome.
For region $R_j$ predict $\hat{y}_{R_j}$ where
$$
\begin{align}
\hat{y}_{R_j} = \frac{1}{n_j} \sum_{i\in R_j} y
\end{align}
$$
---
## Growing trees
We .hi[choose the regions to minimize RSS] .it[across all] $J$ [regions], _i.e._,
$$
\begin{align}
\sum_{j=1}^{J} \left( y_i - \hat{y}_{R_j} \right)^2
\end{align}
$$
--
.b[Problem:] Examining every possible parition is computationally infeasible.
--
.b[Solution:] a .it[top-down, greedy] algorithm named .attn[recursive binary splitting]
- .attn[recursive] start with the "best" split, then find the next "best" split, ...
- .attn[binary] each split creates two branches—"yes" and "no"
- .attn[greedy] each step makes .it[best] split—no consideration of overall process
---
## Growing trees: Choosing a split
.ex[Recall] Regression trees choose the split that minimizes RSS.
To find this split, we need
1. a .purple[predictor], $\color{#6A5ACD}{\mathbf{x}_j}$
1. a .attn[cutoff] $\color{#e64173}{s}$ that splits $\color{#6A5ACD}{\mathbf{x}_j}$ into two parts: (1) $\color{#6A5ACD}{\mathbf{x}_j} < \color{#e64173}{s}$ and (2) $\color{#6A5ACD}{\mathbf{x}_j} \ge \color{#e64173}{s}$
--
Searching across each of our .purple[predictors] $\color{#6A5ACD}{j}$ and all of their .pink[cutoffs] $\color{#e64173}{s}$,
we choose the combination that .b[minimizes RSS].
---
layout: true
# Decision trees
## Example: Splitting
---
name: ex-split
.ex[Example] Consider the dataset
```{R, data-ex-split, echo = F}
ex_df = tibble(
"i" = 1:3,
"pred." = c(0, 0, 0),
"y" = c(0, 8, 6),
"x.sub[1]" = c(1, 3, 5),
"x.sub[2]" = c(4, 2, 6)
)
ex_df %>% hux() %>%
add_colnames() %>%
set_align(1:4, 1:5, "center") %>%
set_bold(1, 1:5, T) %>%
set_bottom_border(1, c(1,3:5), 1) %>%
set_text_color(1:4, 2, "white")
```
--
With just three observations, each variable only has two actual splits..super[🌲]
.footnote[
🌲 You can think about cutoffs as the ways we divide observations into two groups.
]
---
One possible split: x.sub[1] at 2, which yields .purple[(.b[1]) x.sub[1] < 2] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 2]
```{R, ex-split1, echo = F}
split1 = ex_df %>% mutate("pred." = c(0, 7, 7)) %>%
hux() %>%
add_colnames() %>%
set_align(1:4, 1:5, "center") %>%
set_bold(1, 1:5, T) %>%
set_bottom_border(1, 1:5, 1) %>%
set_text_color(2, 1:4, purple) %>%
set_text_color(3:4, 1:4, red_pink)
split1 %>% set_text_color(1:4, 2, "white") %>% set_bottom_border(1, 2, 0)
```
---
One possible split: x.sub[1] at 2, which yields .purple[(.b[1]) x.sub[1] < 2] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 2]
```{R, ex-split1b, echo = F}
split1 = ex_df %>% mutate("pred." = c(0, 7, 7)) %>%
hux() %>%
add_colnames() %>%
set_align(1:4, 1:5, "center") %>%
set_bold(1, 1:5, T) %>%
set_bottom_border(1, 1:5, 1) %>%
set_text_color(2, 1:4, purple) %>%
set_text_color(3:4, 1:4, red_pink) %>%
set_bold(1:4, 2, T)
split1
```
This split yields an RSS of .purple[0.super[2]] + .pink[1.super[2]] + .pink[(-1).super[2]] = 2.
--
.note[Note.sub[1]] Splitting x.sub[1] at 2 yields that same results as 1.5, 2.5—anything in (1, 3).
--
.note[Note.sub[2]] Trees often grow until they hit some number of observations in a leaf.
---
An alternative split: x.sub[1] at 4, which yields .purple[(.b[1]) x.sub[1] < 4] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 4]
```{R, ex-split2, echo = F}
ex_df %>% mutate("pred." = c(4, 4, 6)) %>%
hux() %>%
add_colnames() %>%
set_align(1:4, 1:5, "center") %>%
set_bold(1, 1:5, T) %>%
set_bottom_border(1, 1:5, 1) %>%
set_text_color(2:3, 1:4, purple) %>%
set_text_color(4, 1:4, red_pink) %>%
set_bold(1:4, 2, T)
```
This split yields an RSS of .purple[(-4).super[2]] + .purple[4.super[2]] + .pink[0.super[2]] = 32.
--
.it[Previous:] Splitting x.sub[1] at 4 yielded RSS = 2. .it.grey-light[(Much better)]
---
Another split: x.sub[2] at 3, which yields .purple[(.b[1]) x.sub[1] < 3] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 3]
```{R, ex-split3, echo = F}
ex_df %>% mutate("pred." = c(3, 8, 3)) %>%
hux() %>%
add_colnames() %>%
set_align(1:4, 1:5, "center") %>%
set_bold(1, 1:5, T) %>%
set_bottom_border(1, 1:5, 1) %>%
set_text_color(c(2,4), c(1:3,5), red_pink) %>%
set_text_color(3, c(1:3,5), purple) %>%
set_bold(1:4, 2, T)
```
This split yields an RSS of .pink[(-3).super[2]] + .purple[0.super[2]] + .pink[3.super[2]] = 18.
---
Final split: x.sub[2] at 5, which yields .purple[(.b[1]) x.sub[1] < 5] .it[vs.] .pink[(.b[2]) x.sub[1] ≥ 5]
```{R, ex-split4, echo = F}
ex_df %>% mutate("pred." = c(4, 4, 6)) %>%
hux() %>%
add_colnames() %>%
set_align(1:4, 1:5, "center") %>%
set_bold(1, 1:5, T) %>%
set_bottom_border(1, 1:5, 1) %>%
set_text_color(2:3, c(1:3,5), purple) %>%
set_text_color(4, c(1:3,5), red_pink) %>%
set_bold(1:4, 2, T)
```
This split yields an RSS of .pink[(-4).super[2]] + .pink[4.super[2]] + .purple[0.super[2]] = 32.
---
Across our four possible splits (two variables each with two splits)
- x.sub[1] with a cutoff of 2: .b[RSS] = 2
- x.sub[1] with a cutoff of 4: .b[RSS] = 32
- x.sub[2] with a cutoff of 3: .b[RSS] = 18
- x.sub[2] with a cutoff of 5: .b[RSS] = 32
our split of x.sub[1] at 2 generates the lowest RSS.
---
layout: false
class: clear, middle
.note[Note:] Categorical predictors work in exactly the same way.
We want to try .b[all possible combinations] of the categories.
.ex[Ex:] For a four-level categorical predicator (levels: A, B, C, D)
.col-left[
- Split 1: .pink[A|B|C] .it[vs.] .purple[D]
- Split 2: .pink[A|B|D] .it[vs.] .purple[C]
- Split 3: .pink[A|C|D] .it[vs.] .purple[B]
- Split 4: .pink[B|C|D] .it[vs.] .purple[A]
]
.col-right[
- Split 5: .pink[A|B] .it[vs.] .purple[C|D]
- Split 6: .pink[A|C] .it[vs.] .purple[B|D]
- Split 7: .pink[A|D] .it[vs.] .purple[B|C]
]
.clear-up[
we would need to try 7 possible splits.
]
---
layout: true
# Decision trees
---
name: splits-more
## More splits
Once we make our a split, we then continue splitting,
.b[conditional] on the regions from our previous splits.
So if our first split creates R.sub[1] and R.sub[2], then our next split
searches the predictor space only in R.sub[1] or R.sub[2]..super[🌲]
.footnote[
🌲 We are no longer searching the full space—it is conditional on the previous splits.
]
--
The tree continue to .b[grow until] it hits some specified threshold,
_e.g._, at most 5 observations in each leaf.
---
## Too many splits?
One can have too many splits.
.qa[Q] Why?
--
.qa[A] "More splits" means
1. more flexibility (think about the bias-variance tradeoff/overfitting)
1. less interpretability (one of the selling points for trees)
--
.qa[Q] So what can we do?
--
.qa[A] Prune your trees!
---
name: pruning
## Pruning
.attn[Pruning] allows us to trim our trees back to their "best selves."
.note[The idea:] Some regions may increase .hi[variance] more than they reduce .hi[bias].
By removing these regions, we gain in test MSE.
.note[Candidates for trimming:] Regions that do not .b[reduce RSS] very much.
--
.note[Updated strategy:] Grow big trees $T_0$ and then trim $T_0$ to an optimal .attn[subtree].
--
.note[Updated problem:] Considering all possible subtrees can get expensive.
---
## Pruning
.attn[Cost-complexity pruning].super[🌲] offers a solution.
.footnote[
🌲 Also called: .it[weakest-link pruning].
]
Just as we did with lasso, .attn[cost-complexity pruning] forces the tree to pay a price (penalty) to become more complex
.it[Complexity] here is defined as the number of regions $|T|$.
---
## Pruning
Specifically, .attn[cost-complexity pruning] adds a penalty of $\alpha |T|$ to the RSS, _i.e._,
$$
\begin{align}
\sum_{m=1}^{|T|} \sum_{i:x\in R_m} \left( y_i - \hat{y}_{R_m} \right)^2 + \alpha |T|
\end{align}
$$
For any value of $\alpha (\ge 0)$, we get a subtree $T\subset T_0$.
--
$\alpha = 0$ generates $T_0$, but as $\alpha$ increases, we begin to cut back the tree.
--
We choose $\alpha$ via cross validation.
---
name: ctree
## Classification trees
Classification with trees is very similar to regression.
--
.col-left[
.hi-purple[Regression trees]
- .hi-slate[Predict:] Region's mean
- .hi-slate[Split:] Minimize RSS
- .hi-slate[Prune:] Penalized RSS
]
--
.col-right[
.hi-pink[Classification trees]
- .hi-slate[Predict:] Region's mode
- .hi-slate[Split:] Min. Gini or entropy.super[🌲]
- .hi-slate[Prune:] Penalized error rate.super[🌴]
]
.footnote[
🌲 Defined on the next slide. 🌴 ... or Gini index or entropy
]
.clear-up[
An additional nuance for .attn[classification trees]: We typically care about the .b[proportions of classes in the leaves]—not just the final prediction.
]
---
name: gini
## The Gini index
Let $\hat{p}_{mk}$ denote the proportion of observations in class $k$ and region $m$.
--
The .attn[Gini index] tells us about a region's "purity".super[🌲]
$$
\begin{align}
G = \sum_{k=1}^{K} \hat{p}_{mk} \left( 1 - \hat{p}_{mk} \right)
\end{align}
$$
if a region is very homogeneous, then the Gini index will be small.
.footnote[
🌲 This vocabulary is Voldemort's contribution to the machine-learning literature.
]
Homogenous regions are easier to predict.
Reducing the Gini index yields to more homogeneous regions
∴ We want to minimize the Gini index.
---
name: entropy
## Entropy
Let $\hat{p}_{mk}$ denote the proportion of observations in class $k$ and region $m$.
.attn[Entropy] also measures the "purity" of a node/leaf
$$
\begin{align}
D = - \sum_{k=1}^{K} \hat{p}_{mk} \log \left( \hat{p}_{mk} \right)
\end{align}
$$
.attn[Entropy] is also minimized when $\hat{p}_{mk}$ values are close to 0 and 1.
---
name: class-why
## Rational
.qa[Q] Why are we using the Gini index or entropy (*vs.* error rate)?
--
.qa[A] The error rate isn't sufficiently sensitive to grow good trees.
The Gini index and entropy tell us about the .b[composition] of the leaf.
--
.ex[Ex.] Consider two different leaves in a three-level classification.
.col-left[
.b[Leaf 1]
- .b[A:] 51, .b[B:] 49, .b[C:] 00
- .hi-orange[Error rate:] 49%
- .hi-purple[Gini index:] 0.4998
- .hi-pink[Entropy:] 0.6929
]
.col-right[
.b[Leaf 2]
- .b[A:] 51, .b[B:] 25, .b[C:] 24
- .hi-orange[Error rate:] 49%
- .hi-purple[Gini index:] 0.6198
- .hi-pink[Entropy:] 1.0325
]
.clear-up[
The .hi-purple[Gini index] and .hi-pink[entropy] tell us about the distribution.
]
---
## Classification trees
When .b[growing] classification trees, we want to use the Gini index or entropy.
However, when .b[pruning], the error rate is typically fine—especially if accuracy will be the final criterion.
---
name: in-r
## In R
To train decision trees in R, we can use `caret`, which draws upon `rpart`.
--
To `train()` our model in `caret`
- our `method` is `"rpart"`
- the main tuning parameter is `cp`, the .note[complexity parameter] (penalty)
```{R, train-rpart, cache = T}
# Set seed
set.seed(12345)
# CV and train
default_tree = train(
default ~ .,
data = default_df,
method = "rpart",
trControl = trainControl("cv", number = 5),
tuneGrid = data.frame(cp = seq(0, 0.2, by = 0.005))
)
```
---
layout: true
class: clear
---
.b[Accuracy and complexity] via `cp`, the penalty for complexity
```{R, plot-cv-cp, echo = F}
ggplot(
data = default_tree$results,
aes(x = cp, y = Accuracy)
) +
geom_line(size = 0.4) +
geom_point(size = 3.5) +
theme_minimal(base_size = 20, base_family = "Fira Sans Book")
```
---
class: middle
To plot the CV-chosen tree, we need to
1. .b[extract] the fitted model, _e.g._, `default_tree$finalModel`
1. apply a .b[plotting function] _e.g._, `rpart.plot()` from `rpart.plot`
---
class: clear, middle
```{R, plot-rpart-cv, echo = F}
rpart.plot(
default_tree$finalModel,
extra = 104,
box.palette = "Oranges",
branch.lty = 3,
shadow.col = "gray",
nn = TRUE,
cex = 1.3
)
```
---
class: clear, middle
which we can compare to a less unpruned tree (`cp = 0.005`)
---
class: clear, middle
```{R, plot-tree_complex, echo = F}
tree_complex = train(
default ~ .,
data = default_df,
method = "rpart",
trControl = trainControl("none"),
tuneGrid = data.frame(cp = 0.005)
)
rpart.plot(
tree_complex$finalModel,
extra = 104,
box.palette = "Oranges",
branch.lty = 3,
shadow.col = "gray",
nn = TRUE,
cex = 1.2
)
```
---
class: clear, middle
And now for a more penalized tree (`cp = 0.1`)...
---
class: clear, middle
```{R, plot-tree_simple, echo = F}
tree_simple = train(
default ~ .,
data = default_df,
method = "rpart",
trControl = trainControl("none"),
tuneGrid = data.frame(cp = 0.1)
)
rpart.plot(
tree_simple$finalModel,
extra = 104,
box.palette = "Oranges",
branch.lty = 3,
shadow.col = "gray",
nn = TRUE,
cex = 1.3
)
```
---
layout: true
class: clear, middle
---
name: linearity
.qa[Q] How do trees compare to linear models?
.tran[.b[A] It depends how linear truth is.]
---
.qa[Q] How do trees compare to linear models?
.qa[A] It depends how linear the true boundary is.
---
.b[Linear boundary:] trees struggle to recreate a line.
```{R, fig-compare-linear, echo = F}
knitr::include_graphics("images/compare-linear.png")
```
.ex.small[Source: ISL, p. 315]
---
.b[Nonlinear boundary:] trees easily replicate the nonlinear boundary.
```{R, fig-compare-nonlinear, echo = F}
knitr::include_graphics("images/compare-nonlinear.png")
```
.ex.small[Source: ISL, p. 315]
---
layout: false
name: tree-pro-con
# Decision trees
## Strengths and weaknesses
As with any method, decision trees have tradeoffs.
.col-left.purple.small[
.b[Strengths]
.b[+] Easily explained/interpretted
.b[+] Include several graphical options
.b[+] Mirror human decision making?
.b[+] Handle num. or cat. on LHS/RHS.super[🌳]
]
.footnote[
🌳 Without needing to create lots of dummy variables!
.tran[🌴 Blank]
]
--
.col-right.pink.small[
.b[Weaknesses]
.b[-] Outperformed by other methods
.b[-] Struggle with linearity
.b[-] Can be very "non-robust"
]
.clear-up[
.attn[Non-robust:] Small data changes can cause huge changes in our tree.
]
--
.note[Next:] Create ensembles of trees.super[🌲] to strengthen these weaknesses..super[🌴]
.footnote[
.tran[🌴 Blank]
🌲 Forests! 🌴 Which will also weaken some of the strengths.
]
---
name: sources
layout: false
# Sources
These notes draw upon
- [An Introduction to Statistical Learning](http://faculty.marshall.usc.edu/gareth-james/ISL/) (*ISL*)
James, Witten, Hastie, and Tibshirani
---
# Table of contents
.col-left[
.smallest[
#### Admin
- [Today](#admin-today)
- [Upcoming](#admin-soon)
#### Decision trees
1. [Fundamentals](#fundamentals)
1. [Partitioning predictors](#ex-partition)
1. [Definitions](#defn)
1. [Growing trees](#growth)
1. [Example: Splitting](#ex-split)
1. [More splits](#splits-more)
1. [Pruning](#pruning)
1. [Classification trees](#ctree)
- [The Gini index](#gini)
- [Entropy](#entropy)
- [Rationale](#class-why)
1. [In R](#in-r)
1. [Linearity](#linearity)
1. [Strengths and weaknesses](#tree-pro-con)
]
]
.col-right[
.smallest[
#### Other
- [Sources/references](#sources)
]
]