Goals

  • Look at latent Dirichlet allocation, a form of topic modeling, as an example of a mixture model
  • Understand how to fit common Stan models using R syntax with the rstanarm package
  • Recognize the advantages and limitations of pre-specified models

Setup

The new package this time is rstanarm, which fits pre-built Stan models using an R formula-style syntax.

library("rstan")
library("rstanarm")
library("tidyverse")
library("bayesplot")

options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)
theme_set(theme_minimal())
knitr::opts_chunk$set(fig.align = "center")

set.seed(123)

Mixture models

Latent Dirichlet allocation (LDA) is a common form of topic modeling for text data. Words are distributed across topics, and topics are distributed across documents, probabilistically. We can write the LDA model in Stan.

The data we’ll use are the words from 100 Associated Press articles. You can see what package they came from and the code that formatted them this way in the appendix.

ap <- read_csv("data/associated_press.csv")
## Parsed with column specification:
## cols(
##   document = col_integer(),
##   term = col_character(),
##   w = col_integer()
## )

This model has many indices, many parameters, and a complicated likelihood. We’ll talk through it!

m_lda <- stan_model("stan/lda.stan")
m_lda
## S4 class stanmodel 'lda' coded as follows:
## data {
##   int<lower=2> K;               // num topics
##   int<lower=2> V;               // num words
##   int<lower=1> M;               // num docs
##   int<lower=1> N;               // total word instances
##   int<lower=1, upper=V> w[N];   // word n
##   int<lower=1, upper=M> doc[N]; // doc ID for word n
##   vector<lower=0>[K] alpha;     // topic prior
##   vector<lower=0>[V] beta;      // word prior
## }
## parameters {
##   simplex[K] theta[M];   // topic dist for doc m
##   simplex[V] phi[K];     // word dist for topic k
## }
## model {
##   for (m in 1:M)
##     theta[m] ~ dirichlet(alpha);  // prior
##   for (k in 1:K)
##     phi[k] ~ dirichlet(beta);     // prior
##   for (n in 1:N) {
##     real gamma[K];
##     for (k in 1:K)
##       gamma[k] = log(theta[doc[n], k]) + log(phi[k, w[n]]);
##     target += log_sum_exp(gamma); // likelihood
##   }
## }

The words and documents are formatted as long vectors already, but we need to decide the number of topics K, calculate the other indices, and set hyperparameters for the two Dirichlet distributions.

d <- list(
  doc = ap$document, 
  w = ap$w
)

d$K <- 5
d$V <- length(unique(d$w))
d$M <- length(unique(d$doc))
d$N <- length(d$w)
d$alpha <- rep(1, times = d$K)
d$beta <- rep(0.5, times = d$V) 

We’ll fit the model using variational inference (vb instead of sampling). This is less accurate than MCMC, but faster. This is a tough model to fit! (Dedicated text analysis packages are even faster, but it’s still pretty neat we can write the model in Stan.)

fit_lda <- vb(m_lda, data = d, algorithm = "meanfield")
## Chain 1: ------------------------------------------------------------
## Chain 1: EXPERIMENTAL ALGORITHM:
## Chain 1:   This procedure has not been thoroughly tested and may be unstable
## Chain 1:   or buggy. The interface is subject to change.
## Chain 1: ------------------------------------------------------------
## Chain 1: 
## Chain 1: 
## Chain 1: 
## Chain 1: Gradient evaluation took 0.037247 seconds
## Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 372.47 seconds.
## Chain 1: Adjust your expectations accordingly!
## Chain 1: 
## Chain 1: 
## Chain 1: Begin eta adaptation.
## Chain 1: Iteration:   1 / 250 [  0%]  (Adaptation)
## Chain 1: Iteration:  50 / 250 [ 20%]  (Adaptation)
## Chain 1: Iteration: 100 / 250 [ 40%]  (Adaptation)
## Chain 1: Iteration: 150 / 250 [ 60%]  (Adaptation)
## Chain 1: Iteration: 200 / 250 [ 80%]  (Adaptation)
## Chain 1: Success! Found best value [eta = 1] earlier than expected.
## Chain 1: 
## Chain 1: Begin stochastic gradient ascent.
## Chain 1:   iter             ELBO   delta_ELBO_mean   delta_ELBO_med   notes 
## Chain 1:    100      -165862.849             1.000            1.000
## Chain 1:    200      -163604.699             0.507            1.000
## Chain 1:    300      -162871.867             0.339            0.014
## Chain 1:    400      -162602.532             0.255            0.014
## Chain 1:    500      -162488.566             0.204            0.004   MEDIAN ELBO CONVERGED
## Chain 1: 
## Chain 1: Drawing a sample of size 1000 from the approximate posterior... 
## Chain 1: COMPLETED.

