Bayesian Regression Models with RStanARM

TJ Mahr
Sept. 21, 2016

Madison R Users Group

Github repository
@tjmahr

Overview

Tour of RStanARM

Software ecosystem

What is Stan?

A programming language for probablistic stats.

  • Write out a description of the data and model using a mathy syntax.
  • The model is compiled into an executable that does the sampling.
  • Pystan and RStan are interfaces that help you send/receive data to/from Stan models.

Simple Regression Model in Stan

# R version
lm(log(earnings) ~ height + male)
data {
  int<lower=0> N;
  vector[N] earn;
  vector[N] height;
  vector[N] male;
}
transformed data {
  // log transformation
  vector[N] log_earn;
  log_earn = log(earn);
}
parameters {
  vector[3] beta;
  real<lower=0> sigma;
}
model {
  log_earn ~ normal(beta[1] + beta[2] * height + beta[3] * male, sigma);
}
generated quantities {
  // optional
}

Take aways

  • Program is a description of the data and a model.
  • Language wants a more formal information about the data (type, length, bounds), presumably for more efficient sampling.
  • The tilde operator ~ is a sampling statement.

Another language?

Gif of Princess Leia saying "no thanks"

RStanARM

What is RStanArm?

RStan Applied Regression Modeling

  • Batteries-included, precompiled versions of common regression models.
  • glm -> stan_glm, glmer -> stan_glmer.
  • Write your regular code. Add stan_ to the front, and add a prior.
  • CRAN page
    • Very good! They have lots of detailed vignettes!
  • Proper successor to the arm package.

An example: Height and Weight by Sex

# Some toy data
davis <- car::Davis %>% filter(100 < height) %>% as_data_frame

davis
#> # A tibble: 199 × 5
#>       sex weight height repwt repht
#>    <fctr>  <int>  <int> <int> <int>
#> 1       M     77    182    77   180
#> 2       F     58    161    51   159
#> 3       F     53    161    54   158
#> 4       M     68    177    70   175
#> 5       F     59    157    59   155
#> 6       M     76    170    76   165
#> 7       M     76    167    77   165
#> 8       M     69    186    73   180
#> 9       M     71    178    71   175
#> 10      M     65    171    64   170
#> # ... with 189 more rows

Classical model: Summary

model <- lm(weight ~ height * sex, davis)
summary(model) %>% print(digits = 2)
#> 
#> Call:
#> lm(formula = weight ~ height * sex, data = davis)
#> 
#> Residuals:
#>    Min     1Q Median     3Q    Max 
#>  -20.9   -4.8   -0.9    4.5   41.1 
#> 
#> Coefficients:
#>             Estimate Std. Error t value Pr(>|t|)    
#> (Intercept)   -45.71      22.19    -2.1     0.04 *  
#> height          0.62       0.13     4.6    7e-06 ***
#> sexM          -55.62      32.54    -1.7     0.09 .  
#> height:sexM     0.37       0.19     2.0     0.05 .  
#> ---
#> Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#> 
#> Residual standard error: 8 on 195 degrees of freedom
#> Multiple R-squared:  0.64,   Adjusted R-squared:  0.64 
#> F-statistic: 1.2e+02 on 3 and 195 DF,  p-value: <2e-16

Classical model: Predicted mean over x

Scatterplot of height/weight, colored by sex, with model-predicted mean for each group drawn.

Fitting the model with RStanARM

Load the package

library(rstanarm)
#> Loading required package: Rcpp
#> rstanarm (Version 2.12.1, packaged: 2016-09-12 13:08:24 UTC)
#> - Do not expect the default priors to remain the same in future rstanarm versions.
#> Thus, R scripts should specify priors explicitly, even if they are just the defaults.
#> - For execution on a local, multicore CPU with excess RAM we recommend calling
#> options(mc.cores = parallel::detectCores())
  • So… hard-code the priors.
  • I only set the mc.cores option when I have a big model that going to take more than a minute to fit.

Fit the model

  • We have to use stan_glm().
    • stan_lm() uses a different specification of the prior.
  • By default, it does sampling with 4 MCMC chains.
    • Each chain is 2000 samples, but the first half are warm-up samples.
    • Warm-up samples are ignored.
