knitr::opts_chunk$set(fig.align = "center")

Goals

  • Build intuition for what MCMC is
  • Understand computation for a simple MCMC algorithm (Metropolis)
  • Be aware of the variety of MCMC algorithms that are used in practice
    • Metropolis-Hastings -> Gibbs (JAGS)
    • HMC -> NUTS (Stan)
  • Know what properties a “good” sampler and “good” markov chains have
    • representativeness / convergence
    • accuracy / stability
    • efficiency

Relevant readings:

  • Kruschke Chapter 7 (MCMC in general)
  • Krusckhe Chapter 14, Section 1 (HMC)

The Metropolis algorithm

The Metropolis algorithm is the simplest way to build a Markov chain of samples.

What it does:

  • start somewhere
  • propose a move somewhere else
    • if the probability density is higher there, definitely go there
    • if it’s smaller, maybe go there anyway
  • repeat until you’ve approximated the joint posterior

The key to why this works is how you decide to accept/reject the proposed jump:

p_move = P(proposed) / P(current)

p_move is just 1 if P(proposed) > P(current)

Discrete islands

(Adapted from Kruschke 7.2 and McElreath 8.1.)

Imagine a politician visiting a chain of islands. The politician wants to spend time on each island relative to its population. They can only move to adjacent islands. They don’t know the actual population of any island, but they can figure out the population of an adjacent island relative to their current one when they propose a move.

To keep it simple, let’s assume the 7 islands have relative populations of 1 through 7.

set.seed(123)

relative_pops <- 1:7 # relative population of each island

n_days <- 1e4
positions <- rep(NA, n_days) # we'll fill this up as we go
current <- 4 # pick an island to start on

for (i in 1:n_days) {
  # where are they now?
  positions[i] <- current
  
  # where do they propose to go?
  # they can go left (-1) or right (+1)
  direction <- sample(c(-1, 1), size = 1)
  proposal <- current + direction
  
  # the probability of moving is the ratio of island populations
  prob_move <- relative_pops[proposal] / relative_pops[current]
  # NOTE: in our example, this is the same as this: 
  # prob_move <- proposal / current
  
  # if they propose walking off the island chain, don't do that
  if (proposal < 1 | proposal > 7) prob_move <- 0
  
  # okay, do they move to the next island or not?
  current <- ifelse(runif(1) < prob_move, proposal, current)
}

Now, let’s look at the politician’s trip:

library(tidyverse)

# make a data frame
df <- 
  data_frame(
    n = 1:n_days,
    position = positions
  )

# let's look at the first part of their trip
df_small <- head(df, 100)
ggplot(df_small, aes(x = n, y = position)) + geom_point() + geom_line()

# and a summary of where they were
ggplot(df, aes(x = position)) + geom_bar()

# how does that compare to what we'd expect? 
# try increasing n_days in the algorithm to do better!
df %>% 
  count(position) %>% 
  mutate(frac = nn/sum(nn), 
         expected = position/sum(position)) 
## # A tibble: 7 x 4
##   position    nn   frac expected
##      <dbl> <int>  <dbl>    <dbl>
## 1        1   355 0.0355   0.0357
## 2        2   654 0.0654   0.0714
## 3        3  1068 0.107    0.107 
## 4        4  1534 0.153    0.143 
## 5        5  1864 0.186    0.179 
## 6        6  2142 0.214    0.214 
## 7        7  2383 0.238    0.25

A continuous parameter: Poisson \(\lambda\)

What if there are more possibilities for making a proposal than just left/right?

We can make a symmetric proposal using a normal distribution with some appropriate standard deviation.

  • We don’t want to jump around too far, because we’ll reject too many proposals
  • We don’t want to move around too slowly, because we’ll take too long to move through

Here’s a tweet-sized example from a Poisson distribution:

# https://twitter.com/rlmcelreath/status/732947118785191936
set.seed(123)

N <- 20
lambda <- 2
y <- sum(rpois(N, lambda = lambda))

n_chain <- 1e4
p <- rep(1, n_chain) # we'll fill this up as we go

sd_proposal <- 1/9

for(i in 2:n_chain) {
  r <- p[i-1] # current location
  q = exp(log(r) + rnorm(1, mean = 0, sd = sd_proposal)) # proposed new location
  p[i] <- ifelse(runif(1) < q^y * r^(-y) * exp(-N * (q - r)), 
                 q, 
                 r) # accept or reject? 
}