Review of models

This section reviews models we’ve learned throughout the quarter and rewrites them using rstanarm. The data sets come from previous labs and homework assignments.

Linear regression

We can fit a linear regression using stan_glm() with the gaussian() likelihood family, which is the default.

This is like glm(), but with priors that you specify in the arguments. If you leave the priors out, you’ll get some sensible default priors—but it’s better to be explicit. If you want flat priors (and you probably don’t!), you’d have to set that prior = NULL.

unionDensity <- read_csv("data/unionDensity.csv")

fit_normal <- stan_glm(union ~ left + size + concen, 
                       data = unionDensity, 
                       family = gaussian(),
                       prior = normal(0, 1),
                       prior_intercept = normal(0, 10),
                       prior_aux = exponential(1))

We didn’t scale the data. Why not? Each prior has an autoscale option, which is TRUE by default. The prior scales are adjusted automatically based on the data. We can see that with prior_summary():

prior_summary(fit_normal)
## Priors for model 'fit_normal' 
## ------
## Intercept (after predictors centered)
##  ~ normal(location = 0, scale = 10)
##      **adjusted scale = 187.53
## 
## Coefficients
##  ~ normal(location = [0,0,0], scale = [1,1,1])
##      **adjusted scale = [ 0.56,11.54,58.04]
## 
## Auxiliary (sigma)
##  ~ exponential(rate = 1)
##      **adjusted scale = 18.75 (adjusted rate = 1/adjusted scale)
## ------
## See help('prior_summary.stanreg') for more details

And the results of the model:

summary(fit_normal)
## 
## Model Info:
## 
##  function:     stan_glm
##  family:       gaussian [identity]
##  formula:      union ~ left + size + concen
##  algorithm:    sampling
##  priors:       see help('prior_summary')
##  sample:       4000 (posterior sample size)
##  observations: 20
##  predictors:   4
## 
## Estimates:
##                 mean   sd    2.5%   25%   50%   75%   97.5%
## (Intercept)    89.2   54.6 -18.7   54.0  90.0 126.9 194.7  
## left            0.3    0.1   0.1    0.2   0.3   0.3   0.4  
## size           -5.9    3.6 -12.9   -8.3  -5.9  -3.5   1.2  
## concen          3.0   18.5 -32.3   -9.4   3.0  15.4  40.2  
## sigma          10.9    2.0   7.9    9.5  10.6  12.0  15.8  
## mean_PPD       54.1    3.5  47.1   51.7  54.1  56.4  61.0  
## log-posterior -86.1    1.8 -90.4  -87.1 -85.7 -84.8 -83.8  
## 
## Diagnostics:
##               mcse Rhat n_eff
## (Intercept)   1.2  1.0  2087 
## left          0.0  1.0  3071 
## size          0.1  1.0  2125 
## concen        0.4  1.0  2136 
## sigma         0.0  1.0  2074 
## mean_PPD      0.1  1.0  3652 
## log-posterior 0.1  1.0  1167 
## 
## For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
plot(fit_normal)

Multilevel models

For multilevel models, stan_glmer() is like lme4::glmer().

bangladesh <- read_csv("data/bangladesh.csv")

fit_mlm <- stan_glmer(use.contraception ~ living.children + age.centered + 
                        urban + (1 | district), 
                      data = bangladesh, 
                      family = binomial("logit"),
                      prior = normal(0, 2.5),
                      prior_intercept = normal(0, 10), 
                      prior_covariance = decov(shape = 1,
                                               scale = 1))

What’s a bit different here is prior_covariance. We’ll talk about priors on covariance and correlation matrices in lecture. In this case, since there’s only a varying intercept, this reduces to a prior on the variance. The prior turns out to be a Gamma distribution on the standard deviation, which with shape = scale = 1 simplifies to tau ~ exponential(1);.