stan_model <- stan_glm(
  weight ~ height * sex,
  data = davis,
  family = gaussian,
  prior = normal(location = 0, scale = 5),
  prior_intercept = normal(0, 10)
)
#> 
#> SAMPLING FOR MODEL 'continuous' NOW (CHAIN 1).
#> 
#> Chain 1, Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 1, Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 1, Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 1, Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 1, Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 1, Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 1, Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 1, Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 1, Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 1, Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 1, Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 1, Iteration: 2000 / 2000 [100%]  (Sampling)
#>  Elapsed Time: 1.718 seconds (Warm-up)
#>                1.881 seconds (Sampling)
#>                3.599 seconds (Total)
#> 
#> 
#> SAMPLING FOR MODEL 'continuous' NOW (CHAIN 2).
#> 
#> Chain 2, Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 2, Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 2, Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 2, Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 2, Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 2, Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 2, Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 2, Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 2, Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 2, Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 2, Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 2, Iteration: 2000 / 2000 [100%]  (Sampling)
#>  Elapsed Time: 1.924 seconds (Warm-up)
#>                2.052 seconds (Sampling)
#>                3.976 seconds (Total)
#> 
#> 
#> SAMPLING FOR MODEL 'continuous' NOW (CHAIN 3).
#> 
#> Chain 3, Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 3, Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 3, Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 3, Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 3, Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 3, Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 3, Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 3, Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 3, Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 3, Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 3, Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 3, Iteration: 2000 / 2000 [100%]  (Sampling)
#>  Elapsed Time: 1.75 seconds (Warm-up)
#>                1.788 seconds (Sampling)
#>                3.538 seconds (Total)
#> 
#> 
#> SAMPLING FOR MODEL 'continuous' NOW (CHAIN 4).
#> 
#> Chain 4, Iteration:    1 / 2000 [  0%]  (Warmup)
#> Chain 4, Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Chain 4, Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Chain 4, Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Chain 4, Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Chain 4, Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Chain 4, Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Chain 4, Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Chain 4, Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Chain 4, Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Chain 4, Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Chain 4, Iteration: 2000 / 2000 [100%]  (Sampling)
#>  Elapsed Time: 1.672 seconds (Warm-up)
#>                2.082 seconds (Sampling)
#>                3.754 seconds (Total)
# comment to make the bottom of the output visible

First look: Just printing the model object

stan_model
#> stan_glm(formula = weight ~ height * sex, family = gaussian, 
#>     data = davis, prior = normal(location = 0, scale = 5), prior_intercept = normal(0, 
#>         10))
#> 
#> Estimates:
#>             Median MAD_SD
#> (Intercept) -48.1   21.5 
#> height        0.6    0.1 
#> sexM        -49.9   29.8 
#> height:sexM   0.3    0.2 
#> sigma         8.0    0.4 
#> 
#> Sample avg. posterior predictive 
#> distribution of y (X = xbar):
#>          Median MAD_SD
#> mean_PPD 65.3    0.8  
#> 
#> Observations: 199  Number of unconstrained parameters: 5

Getting a summary from the model

summary(stan_model)
#> stan_glm(formula = weight ~ height * sex, family = gaussian, 
#>     data = davis, prior = normal(location = 0, scale = 5), prior_intercept = normal(0, 
#>         10))
#> 
#> Family: gaussian (identity)
#> Algorithm: sampling
#> Posterior sample size: 4000
#> Observations: 199
#> 
#> Estimates:
#>                 mean   sd     2.5%   25%    50%    75%    97.5%
#> (Intercept)    -48.2   21.4  -90.7  -62.3  -48.1  -33.2   -8.1 
#> height           0.6    0.1    0.4    0.5    0.6    0.7    0.9 
#> sexM           -50.0   30.7 -110.2  -70.5  -49.9  -30.6   12.3 
#> height:sexM      0.3    0.2    0.0    0.2    0.3    0.5    0.7 
#> sigma            8.0    0.4    7.3    7.7    8.0    8.3    8.9 
#> mean_PPD        65.3    0.8   63.7   64.8   65.3   65.9   66.9 
#> log-posterior -708.8    1.6 -712.7 -709.6 -708.5 -707.6 -706.7 
#> 
#> Diagnostics:
#>               mcse Rhat n_eff
#> (Intercept)   0.6  1.0  1254 
#> height        0.0  1.0  1258 
#> sexM          0.9  1.0  1086 
#> height:sexM   0.0  1.0  1070 
#> sigma         0.0  1.0  2092 
#> mean_PPD      0.0  1.0  2562 
#> log-posterior 0.0  1.0  1441 
#> 
#> For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
# comment to make the bottom of the output visible