What’s this mess: q^y * r^(-y) * exp(-N * (q - r))? That’s the ratio of two Poisson probabilities.

Let’s look at p:

plot(p, type = "l") # this is a "trace plot"

summary(p[-c(1:1000)]) # drop the "burn-in" period
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##   1.288   2.120   2.342   2.354   2.581   3.911

Two continuous parameters

To visualize how Metropolis works for two continuous parameters, go to this demo here: https://chi-feng.github.io/mcmc-demo/app.html

Choose “RandomWalkMH” as the algorithm and “standard” as the target distribution.

Gibbs sampling

Can we make smarter proposals than random ones? If we can, we’ll sample more efficiently. Gibbs sampling exploits conjugacy to sample more efficiently than the Metropolis algorithm.

The idea is that, even if the joint posterior is impossible to calculate analytically, we might be able to calculate the conditional posterior for one parameter if we hold all the others constant. Because proposals are drawn from this conditional posterior, they can always be accepted. Gibbs sampling cycles through all the parameters, one at a time.

Note that, despite the name (“Just another Gibbs sampler”), a Gibbs sampler is only one of the samplers that JAGS uses.

Here’s an implementation of a Gibbs sampler in R for a normal distribution, from https://stats.stackexchange.com/questions/266665/gibbs-sampler-examples-in-r.

It samples using the conjugate distributions for a normal with a known precision and then for a normal with a known mean: https://en.wikipedia.org/wiki/Conjugate_prior#When_likelihood_function_is_a_continuous_distribution

# summary statistics of sample
n    <- 30
ybar <- 15
s2   <- 3

# sample from the joint posterior (mu, tau | data)
mu       <- rep(NA, 11000)
tau      <- rep(NA, 11000)
T_burnin <- 1000    # burnin
tau[1]   <- 1  # initialisation
for(i in 2:11000) {   
    mu[i]  <- rnorm(n = 1, mean = ybar, sd = sqrt(1 / (n * tau[i - 1])))    
    tau[i] <- rgamma(n = 1, shape = n / 2, scale = 2 / ((n - 1) * s2 + n * (mu[i] - ybar)^2))
}
mu  <- mu[-(1:T_burnin)]   # remove burnin
tau <- tau[-(1:T_burnin)] # remove burnin
hist(mu)

hist(tau)

Here’s code that implements a Gibbs sampler in R for simple linear regression:

https://github.com/stablemarkets/BayesianTutorials/blob/master/SimpleLinearReg/BayesModel.R

https://www.r-bloggers.com/bayesian-simple-linear-regression-with-gibbs-sampling-in-r/

HMC

Hamiltonian Monte Carlo (HMC) uses a physics analogy, treating the log probability of the posterior as a surface, to improve upon the basic Metropolis sampler and even the Gibbs sampler in certain circumstances. The key thing HMC needs is to be able to calculate the gradient at a given point.

Look at different versions of HMC on the interactive demo page, including at least one version of NUTS:

https://chi-feng.github.io/mcmc-demo/app.html

Compare how HMC and random walk MH sample for the standard (multivariate normal) distribution, and for one of the weirder distributions, like the banana or the donut. What do you notice?

You might want to read this blog post by Richard McElreath, which makes the case for HMC over Gibbs sampling: http://elevanth.org/blog/2017/11/28/build-a-better-markov-chain/

Stan uses a version of a variant of HMC called NUTS:

https://mc-stan.org/docs/2_19/reference-manual/hmc-chapter.html

https://arxiv.org/abs/1701.02434

MCMC diagnostics

What do we want our Markov chains to do? How can we assess if they did it, or check for problems?

Let’s fit a simple Stan model with two parameters.

library(rstan)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)

set.seed(20190502)
true_mu <- 6
true_sigma <- 3
sample_size <- 100
sims <- rnorm(sample_size, true_mu, true_sigma)
fit <- stan("stan/normal.stan", data = list(y = sims, N = sample_size))

Representativeness

We want our chains to fully “explore” the posterior, spending most of their time in the highest-probability areas. We don’t want them to be too influenced by where they started, or to get stuck somewhere different.

We’re looking for chains that overlap and that look stationary, after they warm up.

We can visualize this using trace plots:

stan_trace(fit)

One metric that can flag situations where convergence has failed to occur is the the Gelman-Rubin diagnostic (“r-hat”):

stan_rhat(fit)
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

R-hat should be close to 1. If it’s too far above 1, you should worry.

