Model building

AIM: build up a model to aid your understanding of the data.

library(tidyverse)
## -- Attaching packages --------------------------------------- tidyverse 1.3.0 --
## v ggplot2 3.3.2     v purrr   0.3.4
## v tibble  3.0.4     v dplyr   1.0.2
## v tidyr   1.1.2     v stringr 1.4.0
## v readr   1.4.0     v forcats 0.5.0
## -- Conflicts ------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(modelr)
options(na.action = na.warn)

library(nycflights13)
library(lubridate)
## 
## Attaching package: 'lubridate'
## The following objects are masked from 'package:base':
## 
##     date, intersect, setdiff, union

Why are low quality diamonds more expensive?

ggplot(diamonds, aes(cut, price)) + geom_boxplot()

ggplot(diamonds, aes(color, price)) + geom_boxplot()

ggplot(diamonds, aes(clarity, price)) + geom_boxplot()

Price and carat

It looks like lower quality diamonds have higher prices because there is an important confounding variable: the weight (carat) of the diamond. The weight of the diamond is the single most important factor for determining the price of the diamond, and lower quality diamonds tend to be larger.

ggplot(diamonds, aes(carat, price)) + 
  geom_hex(bins = 50)

We can make it easier to see how the other attributes of a diamond affect its relative price by fitting a model to separate out the effect of carat.

Focus on diamonds smaller than 2.5 carats (99.7% of the data).

Log-transform the carat and price variables.

diamonds2 <- diamonds %>% 
  filter(carat <= 2.5) %>% 
  mutate(lprice = log2(price), lcarat = log2(carat))

ggplot(diamonds2, aes(lcarat, lprice)) + 
  geom_hex(bins = 50)

The log-transformation is particularly useful here because it makes the pattern linear, and linear patterns are the easiest to work with. Let’s take the next step and remove that strong linear pattern. We first make the pattern explicit by fitting a model:

mod_diamond <- lm(lprice ~ lcarat, data = diamonds2)
summary(mod_diamond)
## 
## Call:
## lm(formula = lprice ~ lcarat, data = diamonds2)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -1.96407 -0.24549 -0.00844  0.23930  1.93486 
## 
## Coefficients:
##              Estimate Std. Error t value Pr(>|t|)    
## (Intercept) 12.193863   0.001969  6194.5   <2e-16 ***
## lcarat       1.681371   0.001936   868.5   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.3767 on 53812 degrees of freedom
## Multiple R-squared:  0.9334, Adjusted R-squared:  0.9334 
## F-statistic: 7.542e+05 on 1 and 53812 DF,  p-value: < 2.2e-16
grid <- diamonds2 %>% 
  data_grid(carat = seq_range(carat, 20)) %>% 
  mutate(lcarat = log2(carat)) %>% 
  add_predictions(mod_diamond, "lprice") %>% 
  mutate(price = 2 ^ lprice)

ggplot(diamonds2, aes(carat, price)) + 
  geom_hex(bins = 50) + 
  geom_line(data = grid, colour = "red", size = 1)

That tells us something interesting about our data. If we believe our model, then the large diamonds are much cheaper than expected. This is probably because no diamond in this dataset costs more than $19,000.

Now we can look at the residuals, which verifies that we’ve successfully removed the strong linear pattern:

diamonds2 <- diamonds2 %>% 
  add_residuals(mod_diamond, "lresid")

ggplot(diamonds2, aes(lcarat, lresid)) + 
  geom_hex(bins = 50)

Importantly, we can now re-do our motivating plots using those residuals instead of price.

ggplot(diamonds2, aes(cut, lresid)) + geom_boxplot()

ggplot(diamonds2, aes(color, lresid)) + geom_boxplot()

ggplot(diamonds2, aes(clarity, lresid)) + geom_boxplot()