Notes on summary()

  • Split into estimation and diagnostic information
  • mean_PPD is the predicted value for a completely average observation

Inspecting posterior samples

Looking at the posterior parameter samples

Coerce to a data-frame. Columns are parameters. One row per posterior sample.

samples <- stan_model %>% as.data.frame %>% tbl_df
samples
#> # A tibble: 4,000 × 5
#>    `(Intercept)`    height       sexM `height:sexM`    sigma
#>            <dbl>     <dbl>      <dbl>         <dbl>    <dbl>
#> 1      -56.71142 0.6895427 -68.550433    0.44384885 7.537152
#> 2      -56.70521 0.6909141 -69.411803    0.44392578 7.784114
#> 3      -56.34580 0.6865143 -69.497981    0.44949406 7.795977
#> 4      -30.71059 0.5202990 -58.429948    0.40582612 7.991750
#> 5      -34.68467 0.5452743 -64.090664    0.42866842 7.784008
#> 6      -87.38731 0.8865953   2.711235    0.02298263 8.324237
#> 7      -98.27211 0.9459078  15.229156   -0.05056481 8.665427
#> 8      -58.71536 0.6977810 -13.289139    0.13100168 7.894166
#> 9      -63.19315 0.7205729 -22.734855    0.17883580 8.501026
#> 10     -71.94144 0.7791263 -52.550236    0.33906170 8.051117
#> # ... with 3,990 more rows
ggplot(samples) + aes(x = height) + geom_histogram()

Histogram of height effect.

Quantiles are post-data probabilities

  • If we believe there is a “true” value for a parameter, there is 90% probability that this “true” value is in the 90% interval, given our model, prior information, and the data.
  • The 90% interval contains the middle 90% of the parameter values.
    • There is a 5% chance, says the model, the height parameter is below the 5% quantile.
posterior_interval(stan_model)
#>                        5%         95%
#> (Intercept)  -84.07171214 -13.2795824
#> height         0.42566160   0.8566897
#> sexM        -100.94954649   0.7700961
#> height:sexM    0.04578185   0.6331447
#> sigma          7.41023334   8.7265783

Plotting the parameters

Basic plot() method shows 80% and 95% intervals.

plot(stan_model)

Basic plot() called on model. Shows 80%, 95% intervals for each parameter.

Plotting some of the parameters

# Match a subset of the parameters
plot(stan_model, regex_pars = "height")

Basic plot() called on model. Shows 80%, 95% intervals. Limited to just parameters matching 'height'.

Exploring with ShinyStan

This thing is pretty awesome.

launch_shinystan(stan_model)

[Demo outside of slides.]

Things I just demonstrated in ShinyStan

  • Diagnostics
    • Mixing of chains
    • R hat, n eff
    • PPcheck
  • Parameters plot
    • Different parameters selected
    • KDE
  • Explore
    • Inspect one parameter in multiview
    • Bivariate view

Plotting model results

Doing Things With The Model

Get a data-frame with the parameters from each sample. We now have 4000 plausble regression lines.

df_model <- stan_model %>% as_data_frame()
df_model
#> # A tibble: 4,000 × 5
#>    `(Intercept)`    height       sexM `height:sexM`    sigma
#>            <dbl>     <dbl>      <dbl>         <dbl>    <dbl>
#> 1      -56.71142 0.6895427 -68.550433    0.44384885 7.537152
#> 2      -56.70521 0.6909141 -69.411803    0.44392578 7.784114
#> 3      -56.34580 0.6865143 -69.497981    0.44949406 7.795977
#> 4      -30.71059 0.5202990 -58.429948    0.40582612 7.991750
#> 5      -34.68467 0.5452743 -64.090664    0.42866842 7.784008
#> 6      -87.38731 0.8865953   2.711235    0.02298263 8.324237
#> 7      -98.27211 0.9459078  15.229156   -0.05056481 8.665427
#> 8      -58.71536 0.6977810 -13.289139    0.13100168 7.894166
#> 9      -63.19315 0.7205729 -22.734855    0.17883580 8.501026
#> 10     -71.94144 0.7791263 -52.550236    0.33906170 8.051117
#> # ... with 3,990 more rows

