Mean-Field Variational Inference Made Easy

来源:互联网 发布:spider python 下载 编辑:程序博客网 时间:2024/04/28 08:38

I had the hardest time trying to understand variational inference. All of the presentations I’ve seen (MacKay, Bishop, Wikipedia, Gelman’s draft for the third edition of Bayesian Data Analysis) are deeply tied up with the details of a particular model being fit. I wanted to see the algorithm and get the big picture before being overwhelmed with multivariate exponential family gymnastics.

Bayesian Posterior Inference

In the Bayesian setting (see my earlier post, What is Bayesian Inference?), we have a joint probability model p(y,\theta) for data yand parameters \theta, usually factored as the product of a likelihood and prior term, p(y,\theta) = p(y|\theta) p(\theta). Given some observed data y, Bayesian predictive inference is based on the posterior density p(\theta|y) \propto p(\theta, y) of the the unknown parameter vector \theta given observed data vector y. Thus we need to be able to estimate the posterior density p(\theta|y) to carry out Bayesian inference. Note that the posterior is a whole density function—we’re not just after a point estimate as in maximum likelihood estimation.

Mean-Field Approximation

Variational inference approximates the Bayesian posterior density p(\theta|y) with a (simpler) density g(\theta|\phi) parameterized by some new parameters \phi. The mean-field form of variational inference factors the approximating density g by component of \theta = \theta_1,\ldots,\theta_J, as

g(\theta|\phi) = \prod_{j=1}^J g_j(\theta_j|\phi_j).

I’m going to put off actually defining the terms g_j until we see how they’re used in the variational inference algorithm.

What Variational Inference Does

The variational inference algorithm finds the value \phi^* for the parameters \phi of the approximation which minimizes the Kullback-Leibler divergence of g(\theta|\phi) from p(\theta|y),

\phi^* = \mbox{arg min}_{\phi} \ \mbox{KL}[ g(\theta|\phi) \ || \ p(\theta|y) ].

The key idea here is that variational inference reduces posterior estimation to an optimization problem. Optimization is typically much faster than approaches to posterior estimation such as Markov chain Monte Carlo (MCMC).

The main disadvantage of variational inference is that the posterior is only approximated (though as MacKay points out, just about any approximation is better than a delta function at a point estimate!). In particular, variational methods systematically underestimate posterior variance because of the direction of the KL divergence that is minimized. Expectation propagation (EP) also converts posterior fitting to optimization of KL divergence, but EP uses the opposite direction of KL divergence, which leads to overestimation of posterior variance.

Variational Inference Algorithm

Given the Bayesian model p(y,\theta), observed data y, and functional terms g_j making up the approximation of the posterior p(\theta|y), the variational inference algorithm is:

  • \phi \leftarrow \mbox{random legal initialization}
  • \mbox{repeat}
    • \phi_{\mbox{\footnotesize old}} \leftarrow \phi
    • \mbox{for } j \mbox{ in } 1:J
      • \mbox{{} \ \ set } \phi_j \mbox{ such that } g_j(\theta_j|\phi_j) = \mathbb{E}_{g_{-j}}[\log p(\theta|y)].
  • \mbox{until } ||\phi - \phi_{\mbox{\footnotesize old}}|| < \epsilon

The inner expectation is a function of \theta_j returning a single non-negative value, defined by

\mathbb{E}_{g_{-j}}[\log p(\theta|y)]

\begin{array}{l}  \mbox{ } = \int_{\theta_1} \ldots \int_{\theta_{j-1}} \int_{\theta_{j+1}} \ldots \int_{\theta_J}  \\[8pt] \mbox{ } \hspace*{0.4in}  g(\theta_1|\phi_1) \times \cdots \times g(\theta_{j-1}|\phi_{j-1}) \times  g(\theta_{j+1}|\phi_{j+1}) \times \cdots \times g(\theta_J|\phi_J)  \times  \log p(\theta|y)  \\[8pt] \mbox{ } \hspace*{0.2in}  d\theta_J \cdots d\theta_{j+1} \ d\theta_{j-1} \cdots d\theta_1  \end{array}

Despite the suggestive factorization of g and the coordinate-wise nature of the algorithm, variational inference does not simply approximate the posterior marginals p(\theta_j|y) independently.

Defining the Approximating Densities

The trick is to choose the approximating factors so that the we can compute parameter values \phi_j such that g_j(\theta_j|\phi_j) = \mathbb{E}_{g_{-j}}[\log p(\theta|y)]. Finding such approximating terms g_j(\theta_j|\phi_j) given a posterior p(\theta|y) is an art form unto itself. It’s much easier for models with conjugate priors. Bishop or MacKay’s books and the Wikipedia present calculations for a wide range of exponential-family models.

What if My Model is not Conjugate?

Unfortunately, I almost never work with conjugate priors (and even if I did, I’m not much of a hand at exponential-family algebra). Therefore, the following paper just got bumped to the top of my must understand queue:

  • Wang, Chong and David M. Blei. 2012–2013. Variational Inference in Nonconjugate Models. arXiv 1209.4360.

It’s great having Dave down the hall on sabbatical this year — one couldn’t ask for a better stand in for Matt Hoffman. They are both insanely good at on-the-fly explanations at the blackboard (I love that we still have real chalk and high quality boards).

原创粉丝点击