Faster than ever before: Hamiltonian Monte Carlo using an adjoint-differentiated Laplace approximation

Charles Margossian, Aki Vehtari, Daniel Simpson, Raj Agrawal write:

Gaussian latent variable models are a key class of Bayesian hierarchical models with applications in many fields. Performing Bayesian inference on such models can be challenging as Markov chain Monte Carlo algorithms struggle with the geometry of the resulting posterior distribution and can be prohibitively slow. An alternative is to use a Laplace approximation to marginalize out the latent Gaussian variables and then integrate out the remaining hyperparameters using dynamic Hamiltonian Monte Carlo, a gradient-based Markov chain Monte Carlo sampler. To implement this scheme efficiently, we derive a novel adjoint method that propagates the minimal information needed to construct the gradient of the approximate marginal likelihood. This strategy yields a scalable method that is orders of magnitude faster than state of the art techniques when the hyperparameters are high dimensional. We prototype the method in the probabilistic programming framework Stan and test the utility of the embedded Laplace approximation on several models, including one where the dimension of the hyperparameter is ∼6,000. Depending on the cases, the benefits are either a dramatic speed-up, or an alleviation of the geometric pathologies that frustrate Hamiltonian Monte Carlo.

They conclude:

Our next step is to further develop the prototype for Stan. We are also aiming to incorporate features that allow for a high performance implementation, as seen in the packages INLA, TMB, and GPstuff. Examples includes support for sparse matrices required to fit latent Markov random fields, parallelization and GPU support. We also want to improve the flexibility of the method by allowing users to specify their own likelihood. In this respect, the implementation in TMB is exemplary. It is in principle possible to apply automatic differentiation to do higher-order automatic differentiation and most libraries, including Stan, support this; but, along with feasibility, there is a question of efficiency and practicality. The added flexibility also burdens us with more robustly diagnosing errors induced by the approximation. There is extensive literature on log-concave likelihoods but less so for general likelihoods. Future work will investigate diagnostics such as importance sampling, leave-one-out cross-validation, and simulation based calibration.

One thing I can’t quite figure out from skimming the paper is whether the method helps for regular old multilevel linear and logistic regressions, no fancy Gaussian processes, just batches of varying intercepts and maybe varying slopes. I guess the method will work in such examples; it’s just not clear to me how much of a speed improvement you’d get. This is an important question to me because I see these sorts of problems all the time.

I’m also wondering if some of the computation could be improved by including stronger priors on the hyperparameters. Again, that’s an idea that’s been coming up a lot lately, in a wide range of applications.

Finally, I’m wondering how much parallelization is going on. Is this new algorithm faster because it requires fewer computations or because it is more parallelizable so you can get wall-time improvements by plugging in more processors? Either way is fine; I’d just like to have a better sense of how the method is working and where the speedup is coming from:

The y-axis of this graph should be on the log scale, but whatever.