Apply the group effects to get regression lines for each sex.

df_model2 <- df_model %>%
  mutate(F_Intercept = `(Intercept)`, F_Slope = height,
         M_Intercept = `(Intercept)` + sexM,
         M_Slope = height + `height:sexM`) %>%
  select(F_Intercept:M_Slope)
df_model2
#> # A tibble: 4,000 × 4
#>    F_Intercept   F_Slope M_Intercept   M_Slope
#>          <dbl>     <dbl>       <dbl>     <dbl>
#> 1    -56.71142 0.6895427  -125.26185 1.1333915
#> 2    -56.70521 0.6909141  -126.11701 1.1348399
#> 3    -56.34580 0.6865143  -125.84378 1.1360084
#> 4    -30.71059 0.5202990   -89.14054 0.9261251
#> 5    -34.68467 0.5452743   -98.77534 0.9739427
#> 6    -87.38731 0.8865953   -84.67607 0.9095779
#> 7    -98.27211 0.9459078   -83.04295 0.8953430
#> 8    -58.71536 0.6977810   -72.00450 0.8287827
#> 9    -63.19315 0.7205729   -85.92801 0.8994087
#> 10   -71.94144 0.7791263  -124.49167 1.1181880
#> # ... with 3,990 more rows

The "Pile of Lines" Plot

Plot lines with the median parameter values and a random sample of lines.

fits <- sample_n(df_model2, 200)
medians <- df_model2 %>% summarise_each(funs = funs(median))

p2 <- ggplot(davis) +
  aes(x = height, y = weight, color = sex) +
  geom_abline(aes(color = "F", intercept = F_Intercept,
                  slope = F_Slope), data = fits, alpha = .075) +
  geom_abline(aes(color = "M", intercept = M_Intercept,
                  slope = M_Slope), data = fits, alpha = .075) +
  geom_abline(aes(color = "F", intercept = F_Intercept,
                  slope = F_Slope), data = medians, size = 1.25) +
  geom_abline(aes(color = "M", intercept = M_Intercept,
                  slope = M_Slope), data = medians, size = 1.25) +
  geom_point() +
  theme(legend.position = c(0, 1), legend.justification = c(0, 1))

Pile of lines plot. Scatter plot of height and weight. One thick line for each group's posterior median. 200 other lines, drawn from posterior, for each group also drawn.

"Pile of Lines"

  • Depicts uncertainty by drawing line of best fit plus many other plausible regression lines.
  • Visualization works best when there is only one group and we don't mind having regression lines stretch past the range of the data.

The "Line + Prediction Interval" Plot

Get a sample of height values within each group's range.

# # do some stuff
# ...

new_data
#> # A tibble: 160 × 3
#>    Observation   sex   height
#>          <chr> <chr>    <dbl>
#> 1            1     F 148.0000
#> 2            2     F 148.3797
#> 3            3     F 148.7595
#> 4            4     F 149.1392
#> 5            5     F 149.5190
#> 6            6     F 149.8987
#> 7            7     F 150.2785
#> 8            8     F 150.6582
#> 9            9     F 151.0380
#> 10          10     F 151.4177
#> # ... with 150 more rows

Get the predicted mean for each point

With the normal predict() function on a classical regression model, we use the model parameters to get the expected mean (\( \mu_i \)) for each row in newdata.

With posterior_linpred(), we do the same thing, but for each of the 4000 posterior samples.

linpreds <- posterior_linpred(stan_model, newdata = new_data)
dim(linpreds)
#> [1] 4000  160
# In long format... 4000 samples x 160 data points
tidy_linpreds <- linpreds %>%
  as_data_frame %>%
  tibble::rownames_to_column("Draw") %>%
  tidyr::gather(Observation, Value, -Draw)