(How far is too far? People used to say 1.1, now they say 1.01.)

Accuracy and stability

The effective sample size (ESS) takes into account autocorrelation between our samples, to tell us how much independent information they’re giving us about each parameter.

stan_ess(fit)
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

The Monte Carlo standard error (MCSE) uses the ESS and the standard deviation of the samples to tell us how stably we’re estimating the posterior mean of each parameter. (So smaller MCSE is better.)

stan_mcse(fit)
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

If those diagnostics turned out okay, we can probably start to trust our estimates! So now it’s safe to look at them. Now that we’ve looked at some of those diagnostics individually, the printed output of a Stan model should make a bit more sense:

fit
## Inference for Stan model: normal.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##          mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff
## mu       6.49    0.01 0.35    5.79    6.25    6.49    6.72    7.16  3660
## sigma    3.44    0.00 0.24    3.01    3.27    3.43    3.60    3.96  3483
## lp__  -177.10    0.02 0.97 -179.76 -177.49 -176.80 -176.39 -176.11  1980
##       Rhat
## mu       1
## sigma    1
## lp__     1
## 
## Samples were drawn using NUTS(diag_e) at Thu May  2 13:15:47 2019.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

We can plot coefficient point estimates, intervals, and densities:

stan_plot(fit)
## ci_level: 0.8 (80% intervals)
## outer_level: 0.95 (95% intervals)

stan_dens(fit)

Parameters can be correlated, so it’s also worth plotting them in pairs:

# the quick way
pairs(fit, pars = c("mu", "sigma"))

# the pretty way
# permuted = TRUE messes things up, so set it to FALSE
draws <- rstan::extract(fit, pars = c("mu", "sigma"), permuted = FALSE)
bayesplot::mcmc_pairs(draws)

Efficiency

Finally, we want MCMC to sample efficiently. If sampling is too slow, there are some things we might try to speed it up:

  • Parallel chains
  • Tinkering with parameters of the sampler, like adaptation parameters or step size
  • Changing the parameterization of the model
sessionInfo()
## R version 3.5.3 (2019-03-11)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: macOS High Sierra 10.13.2
## 
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] rstan_2.18.2       StanHeaders_2.18.1 bindrcpp_0.2.2    
##  [4] forcats_0.3.0      stringr_1.3.1      dplyr_0.7.5       
##  [7] purrr_0.2.5        readr_1.1.1        tidyr_0.8.1       
## [10] tibble_1.4.2       ggplot2_3.1.0      tidyverse_1.2.1   
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_1.0.0         lubridate_1.7.4    lattice_0.20-38   
##  [4] prettyunits_1.0.2  assertthat_0.2.0   rprojroot_1.3-2   
##  [7] digest_0.6.18      psych_1.8.4        utf8_1.1.4        
## [10] R6_2.3.0           cellranger_1.1.0   plyr_1.8.4        
## [13] ggridges_0.5.0     backports_1.1.2    stats4_3.5.3      
## [16] evaluate_0.10.1    httr_1.3.1         pillar_1.2.3      
## [19] rlang_0.3.0.1      lazyeval_0.2.1     readxl_1.1.0      
## [22] rstudioapi_0.7     callr_2.0.4        rmarkdown_1.10    
## [25] labeling_0.3       foreign_0.8-71     loo_2.1.0         
## [28] munsell_0.5.0      broom_0.4.4        compiler_3.5.3    
## [31] modelr_0.1.2       pkgconfig_2.0.1    mnormt_1.5-5      
## [34] pkgbuild_1.0.3     htmltools_0.3.6    tidyselect_0.2.4  
## [37] gridExtra_2.3      codetools_0.2-16   matrixStats_0.53.1
## [40] crayon_1.3.4       withr_2.1.2        grid_3.5.3        
## [43] nlme_3.1-137       jsonlite_1.6       gtable_0.2.0      
## [46] magrittr_1.5       scales_1.0.0       KernSmooth_2.23-15
## [49] cli_1.0.0          stringi_1.2.3      reshape2_1.4.3    
## [52] xml2_1.2.0         tools_3.5.3        glue_1.3.1        
## [55] hms_0.4.2          processx_3.1.0     parallel_3.5.3    
## [58] yaml_2.1.19        inline_0.3.15      colorspace_1.4-0  
## [61] bayesplot_1.6.0    rvest_0.3.2        knitr_1.20        
## [64] bindr_0.1.1        haven_1.1.1