I’m reading a really dense and beautifully written survey of Monte Carlo gradient estimation for machine learning by Shakir Mohamed, Mihaela Rosca, Michael Figurnov, and Andriy Mnih. There are great explanations of everything including variance reduction techniques like coupling, control variates, and Rao-Blackwellization. The latter’s the topic of today’s post, as it relates directly to current Stan practices.
Expecations of interest
In Bayesian inference, parameter estimates and event probabilities and predictions can all be formulated as expectations of functions of parameters conditioned on observed data. In symbols, that’s
for a model with parameter vector and data In this post and most writing about probability theory, random variables are capitalized and bound variables are not.
Suppose we have two random variables and want to compute an expectation In the Bayesian setting, this means splitting our parameters into two groups and suppressing the conditioning on in the notation.
Full sampling-based estimate of expectations
There are two unbiased approaches to computing the expectation using sampling. This first one is traditional, with all random variables in the expectation being sampled.
Draw for and estimate the expectation as
Marginalized sampling-based estimate of expectations
The so-called Rao-Blackwellized estimator of the expectation involves marginalizing and sampling for . The expectation is then estimated as
For this estimator to be efficiently computatable, the nested expectation must be efficiently computable,
The Rao-Blackwell theorem
The Rao-Blackwell theorem states that the marginalization approach has variance less than or equal to the direct approach. In practice, this difference can be enormous. It will be based on how efficiently we could estimate by sampling
Discrete variables in Stan
Stan does not have a sampler for discrete variables. Instead, Rao-Blackwellized estimators must be used, which essentially means marginalizing out the discrete parameters. Thus if is the vector of discrete parameters in a model, the vector of continuous parameters, and the vector of observed data, then the model posterior is
With a sampler that can efficiently make Gibbs draws (e.g., BUGS or PyMC3), it is tempting to try to compute posterior expectations by sampling,
This is almost always a bad idea if it possible to efficiently calculate the inner Rao-Blackwellizization expectation, With discrete variables, the formula is just
Usually the summation can be done efficiently in models like mixture models where the discrete variables are tied to individual data points or in state-space models like HMMs where the discrete parameters can be marginalized using the forward algorithm. Where this is not so easy is with missing count data or variable selection problems where the posterior combinatorics are intractable.
Gains from marginalizing discrete parameters
The gains to be had from marginalizing discrete parameters are enormous. This is even true of models coded in BUGS or PyMC3. Cole Monnahan, James Thorson, and Trevor Branch wrote a nice survey of the advantages for some ecology models that compares marginalized HMC with Stan to JAGS with discrete sampling and JAGS with marginalization. The takeway here isn’t that HMC is faster than JAGS, but that JAGS with marginalization is a lot faster than JAGS without.
The other place to see the effects of marginalization are in the Stan User’s Guide chapter on latent discrete parameters. The first choice-point example shows how much more efficient the marginalization is by comparing it directly with estimated generated from exact sampling of the discrete parameters conditioned on the continuous ones. This is particularly true of the tail statistics, which can’t be estimated at all with MCMC because too many draws would be required. I had the same experience in coding the Dawid-Skene model of noisy data coding, which was my gateway to Bayesian inference—I had coded it with discrete sampling in BUGS, but BUGS took forever (24 hours compared to 20m for Stan for my real data) and kept crashing on trivial examples during my tutorials.
Marginalization calculations can be found in the MLE literature
The other place marginalization of discrete parameters comes up is in maximum likelihood estimation. For example, Dawid and Skene’s original approach to their coding model used the expectation maximization (EM) algorithm for maximum marginal likelihood estimation. The E-step does the marginalization and it’s exactly the same marginalization as required in Stan for discrete parameters. You can find the marginalization for HMMs in the literature on calculating maximum likelihood estiates of HMMs (in computer science, electrical engineering, etc.) and in the ecology literature for the Cormack-Jolly-Seber model. And they’re in the Stan user’s guide.