tidy_linpreds
#> # A tibble: 640,000 × 3
#>     Draw Observation    Value
#>    <chr>       <chr>    <dbl>
#> 1      1           1 45.34089
#> 2      2           1 45.55008
#> 3      3           1 45.25832
#> 4      4           1 46.29367
#> 5      5           1 46.01592
#> 6      6           1 43.82879
#> 7      7           1 41.72225
#> 8      8           1 44.55623
#> 9      9           1 43.45163
#> 10    10           1 43.36925
#> # ... with 639,990 more rows

Summarize the posterior predictions for each new data point. Go from 4000 rows per new point to just one row per new data point.

df_predictions <- tidy_linpreds %>%
  group_by(Observation) %>%
  summarise(median = median(Value),
            ymin = quantile(Value, .025),
            ymax = quantile(Value, .975)) %>%
  left_join(new_data)

df_predictions
#> # A tibble: 160 × 6
#>    Observation   median     ymin     ymax   sex   height
#>          <chr>    <dbl>    <dbl>    <dbl> <chr>    <dbl>
#> 1            1 46.26271 41.56672 50.55823     F 148.0000
#> 2           10 48.44859 44.61635 52.04461     F 151.4177
#> 3          100 69.23704 66.83380 71.73073     M 171.1772
#> 4          101 69.65821 67.32576 72.05880     M 171.6076
#> 5          102 70.09021 67.80754 72.40900     M 172.0380
#> 6          103 70.51185 68.29674 72.76241     M 172.4684
#> 7          104 70.93182 68.79742 73.11356     M 172.8987
#> 8          105 71.35920 69.29262 73.47099     M 173.3291
#> 9          106 71.78164 69.78014 73.83019     M 173.7595
#> 10         107 72.20361 70.23412 74.17940     M 174.1899
#> # ... with 150 more rows

Now, we can plot predicted means

p3 <- ggplot(davis) +
  aes(x = height, y = weight, color = sex, group = sex) +
  geom_point() +
  geom_ribbon(aes(ymin = ymin, ymax = ymax, y = NULL, color = NULL),
              data = df_predictions, fill = "grey60", alpha = .4) +
  geom_line(aes(y = median), data = df_predictions, size = 1.25) +
  theme(legend.position = c(0, 1), legend.justification = c(0, 1))

Scatterplot of data with plot of predicted means and 95% credible intervals for each sex.

Limitations of this approach

RStanARM does not like posterior_linpred(). Documentation says:

See also: posterior_predict to draw from the posterior predictive distribution of the outcome, which is almost always preferable.


Why? Because we have samples of the error term sigma, and we're not using them!

Get the model to generate new observations

Same plan as before, but use posterior_predict() to get predictions for observations (\( y_i \)), instead of estimates of the mean (\( \mu_i \)).

# ?posterior_predict
preds <- posterior_predict(stan_model, newdata = new_data)
dim(preds)
#> [1] 4000  160
tidy_preds <- preds %>%
  # Rename columns from V1, V2, ... to 1, 2, ...
  as_data_frame %>% setNames(seq_len(ncol(.))) %>%
  tibble::rownames_to_column("Draw") %>%
  tidyr::gather(Observation, Value, -Draw)

tidy_preds
#> # A tibble: 640,000 × 3
#>     Draw Observation    Value
#>    <chr>       <chr>    <dbl>
#> 1      1           1 48.09020
#> 2      2           1 56.69339
#> 3      3           1 43.43079
#> 4      4           1 45.32523
#> 5      5           1 52.31126
#> 6      6           1 49.02409
#> 7      7           1 45.03329
#> 8      8           1 36.15589
#> 9      9           1 51.22175
#> 10    10           1 67.86039
#> # ... with 639,990 more rows
df_predictions <- tidy_preds %>%
  group_by(Observation) %>%
  summarise(median = median(Value),
            ymin = quantile(Value, .025),
            ymax = quantile(Value, .975)) %>%
  left_join(new_data)