Now we see the relationship we expect: as the quality of the diamond increases, so too does its relative price. To interpret the y axis, we need to think about what the residuals are telling us, and what scale they are on. A residual of -1 indicates that lprice was 1 unit lower than a prediction based solely on its weight. 2−1 is 1/2, points with a value of -1 are half the expected price, and residuals with value 1 are twice the predicted price.

A more complicated model

If we wanted to, we could continue to build up our model, moving the effects we’ve observed into the model to make them explicit. For example, we could include color, cut, and clarity into the model so that we also make explicit the effect of these three categorical variables:

mod_diamond2 <- lm(lprice ~ lcarat + color + cut + clarity, data = diamonds2)

# If the model needs variables that you haven’t explicitly supplied, data_grid() will automatically fill them in with “typical” value.
grid <- diamonds2 %>% 
  data_grid(cut, .model = mod_diamond2) %>% 
  add_predictions(mod_diamond2)
grid
## # A tibble: 5 x 5
##   cut       lcarat color clarity  pred
##   <ord>      <dbl> <chr> <chr>   <dbl>
## 1 Fair      -0.515 G     VS2      11.2
## 2 Good      -0.515 G     VS2      11.3
## 3 Very Good -0.515 G     VS2      11.4
## 4 Premium   -0.515 G     VS2      11.4
## 5 Ideal     -0.515 G     VS2      11.4
ggplot(grid, aes(cut, pred)) + 
  geom_point()

diamonds2 <- diamonds2 %>% 
  add_residuals(mod_diamond2, "lresid2")

ggplot(diamonds2, aes(lcarat, lresid2)) + 
  geom_hex(bins = 50)

This plot indicates that there are some diamonds with quite large residuals - remember a residual of 2 indicates that the diamond is 4x the price that we expected. It’s often useful to look at unusual values individually:

diamonds2 %>% 
  filter(abs(lresid2) > 1) %>% 
  add_predictions(mod_diamond2) %>% 
  mutate(pred = round(2 ^ pred)) %>% 
  select(price, pred, carat:table, x:z) %>% 
  arrange(price)
## # A tibble: 16 x 11
##    price  pred carat cut       color clarity depth table     x     y     z
##    <int> <dbl> <dbl> <ord>     <ord> <ord>   <dbl> <dbl> <dbl> <dbl> <dbl>
##  1  1013   264 0.25  Fair      F     SI2      54.4    64  4.3   4.23  2.32
##  2  1186   284 0.25  Premium   G     SI2      59      60  5.33  5.28  3.12
##  3  1186   284 0.25  Premium   G     SI2      58.8    60  5.33  5.28  3.12
##  4  1262  2644 1.03  Fair      E     I1       78.2    54  5.72  5.59  4.42
##  5  1415   639 0.35  Fair      G     VS2      65.9    54  5.57  5.53  3.66
##  6  1415   639 0.35  Fair      G     VS2      65.9    54  5.57  5.53  3.66
##  7  1715   576 0.32  Fair      F     VS2      59.6    60  4.42  4.34  2.61
##  8  1776   412 0.290 Fair      F     SI1      55.8    60  4.48  4.41  2.48
##  9  2160   314 0.34  Fair      F     I1       55.8    62  4.72  4.6   2.6 
## 10  2366   774 0.3   Very Good D     VVS2     60.6    58  4.33  4.35  2.63
## 11  3360  1373 0.51  Premium   F     SI1      62.7    62  5.09  4.96  3.15
## 12  3807  1540 0.61  Good      F     SI2      62.5    65  5.36  5.29  3.33
## 13  3920  1705 0.51  Fair      F     VVS2     65.4    60  4.98  4.9   3.23
## 14  4368  1705 0.51  Fair      F     VVS2     60.7    66  5.21  5.11  3.13
## 15 10011  4048 1.01  Fair      D     SI2      64.6    58  6.25  6.2   4.02
## 16 10470 23622 2.46  Premium   E     SI2      59.7    59  8.82  8.76  5.25

