Rao-Blackwellization and discrete parameters in Stan

I’m reading a really dense and beautifully written survey of Monte Carlo gradient estimation for machine learning by Shakir Mohamed, Mihaela Rosca, Michael Figurnov, and Andriy Mnih. There are great explanations of everything including variance reduction techniques like coupling, control variates, and Rao-Blackwellization. The latter’s the topic of today’s post, as it relates directly to current Stan practices.

Expecations of interest

In Bayesian inference, parameter estimates and event probabilities and predictions can all be formulated as expectations of functions of parameters conditioned on observed data. In symbols, that’s

displaystyle mathbb{E}[f(Theta) mid Y = y] = int f(theta) cdot p(theta mid y) , textrm{d}theta

for a model with parameter vector Theta and data Y = y. In this post and most writing about probability theory, random variables are capitalized and bound variables are not.

Partitioning variables

Suppose we have two random variables A, B and want to compute an expectation mathbb{E}[f(A, B)]. In the Bayesian setting, this means splitting our parameters Theta = (A, B) into two groups and suppressing the conditioning on Y = y in the notation.

Full sampling-based estimate of expectations

There are two unbiased approaches to computing the expectation mathbb{E}[f(A, B)] using sampling. This first one is traditional, with all random variables in the expectation being sampled.

Draw (a^{(m)}, b^{(m)}) sim p_{A,B}(a, b) for m in 1:M and estimate the expectation as

displaystylemathbb{E}[f(A, B)] approx frac{1}{M} sum_{m=1}^M f(a^{(m)}, b^{(m)}).

Marginalized sampling-based estimate of expectations

The so-called Rao-Blackwellized estimator of the expectation involves marginalizing p_{A,B}(a, b) and sampling b^{(m)} sim p_{B}(b) for m in 1:M. The expectation is then estimated as

displaystyle mathbb{E}[f(A, B)] approx frac{1}{M} sum_{m=1}^M mathbb{E}[f(A, b^{(m)})]

For this estimator to be efficiently computatable, the nested expectation must be efficiently computable,

displaystyle mathbb{E}[f(A, b^{(m)})] = int f(a, b^{(m)}) cdot p(a mid b^{(m)}) , textrm{d}a.

The Rao-Blackwell theorem

The Rao-Blackwell theorem states that the marginalization approach has variance less than or equal to the direct approach. In practice, this difference can be enormous. It will be based on how efficiently we could estimate mathbb{E}[f(A, b^{(m)})] by sampling a^{(n)} sim p_{A mid B}(a mid b^{(m)}),

displaystyle mathbb{E}[f(A, b^{(m)})] approx frac{1}{N} sum_{n = 1}^N f(a^{(n)}, b^{(m)})

Discrete variables in Stan

Stan does not have a sampler for discrete variables. Instead, Rao-Blackwellized estimators must be used, which essentially means marginalizing out the discrete parameters. Thus if A is the vector of discrete parameters in a model, B the vector of continuous parameters, and y the vector of observed data, then the model posterior is p_{A, B mid Y}(a, b mid y).

With a sampler that can efficiently make Gibbs draws (e.g., BUGS or PyMC3), it is tempting to try to compute posterior expectations by sampling,

mathbb{E}[f(A, B) mid y] approx frac{1}{M} sum_{m=1}^M f(a^{(m)}, b^{(m)}) where (a^{(m)}, b^{(m)}) sim p_{A,B}(a, b).

This is almost always a bad idea if it possible to efficiently calculate the inner Rao-Blackwellizization expectation, mathbb{E}[f(A, b^{(m)})]. With discrete variables, the formula is just

mathbb{E}[f(A, b^{(m)})] = sum_{a in A} p(a mid b^{(m)}) cdot f(a, b^{(m)}).

Usually the summation can be done efficiently in models like mixture models where the discrete variables are tied to individual data points or in state-space models like HMMs where the discrete parameters can be marginalized using the forward algorithm. Where this is not so easy is with missing count data or variable selection problems where the posterior combinatorics are intractable.

Gains from marginalizing discrete parameters

The gains to be had from marginalizing discrete parameters are enormous. This is even true of models coded in BUGS or PyMC3. Cole Monnahan, James Thorson, and Trevor Branch wrote a nice survey of the advantages for some ecology models that compares marginalized HMC with Stan to JAGS with discrete sampling and JAGS with marginalization. The takeway here isn’t that HMC is faster than JAGS, but that JAGS with marginalization is a lot faster than JAGS without.

The other place to see the effects of marginalization are in the Stan User’s Guide chapter on latent discrete parameters. The first choice-point example shows how much more efficient the marginalization is by comparing it directly with estimated generated from exact sampling of the discrete parameters conditioned on the continuous ones. This is particularly true of the tail statistics, which can’t be estimated at all with MCMC because too many draws would be required. I had the same experience in coding the Dawid-Skene model of noisy data coding, which was my gateway to Bayesian inference—I had coded it with discrete sampling in BUGS, but BUGS took forever (24 hours compared to 20m for Stan for my real data) and kept crashing on trivial examples during my tutorials.

Marginalization calculations can be found in the MLE literature

The other place marginalization of discrete parameters comes up is in maximum likelihood estimation. For example, Dawid and Skene’s original approach to their coding model used the expectation maximization (EM) algorithm for maximum marginal likelihood estimation. The E-step does the marginalization and it’s exactly the same marginalization as required in Stan for discrete parameters. You can find the marginalization for HMMs in the literature on calculating maximum likelihood estiates of HMMs (in computer science, electrical engineering, etc.) and in the ecology literature for the Cormack-Jolly-Seber model. And they’re in the Stan user’s guide.