df_predictions
#> # A tibble: 160 × 6
#>    Observation   median     ymin     ymax   sex   height
#>          <chr>    <dbl>    <dbl>    <dbl> <chr>    <dbl>
#> 1            1 46.20671 29.78864 63.33542     F 148.0000
#> 2           10 48.40533 32.28516 64.73203     F 151.4177
#> 3          100 69.18329 53.50192 84.44524     M 171.1772
#> 4          101 69.75366 53.64065 85.40446     M 171.6076
#> 5          102 69.82163 54.19379 85.86467     M 172.0380
#> 6          103 70.37902 54.75848 86.75807     M 172.4684
#> 7          104 71.01157 54.98491 86.40210     M 172.8987
#> 8          105 71.60191 55.75243 87.27941     M 173.3291
#> 9          106 71.55881 56.19622 87.82290     M 173.7595
#> 10         107 72.33715 55.81252 88.64755     M 174.1899
#> # ... with 150 more rows

Plot the predictions

p4 <- ggplot(davis) +
  aes(x = height, y = weight, color = sex, group = sex) +
  geom_point() +
  geom_ribbon(aes(ymin = ymin, ymax = ymax, y = NULL, color = NULL),
              data = df_predictions, fill = "grey60", alpha = .2) +
  geom_line(aes(y = median),
              data = df_predictions, size = 1.25) +
  theme(legend.position = c(0, 1), legend.justification = c(0, 1))

Scatterplot of data with medians of simulated data and 95% posterior predictive intervals for each sex.

Model comparison

Model comparison

Do we need the sex-height interaction?

Do we need the sex predictor at all?

Let's fit alternative models.

stan_model_no_inter <- update(stan_model, weight ~ sex + height)
stan_model_no_group <- update(stan_model, weight ~ height)

Classical model comparison

For classical regression models, we could compare models with AIC and BIC.

model_no_inter <- update(model, weight ~ sex + height)
model_no_group <- update(model, weight ~ height)
model_list <- list(
  no_group = model_no_group,
  no_inter = model_no_inter,
  inter = model)

model_list %>% lapply(AIC) %>% unlist
#> no_group no_inter    inter 
#> 1421.572 1401.590 1399.692

model_list %>% lapply(BIC) %>% unlist
#> no_group no_inter    inter 
#> 1431.452 1414.764 1416.158

Classical model comparison

Or with anova().

anova(model_no_group, model_no_inter, model)
#> Analysis of Variance Table
#> 
#> Model 1: weight ~ height
#> Model 2: weight ~ sex + height
#> Model 3: weight ~ height * sex
#>   Res.Df   RSS Df Sum of Sq       F    Pr(>F)    
#> 1    197 14312                                   
#> 2    196 12815  1   1496.72 23.2250 2.892e-06 ***
#> 3    195 12567  1    248.64  3.8582   0.05092 .  
#> ---
#> Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Approximate Leave One Out Cross Validation

For RStanARM, we compare the models by using approximation leave-one-out cross-validation. The details are outside of the scope of this talk.

loo1 <- loo(stan_model_no_group)
loo2 <- loo(stan_model_no_inter)
loo3 <- loo(stan_model)
compare(loo1, loo2, loo3)
#>      looic  se_looic elpd_loo se_elpd_loo p_loo  se_p_loo
#> loo3 1401.7   33.9   -700.8     16.9         6.7    2.1  
#> loo2 1402.9   33.3   -701.5     16.6         5.3    1.9  
#> loo1 1422.7   31.9   -711.3     15.9         4.2    1.7
  • Like other IC, lower is better. Importantly, the LOOIC comes with a standard error.
  • By point-value (looic), the model with the interaction is expected to perform the best on out-of-sample data.
  • But the se_looic indicates considerable overlap with the model with no interaction term.
  • Like so much of Bayesian modeling, uncertainty is everything.

Comparing the model to the priors

Priors

  • In our models, the prior normal(0, 5) gets rescaled to match each predictor variable.
  • This rescaling is not well documented.
  • Some options:
    • scale() our measurements so that they are scale free (with mean = 0 and SD = 1).
    • Sample from the prior to see what kinds of values are generated by our prior and confirm that our prior information yields sensible model behavior.

Sampling from the prior: posterior_vs_prior()

RStanARM will compare samples from the posterior and the prior.

posterior_vs_prior(stan_model)

default posterior_vs_prior() plot. It looks bad because one parameter ruins y-axis.

# Select just parameters with "height" in the name
posterior_vs_prior(stan_model, regex_pars = "height")

posterior_vs_prior() plot on just the 'height' parameters.

# use the faceting options
comparison <- posterior_vs_prior(
  stan_model,
  group_by_parameter = TRUE,
  facet_args = list(scales = "free_y"),
  prob = .95)