Nothing really jumps out at me here, but it’s probably worth spending time considering if this indicates a problem with our model, or if there are errors in the data. If there are mistakes in the data, this could be an opportunity to buy diamonds that have been priced low incorrectly. :D

What affects the number of daily flights?

Let’s work through a similar process for a dataset that seems even simpler at first glance: the number of flights that leave NYC per day. This is a really small dataset — only 365 rows and 2 columns — and we’re not going to end up with a fully realised model, but as you’ll see, the steps along the way will help us better understand the data. Let’s get started by counting the number of flights per day and visualising it with ggplot2.

library(lubridate)

daily <- flights %>% 
  mutate(date = make_date(year, month, day)) %>% 
  group_by(date) %>% 
  summarise(n = n())
## `summarise()` ungrouping output (override with `.groups` argument)
daily
## # A tibble: 365 x 2
##    date           n
##    <date>     <int>
##  1 2013-01-01   842
##  2 2013-01-02   943
##  3 2013-01-03   914
##  4 2013-01-04   915
##  5 2013-01-05   720
##  6 2013-01-06   832
##  7 2013-01-07   933
##  8 2013-01-08   899
##  9 2013-01-09   902
## 10 2013-01-10   932
## # ... with 355 more rows
ggplot(daily, aes(date, n)) + 
  geom_line()

Day of week

Understanding the long-term trend is challenging because there’s a very strong day-of-week effect that dominates the subtler patterns. Let’s start by looking at the distribution of flight numbers by day-of-week:

daily <- daily %>% 
  mutate(wday = wday(date, label = TRUE))

ggplot(daily, aes(wday, n)) + 
  geom_boxplot()

There are fewer flights on weekends because most travel is for business. The effect is particularly pronounced on Saturday: you might sometimes leave on Sunday for a Monday morning meeting, but it’s very rare that you’d leave on Saturday as you’d much rather be at home with your family.

One way to remove this strong pattern is to use a model. First, we fit the model, and display its predictions overlaid on the original data:

mod <- lm(n ~ wday, data = daily)

grid <- daily %>% 
  data_grid(wday) %>% 
  add_predictions(mod, "n")

ggplot(daily, aes(wday, n)) + 
  geom_boxplot() +
  geom_point(data = grid, colour = "red", size = 4)

daily <- daily %>% 
  add_residuals(mod)

daily %>% 
  ggplot(aes(date, resid)) + 
  geom_ref_line(h = 0) + 
  geom_line()

Note the change in the y-axis: now we are seeing the deviation from the expected number of flights, given the day of week. This plot is useful because now that we’ve removed much of the large day-of-week effect, we can see some of the subtler patterns that remain:

  1. Our model seems to fail starting in June: you can still see a strong regular pattern that our model hasn’t captured. Drawing a plot with one line for each day of the week makes the cause easier to see:
ggplot(daily, aes(date, resid, colour = wday)) + 
  geom_ref_line(h = 0) + 
  geom_line()

Our model fails to accurately predict the number of flights on Saturday: during summer there are more flights than we expect, and during Fall there are fewer. We’ll see how we can do better to capture this pattern in the next section.

  1. There are some days with far fewer flights than expected:
daily %>% 
  filter(resid < -100)
## # A tibble: 11 x 4
##    date           n wday  resid
##    <date>     <int> <ord> <dbl>
##  1 2013-01-01   842 Tue   -109.
##  2 2013-01-20   786 Sun   -105.
##  3 2013-05-26   729 Sun   -162.
##  4 2013-07-04   737 Thu   -229.
##  5 2013-07-05   822 Fri   -145.
##  6 2013-09-01   718 Sun   -173.
##  7 2013-11-28   634 Thu   -332.
##  8 2013-11-29   661 Fri   -306.
##  9 2013-12-24   761 Tue   -190.
## 10 2013-12-25   719 Wed   -244.
## 11 2013-12-31   776 Tue   -175.

