Super-duper online matrix derivative calculator vs. the matrix normal (for Stan)

I’m implementing the matrix normal distribution for Stan, which provides a multivariate density for a matrix with covariance factored into row and column covariances.

The motivation

A new colleague of mine at Flatiron’s Center for Comp Bio, Jamie Morton, is using the matrix normal to model the ocean biome. A few years ago, folks in Richard Bonneau‘s group at Flatiron managed to extend the state of the art from a handful of dimensions to a few dozen dimensions using Stan, as described here

That’s a couple orders of magnitude short of the models Jamie would like to fit. So instead of Stan, which he’s used before and is using for other smaller-scale problems, he’s turning to the approach in the following paper:

I’d just like to add matrix normal to Stan and see if we can scale up Äijö et al.’s results a bit.

The matrix normal

The density is defined for an

  • N times P observation matrix Y

with parameters

  • N times P mean matrix M,
  • positive-definite P times P column covariance matrix U, and
  • positive-definite N times N row covariance matrix V

displaystyle textrm{matrix_normal}(Y mid M, U, V) = frac{displaystyle expleft(-frac{1}{2} cdot textrm{tr}left(V^{-1} cdot (Y - M)^{top} cdot U^{-1} cdot (Y - M)right)right)}{displaystyle (2 cdot pi)^{{N cdot P} / {2}} cdot textrm{det}(V)^{{N} / {2}} cdot textrm{det}(U)^{{P} / {2}}}.

It relates to the multivariate normal through vectorization (stacking the columns of a matrix) and Kronecker products as

textrm{matrix_normal}(Y mid M, U, V) = textrm{multi_normal}(textrm{vec}(Y) mid textrm{vec}(M), V otimes U).

After struggling with tensor differentials for the derivatives of the matrix normal with respect to its covariance matrix parameters long enough to get a handle on how to do it but not long enough to get good at it or trust my results, I thought I’d try Wolfram Alpha. I could not make that work. to the rescue

After a bit more struggling, I entered the query [matrix derivative software] into Google and the first hit was a winner:

This beautiful piece of online software has a 1990s interface and 2020s functionality. I helped out by doing the conversion to log scale and dropping constant terms,

begin{array}{l}displaystylelog textrm{matrix_normal}(Y mid M, U, V) [8pt] qquad = -frac{1}{2} cdot left(textrm{tr}!left[V^{-1} cdot (Y - M)^{top} cdot U^{-1} cdot (Y - M)right] - N cdot log textrm{det}(V) - P cdot log textrm{det}(U) right) + textrm{const}.end{array}

Some of these terms have surprisingly simple derivatives, like frac{partial}{partial U} log textrm{det}(U) = U^{-1}. Nevertheless, when you push the differential through, you wind up with tensor derivatives inside that trace which are difficult to manipulate without guesswork at my level of matrix algebra skill.
Unlike me, the matrix calculus site chewed this up and spit it out in neatly formatted LaTeX in a split second:

Despite the simplicity, this is a really beautiful interface. You enter a formula and it parses out the variables and asks you for their shape (scalar, vector, row vector, or matrix). Another win for delcaring data types in that it lets the interface resolve all the symbols. There may be a way to do that in Wolfram Alpha, but it’s not baked into their interface.

The paper

There’s a very nice paper behind that explains what they did and how it relates to autodiff.

I haven’t digested it all, but as you may suspect, they implement a tensor algebra for derivatives. Here’s the meaty part of the abstract.

The framework can be used for symbolic as well as for forward and reverse mode automatic differentiation. Experiments show a speedup of up to two orders of magnitude over state-of-the-art frameworks when evaluating higher order derivatives on CPUs and a speedup of about three orders of magnitude on GPUs.

As far as I can tell, the tool is only first order. But that’s all we need for specialized forward and reverse mode implementations in Stan. The higher-order derivatives start from reverse node then nest in one more forward mode instances. I’m wondering if this will give us a better way to specialize fvar<var> in Stan (the type used for second derivatives).

Some suggestions for the tool

It wouldn’t be Andrew’s blog without suggestions about improving interfaces. My suggestions are to

  1. check the “Common subexpressions” checkbox by default (the alternative is nice for converting to code),
  2. stick to a simple checkbox interface indicating with a check that common subexpression sharing is on (as is, the text says “On” with an empty checkbox when it’s off and “Off” with an empty checkbox when it is on),
  3. get rid of extra vertical whitespace output at hte bottom of the return box, and
  4. provide a way to make the text entry box bigger and multiple lines (as is, I composed in emacs and cut-and-paste into the interface), and
  5. allow standard multi-character identifiers (it seems to only allow single character variable names).

I didn’t try the Python output, but that’s a great idea if it produces code to compute both the function and partials efficiently.

Translating to the Stan math library

With the gradient in hand, it’s straightforward to define efficient forward-mode and reverse-mode autodiff in Stan using our general operands-and-partials builder structure. But now that I look at our code, I see our basic multivariate-normal isn’t even using that efficient code. So I’ll have to fix that, too, which should be a win for everyone.

What I’m actually going to do is define the matrix normal in terms of Cholesky factors. That has the huge advantage of not needing to put a whole covariance matrix together only to factor it (we need a log determinant). Sticking to Cholesky factors the whole way is much more arithmetically stable and requires only quadratic time to factor rather than cubic.

In terms of using the matrix derivative site, just replace U and V with (L_U * L_U’) and (L_V & L_V’) in the formulas and it should be good to go. Actually, it won’t because the parser apparently requires single letter variabes. So I wound up just using (U * U’), which does work.

Here’s the formula I used for the log density up to a constant:

-1/2 * (N * log(det(V * V')) + P * log(det(U * U')) + tr([inv(V * V') * (Y - M)'] * [inv(U * U') * (Y - M)]))

Because there’s no way to tell it that U and V are Cholesky factors, and hence that U * U’ is symmetric and positive definite, I had to do some reduction by hand, such as

inv(U * U') * U = inv(U')

That yields a whole lot of common subexpressions between the function and its gradient, and I think I can go a bit further by noting (U * U’) = (U * U’)’.

matrix_normal_lpdf(Y | M, U, V)
= -1/2 * (N * log(det(V * V'))
  + P * log(det(U * U'))
  + tr([inv(V * V') * (Y - M)'] * [inv(U * U') * (Y - M)]))
  + const

d/d.Y: -[inv(U * U') * (Y - M)] * inv(V * V')

d/d.M: -[inv(U * U') * (Y - M)] * inv(V * V')

d/d.U: -P * inv(U')
       - inv(U * U') * (Y - M)] * [inv(V * V') * (Y - M)'] * inv(U')

d/d.V: -N * inv(V')
       -[inv(V * V') * (Y - M)'] * [inv(U * U') * (Y - M)] * inv(V')

This is totally going in the book

Next up, I’d like to add all the multivariate densities to the following work in progress.

There’s a draft up on GitHub with all the introductory material and a reference C++ implementation and lots of matrix derivatives and even algebraic solvers and HMMs. Lots still to do, though.

We haven’t done the main densities yet, so we can start with multivariate normal and T with all four parameterizations (covariance, precision and Cholesky factors of same), along with the (inverse) Wisharts and LKJ we need for priors. This matrix calculus site’s going to make it easy to deal with all the Cholesky-based parameterizations.

If you’d like to help implementing these in Stan or in joining the effort for the handbook, let me know. I recently moved from Columbia University to the Flatiron Institute and I’m moving to a new email address: I didn’t change GitHub handles. Flatiron is great, by the way, and I’m still working with the Stan dev team including Andrew.