In this document, we are going to code up some resampling functions that will help us better understand cross validation. Please open an R script on your computer and code up a script while following along.
To start, make sure to put some thought into where you work on this script. I’d recommend either
Remember, keeping your computer organized in a methodical way will save your future self time (and pain). Within the project folder, I recommend creating two folders; “R” for your script(s) and “data” for your data.
First things first, lets set setup our scripts properly by installing/loading necessary packages
# Setup ----------------------------------------------------------------------------------
# Options
options(stringsAsFactors = F)
# Packages
# devtools::install_github("tidymodels/parsnip")
::p_load(
pacman
tidyverse, data.table, broom, parallel, here, plotly )
Sidebar: pacman
is a great package for setting up your scripts/markdown docs/slides etc. It is what we would call a package manager. A simplified way of organizing the packages we install on our system and load into memory. Very conveniently installs any packages that are not previously installed.
pacman
is the standard package manager for arch; inspiration for the name.. I think..Sidebar sidebar: If you ever start using a new package, or if you want to understand how a specific function from a specific package works, try to google the package’s vignette.
A vignette is typically a markdown document that is written by the package maintainers that explains the functionality of a package and its functions through a series of examples. Kind of like a set of lecture notes for the package. However, not every package has a vignette.
Here is a list of a few of my favorite vignettes:
Typically, when I am searching for a vignette on Google/DuckDuckGo, I use the following search criteria
R package name
vignette R”Additionally you can search up the package on CRAN. Once you find the CRAN package page, usually any official vignettes will be linked there.
Finally you can search vignettes on the RStudio
console using the browseVignettes()
function from the utils
package.
# browseVignettes(package = 'pacman')
Back to resampling methods
Now load the housing data set that we’ve be using. Click here to download a zip file; download it to your data folder in which ever project you are using and unzip.
# Load data ------------------------------------------------------------------------------
# Training data
= here('data', 'train.csv') %>% fread()
train_dt # Testing data
= here('data', 'test.csv') %>% fread() test_dt
We only need a subset of these data for this lab. Let’s trim down our data set to four columns:
id
, the house Id
sale_price
, the sale price of the house aka SalePrice
age
, the age of the house at the time of the sale (the difference between YrSold
and YrBuilt
area
, the non-basement square-foot area of the houseWe can do this in a few different ways but dplyr::transmute()
is a very convenient function to use here
# Generate the new columns and keep only what we want
= train_dt %>% transmute(
house_df id = Id,
sale_price = SalePrice / 10000,
age = YrSold - YearBuilt,
area = GrLivArea
)
It’s always a good idea to look at the new dataframe and make sure it is exactly what we expect it to be. Any of the following is a great way to check:
# Double check the new data set using any of the following functions
summary(house_df)
## id sale_price age area
## Min. : 1.0 Min. : 3.49 Min. : 0.00 Min. : 334
## 1st Qu.: 365.8 1st Qu.:13.00 1st Qu.: 8.00 1st Qu.:1130
## Median : 730.5 Median :16.30 Median : 35.00 Median :1464
## Mean : 730.5 Mean :18.09 Mean : 36.55 Mean :1515
## 3rd Qu.:1095.2 3rd Qu.:21.40 3rd Qu.: 54.00 3rd Qu.:1777
## Max. :1460.0 Max. :75.50 Max. :136.00 Max. :5642
# glimpse(house_df)
# skimr::skim(house_df)
# View(house_df)
Finally, since we are going to use some RNG the last thing we need to do before our project is setup and ready to go is set a randomization seed using the set.seed()
function. Setting a seed allows us to keep track of randomness we introduce. By using the same seed, we should be able to reproduce the same result each time we run our code.
Just pick any number. For fun, us your birthday (Ex. “20210624” for 06-24-2021). For simplicity I will use “1234”.
# Set seed
set.seed(1234)
Let’s start by creating a single validation set composed of 30% of our training data. Since we have already set our seed
, we can draw the validation sample randomly using the dplyr
function sample_frac()
. The argument size
will allow us to choose the desired sample of 30%.
dplyr
’s function setdiff()
will give us the remaining, non-validation observation from the original training data.
# Draw validation set
= house_df %>% sample_frac(size = 0.3)
validation_df # Find remaining training set
= setdiff(house_df, validation_df) training_df
If you would like to read more into these functions, remember to look at the help files using ?sample_frac()
and ?setdiff()
.
Finally, let’s check our work and make sure the training_df
+ validation_df
= house_df
# Check that dimensions make sense
nrow(house_df) == nrow(validation_df) + nrow(training_df)
## [1] TRUE
## Model fit
Now that we have a training and validation set, let’s
training_df
Let’s define a flexible linear regression model (ie step i.)
\[\begin{align*} Price_i = &\beta_0 + \beta_1 * age_i^2 + \beta_2 * age_i + \beta_3 * area_i^2 + \\ &\beta_4 * area_i + \beta_5 * age_i^2 \cdot area_i^2 + \beta_6 * age_i^2 \cdot area_i + \\ & \beta_7 * age_i \cdot area_i^2 + beta_8 * area_i \cdot age_i \end{align*}\]
Since we want to perform this ^ algorithm (steps i., ii., iii.) several times over, it makes sense to automate this using a function. Doing this will allow us to validate tens, hundreds, thousands etc. of samples very quickly by throwing it into a for loop
.
# Our model-fit function
= function(deg_age, deg_area) {
fit_model # Estimate the model using the training data
= lm(
est_model ~ poly(age, deg_age, raw = T) * poly(area, deg_area, raw = T),
sale_price data = training_df
)# Make predictions on the validation data
= predict(est_model, newdata = validation_df, se.fit = F)
y_hat # Calculate our validation MSE
= mean((validation_df$sale_price - y_hat)^2)
mse
return(mse)
}
The two arguments for this function are deg_age
and deg_area
. They represent the degree of polynomial for age and area that we want to fit our model (ie deg_age = 2
>> \(age^2\))
We would like to loop over a series of values for deg_age
and deg_area
, fitting a model to each of the polynomial degrees.
First let’s create a dataframe that is 2 by 4x6 using the expand_grid()
function. We will attach each model fit MSE to an additional column.
# Take all possible combinations of our degrees
= expand_grid(deg_age = 1:6, deg_area = 1:4) deg_df
Now let’s iterate over all possible combinations (4x6) of polynomial specifications and see which model fit produces the smallest MSE.
# Iterate over set of possibilities (returns a vector of validation-set MSEs)
#Note: for Windows machines, can't use mc.cores > 1. Windows users can also use mapply function
= mcmapply(
mse_v FUN = fit_model,
deg_age = deg_df$deg_age,
deg_area = deg_df$deg_area,
mc.cores = 4
)
Now that we have a 1 by 24 length vector of all possible polynomial combinations, lets attach this vector as an additional column to the deg_df
dataframe we assigned a moment ago and arrange by the smalled MSE parameter.
# Add validation-set MSEs to 'deg_df'
$mse_v = mse_v
deg_df# Which set of parameters minimizes validation-set MSE?
arrange(deg_df, mse_v)
## # A tibble: 24 × 3
## deg_age deg_area mse_v
## <int> <int> <dbl>
## 1 6 2 15.4
## 2 4 2 15.9
## 3 2 3 16.1
## 4 2 1 16.5
## 5 2 4 16.6
## 6 3 1 16.9
## 7 2 2 17.2
## 8 1 3 17.3
## 9 1 2 17.6
## 10 1 1 17.6
## # … with 14 more rows
Now let’s plot it using ggplot2
, geom_tile()
and for extra flair/analysis plotly
an interactive ggplot2
graphing library.
#If using RStudio, can run ggplotly to get a more interactive plot:
= ggplot(data = deg_df, aes(x = deg_age, y = deg_area, fill = log(mse_v))) +
mse_gg geom_tile() +
# scale_fill_viridis_c("Logged MSE", option = "inferno", begin = 0.1) +
::theme_ipsum(base_size = 12) +
hrbrthemestheme(panel.grid.major = element_blank(),
panel.grid.minor = element_blank()) +
labs(
title = 'Model fit MSE heat map',
x = "Degrees of age",
y = 'Degrees of area',
colour = 'Log of MSE')
ggplotly(mse_gg)
Next week, we are going to build on this. But instead of (just) automating the model specification, we are going to automate sample selection. My current plan is to show how to do leave-one-out and k-fold cross validation.
Also we used a little parallel computing today. Next week, it may become much more useful and we can talk about parallel computing and why it is useful
Hopefully that lesson coincides nicely with the second project. I will definitely take some time to go over tips/typical coding errors for that assignment
Finally, a general tip: __Start the next project early! Send me your questions and I will try my best to help (and tell you when I can’t help). Any common errors or misunderstandings I will try to go over in next weeks lecture.
Have a great weekend!!
This document is built upon notes created by a previous GE for this course Stephen Reed which you can find here