prior_summary(fit_mlm)
## Priors for model 'fit_mlm' 
## ------
## Intercept (after predictors centered)
##  ~ normal(location = 0, scale = 10)
## 
## Coefficients
##  ~ normal(location = [0,0,0], scale = [2.5,2.5,2.5])
##      **adjusted scale = [2.01,0.28,2.50]
## 
## Covariance
##  ~ decov(reg. = 1, conc. = 1, shape = 1, scale = 1)
## ------
## See help('prior_summary.stanreg') for more details
print(fit_mlm)
## stan_glmer
##  family:       binomial [logit]
##  formula:      use.contraception ~ living.children + age.centered + urban + 
##     (1 | district)
##  observations: 1934
## ------
##                 Median MAD_SD
## (Intercept)     -1.9    0.2  
## living.children  0.4    0.1  
## age.centered     0.0    0.0  
## urban            0.7    0.1  
## 
## Error terms:
##  Groups   Name        Std.Dev.
##  district (Intercept) 0.5     
## Num. levels: district 60 
## 
## Sample avg. posterior predictive distribution of y:
##          Median MAD_SD
## mean_PPD 0.4    0.0   
## 
## ------
## * For help interpreting the printed output see ?print.stanreg
## * For info on the priors used see ?prior_summary.stanreg

GLMs

With different families for the likelihood, we get GLMs for count data.

roaches <- read_csv("data/roaches.csv")

roaches <- 
  roaches %>%
  mutate(roach1 = roach1/100)

fit_poisson <- stan_glm(y ~ roach1 + treatment + senior,
                        offset = log(exposure2),
                        data = roaches,
                        family = poisson(link = "log"), 
                        prior = normal(0, 2.5), 
                        prior_intercept = normal(0, 10))
fit_negbinom <- stan_glm(y ~ roach1 + treatment + senior,
                         offset = log(exposure2),
                         data = roaches,
                         family = neg_binomial_2(link = "log"), 
                         prior = normal(0, 2.5), 
                         prior_intercept = normal(0, 10), 
                         prior_aux = exponential(1))

Shrinkage

The global scales for shrinkage below come approximately from the estimates in Lab 8.

The lasso prior isn’t quite the same as the laplace prior—it actually puts a prior on the global scale.

Prostate <- read_csv("data/Prostate.csv")

f <- lpsa ~ lcavol + lweight + age + lbph + svi + lcp + gleason + pgg45 

fit_ridge <- stan_glm(f, data = Prostate, 
                      family = gaussian(), 
                      prior = normal(0, 0.33), 
                      prior_intercept = normal(0, 10), 
                      prior_aux = exponential(1))

fit_lasso <- stan_glm(f, data = Prostate, 
                      family = gaussian(), 
                      prior = laplace(0, 0.26), 
                      prior_intercept = normal(0, 10), 
                      prior_aux = exponential(1))

fit_lasso_v2 <- stan_glm(f, data = Prostate, 
                         family = gaussian(), 
                         prior = lasso(df = 1, 0, 2.5), 
                         prior_intercept = normal(0, 10), 
                         prior_aux = exponential(1))

You can compare models of the same outcome with loo:

compare_models(loo(fit_ridge), 
               loo(fit_lasso))
## 
## Model comparison: 
## (negative 'elpd_diff' favors 1st model, positive favors 2nd) 
## 
## elpd_diff        se 
##       0.2       0.4

hs is not quite the original horseshoe prior we saw before; it’s what’s called a regularized horseshoe. (As mentioned before, hierarchical shrinkage to induce sparsity is an active area of Bayesian research!)

# `hs` doesn't autoscale, so we scale the data
Prostate_scaled <- 
  Prostate %>%
  mutate_all(function(x) scale(x)[, 1])

# calculate a global scale
# following Piironen and Vehtari (2017)
K <- 8  # total number of coefficients
k0 <- 2  # guess for number of non-zero coefficients
global_scale <- k0 / (K - k0) / sqrt(nrow(Prostate_scaled))

fit_hs <- stan_glm(f, data = Prostate_scaled,
                   family = gaussian(),
                   prior = hs(df = 1,
                              global_df = 1,
                              global_scale =  global_scale,
                              slab_df = 4,
                              slab_scale = 2.5),
                   prior_intercept = normal(0, 10),
                   prior_aux = exponential(1))

Limitations of rstanarm

What can’t you do in rstanarm? Some important and useful models are missing from above. Two major cases we’ve learned about this quarter:

Robust regression: There’s no Student-T likelihood, either with fixed or estimated nu.

Regularized regression: except for the lasso, there’s not a way to put a prior on the global scale and estimate it.

With rstanarm, you’re giving up extensibility and flexibility for speed and convenience. There are definitely some cases where you’ll want to write models in Stan directly, but rstanarm is often a good place to start.

What else can rstanarm do?

Some examples:

stan_lm takes a different approach from stan_glm: It does a QR decomposition of X and puts a prior on \(R^2\). http://mc-stan.org/rstanarm/articles/lm.html

stan_polr exist for ordinal regression, and does interesting things to the priors on the cutpoints. http://mc-stan.org/rstanarm/articles/polr.html

See the articles here for more: http://mc-stan.org/rstanarm/articles/index.html

Appendix

This is how the text data were generated:

library(tidytext)
library(tidyverse)

data("AssociatedPress", package = "topicmodels")

# https://www.tidytextmining.com/dtm.html#tidy-dtm
ap_tidy <- tidy(AssociatedPress)

ap_tidy_long <-
  ap_tidy %>%
  # only the first 100 documents
  filter(document %in% 1:100) %>%
  # https://github.com/tidyverse/tidyr/issues/279
  mutate(ids = map(count, seq_len)) %>% 
  unnest()

ap_tidy_long <- 
  ap_tidy_long %>%
  mutate(w = as.integer(as_factor(term))) %>%
  select(-ids, -count)

write_csv(ap_tidy_long, "data/associated_press.csv")
sessionInfo()
## R version 3.5.3 (2019-03-11)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS High Sierra 10.13.2
## 
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] bindrcpp_0.2.2     bayesplot_1.6.0    forcats_0.3.0     
##  [4] stringr_1.3.1      dplyr_0.7.5        purrr_0.2.5       
##  [7] readr_1.1.1        tidyr_0.8.1        tibble_1.4.2      
## [10] tidyverse_1.2.1    rstanarm_2.18.2    Rcpp_1.0.0        
## [13] rstan_2.18.2       StanHeaders_2.18.1 ggplot2_3.1.0     
## 
## loaded via a namespace (and not attached):
##  [1] nlme_3.1-137       matrixStats_0.53.1 xts_0.10-2        
##  [4] lubridate_1.7.4    threejs_0.3.1      httr_1.3.1        
##  [7] rprojroot_1.3-2    tools_3.5.3        backports_1.1.2   
## [10] R6_2.3.0           DT_0.4             lazyeval_0.2.1    
## [13] colorspace_1.4-0   withr_2.1.2        tidyselect_0.2.4  
## [16] gridExtra_2.3      prettyunits_1.0.2  mnormt_1.5-5      
## [19] processx_3.1.0     compiler_3.5.3     rvest_0.3.2       
## [22] cli_1.0.0          xml2_1.2.0         shinyjs_1.0       
## [25] labeling_0.3       colourpicker_1.0   scales_1.0.0      
## [28] dygraphs_1.1.1.4   psych_1.8.4        ggridges_0.5.0    
## [31] callr_2.0.4        digest_0.6.18      foreign_0.8-71    
## [34] minqa_1.2.4        rmarkdown_1.10     base64enc_0.1-3   
## [37] pkgconfig_2.0.1    htmltools_0.3.6    lme4_1.1-17       
## [40] readxl_1.1.0       htmlwidgets_1.2    rlang_0.3.0.1     
## [43] rstudioapi_0.7     shiny_1.2.0        bindr_0.1.1       
## [46] zoo_1.8-2          jsonlite_1.6       crosstalk_1.0.0   
## [49] gtools_3.8.1       inline_0.3.15      magrittr_1.5      
## [52] loo_2.1.0          Matrix_1.2-15      munsell_0.5.0     
## [55] stringi_1.2.3      yaml_2.1.19        MASS_7.3-51.1     
## [58] pkgbuild_1.0.3     plyr_1.8.4         grid_3.5.3        
## [61] parallel_3.5.3     promises_1.0.1     crayon_1.3.4      
## [64] miniUI_0.1.1.1     lattice_0.20-38    haven_1.1.1       
## [67] splines_3.5.3      hms_0.4.2          knitr_1.20        
## [70] pillar_1.2.3       igraph_1.2.1       markdown_0.9      
## [73] shinystan_2.5.0    reshape2_1.4.3     codetools_0.2-16  
## [76] stats4_3.5.3       rstantools_1.5.0   glue_1.3.1        
## [79] evaluate_0.10.1    modelr_0.1.2       nloptr_1.2.1      
## [82] httpuv_1.4.5.1     cellranger_1.1.0   gtable_0.2.0      
## [85] assertthat_0.2.0   mime_0.6           xtable_1.8-3      
## [88] broom_0.4.4        later_0.7.5        rsconnect_0.8.8   
## [91] survival_2.43-3    shinythemes_1.1.1