comparison + theme_grey() + guides(color = FALSE)

posterior_vs_prior() plot with facets for each variable. This looks best because the y-axis varies freely by facet.

Sampling from the prior: prior_PD = TRUE

We can also sample from the prior directly.

stan_model_prior <- stan_glm(
  weight ~ height * sex,
  data = davis,
  family = gaussian,
  prior = normal(0, 5),
  prior_intercept = normal(0, 10),
  # this line is new:
  prior_PD = TRUE
)

Our initial prior says that an increase in height of 1 cm may predict a 0 +/- 7 kg increase in weight in women. This is a very generous interval for this effect!

summary(stan_model_prior, probs = c(.1, .5, .9))
#> stan_glm(formula = weight ~ height * sex, family = gaussian, 
#>     data = davis, prior = normal(0, 5), prior_intercept = normal(0, 
#>         10), prior_PD = TRUE)
#> 
#> Family: gaussian (identity)
#> Algorithm: sampling
#> Posterior sample size: 4000
#> Observations: 199
#> 
#> Estimates:
#>                 mean    sd      10%     50%     90%  
#> (Intercept)       5.8  1208.5 -1531.0    10.7  1553.6
#> height           -0.1     6.9    -8.9     0.0     8.8
#> sexM              1.9   126.0  -164.8     4.0   161.9
#> height:sexM       0.0     0.7    -0.9     0.0     0.9
#> sigma            22.5   335.0     0.9     5.2    31.3
#> mean_PPD         -3.8   258.4  -335.8    -1.1   326.4
#> log-posterior   -13.6     1.5   -15.6   -13.2   -11.9
#> 
#> Diagnostics:
#>               mcse Rhat n_eff
#> (Intercept)   19.1  1.0 4000 
#> height         0.1  1.0 4000 
#> sexM           2.0  1.0 4000 
#> height:sexM    0.0  1.0 4000 
#> sigma          5.3  1.0 3941 
#> mean_PPD       4.1  1.0 4000 
#> log-posterior  0.0  1.0 1843 
#> 
#> For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
# comment

This is an extra slide about stan_lm().

We didn't use stan_lm() because those models have one prior for all the parameters: R2.

lm_model <- stan_lm(
  weight ~ height * sex,
  data = davis,
  # Prior: I asked Wolfram Alpha for height and weight
  # correlation in adults and squared it
  prior = R2(0.63))
summary(lm_model)
#> stan_lm(formula = weight ~ height * sex, data = davis, prior = R2(0.63))
#> 
#> Family: gaussian (identity)
#> Algorithm: sampling
#> Posterior sample size: 4000
#> Observations: 199
#> 
#> Estimates:
#>                 mean   sd     2.5%   25%    50%    75%    97.5%
#> (Intercept)    -44.5   21.9  -87.8  -59.3  -44.4  -30.1   -2.1 
#> height           0.6    0.1    0.4    0.5    0.6    0.7    0.9 
#> sexM           -55.4   32.2 -118.6  -76.7  -54.7  -33.6    6.2 
#> height:sexM      0.4    0.2    0.0    0.2    0.4    0.5    0.7 
#> sigma            8.1    0.4    7.3    7.8    8.0    8.3    8.9 
#> log-fit_ratio    0.0    0.0   -0.1    0.0    0.0    0.0    0.1 
#> R2               0.6    0.0    0.6    0.6    0.6    0.7    0.7 
#> mean_PPD        65.3    0.8   63.7   64.7   65.3   65.8   66.8 
#> log-posterior -700.1    2.0 -704.9 -701.3 -699.8 -698.6 -697.2 
#> 
#> Diagnostics:
#>               mcse Rhat n_eff
#> (Intercept)   0.5  1.0  1899 
#> height        0.0  1.0  1903 
#> sexM          0.7  1.0  2355 
#> height:sexM   0.0  1.0  2298 
#> sigma         0.0  1.0  2621 
#> log-fit_ratio 0.0  1.0  2163 
#> R2            0.0  1.0  2440 
#> mean_PPD      0.0  1.0  4000 
#> log-posterior 0.1  1.0  1108 
#> 
#> For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).

Finally