If you’re familiar with American public holidays, you might spot New Year’s day, July 4th, Thanksgiving and Christmas. There are some others that don’t seem to correspond to public holidays. You’ll work on those in one of the exercises.

  1. There seems to be some smoother long term trend over the course of a year. We can highlight that trend with geom_smooth():
daily %>% 
  ggplot(aes(date, resid)) + 
  geom_ref_line(h = 0) + 
  geom_line(colour = "grey50") + 
  geom_smooth(se = FALSE, span = 0.20)
## `geom_smooth()` using method = 'loess' and formula 'y ~ x'

There are fewer flights in January (and December), and more in summer (May-Sep). We can’t do much with this pattern quantitatively, because we only have a single year of data. But we can use our domain knowledge to brainstorm potential explanations.

Seasonal Saturday effect

daily %>% 
  filter(wday == "Sat") %>% 
  ggplot(aes(date, n)) + 
    geom_point() + 
    geom_line() +
    scale_x_date(NULL, date_breaks = "1 month", date_labels = "%b")

I suspect this pattern is caused by summer holidays: many people go on holiday in the summer, and people don’t mind travelling on Saturdays for vacation. Looking at this plot, we might guess that summer holidays are from early June to late August. That seems to line up fairly well with the state’s school terms: summer break in 2013 was Jun 26–Sep 9.

Why are there more Saturday flights in the Spring than the Fall? I asked some American friends and they suggested that it’s less common to plan family vacations during the Fall because of the big Thanksgiving and Christmas holidays. We don’t have the data to know for sure, but it seems like a plausible working hypothesis.

term <- function(date) {
  cut(date, 
    breaks = ymd(20130101, 20130605, 20130825, 20140101),
    labels = c("spring", "summer", "fall") 
  )
}

daily <- daily %>% 
  mutate(term = term(date))
daily %>% 
  filter(wday == "Sat") %>% 
  ggplot(aes(date, n, colour = term)) +
  geom_point(alpha = 1/3) + 
  geom_line() +
  scale_x_date(NULL, date_breaks = "1 month", date_labels = "%b")

It’s useful to see how this new variable affects the other days of the week:

daily %>% 
  ggplot(aes(wday, n, colour = term)) +
    geom_boxplot()

It looks like there is significant variation across the terms, so fitting a separate day of week effect for each term is reasonable. This improves our model, but not as much as we might hope:

mod1 <- lm(n ~ wday, data = daily)
mod2 <- lm(n ~ wday * term, data = daily)

daily %>% 
  gather_residuals(without_term = mod1, with_term = mod2) %>% 
  ggplot(aes(date, resid, colour = model)) +
    geom_line(alpha = 0.75)

grid <- daily %>% 
  data_grid(wday, term) %>% 
  add_predictions(mod2, "n")

ggplot(daily, aes(wday, n)) +
  geom_boxplot() + 
  geom_point(data = grid, colour = "red") + 
  facet_wrap(~ term)

Our model is finding the mean effect, but we have a lot of big outliers, so mean tends to be far away from the typical value. We can alleviate this problem by using a model that is robust to the effect of outliers: MASS::rlm(). This greatly reduces the impact of the outliers on our estimates, and gives a model that does a good job of removing the day of week pattern:

mod3 <- MASS::rlm(n ~ wday * term, data = daily)

daily %>% 
  add_residuals(mod3, "resid") %>% 
  ggplot(aes(date, resid)) + 
  geom_hline(yintercept = 0, size = 2, colour = "white") + 
  geom_line()

Computed variables

RECOMMENDATION

compute_vars <- function(data) {
  data %>% 
    mutate(
      term = term(date), 
      wday = wday(date, label = TRUE)
    )
}

Time of year: an alternative approach

library(splines)
mod <- MASS::rlm(n ~ wday * ns(date, 5), data = daily)

daily %>% 
  data_grid(wday, date = seq_range(date, n = 13)) %>% 
  add_predictions(mod) %>% 
  ggplot(aes(date, pred, colour = wday)) + 
    geom_line() +
    geom_point()