Variational Inference

Variational Inference:

1. Why approximate inference?

Analytical inference ("inference" here refer to computing the posterior distribution, not the one used in the typical deep learning literature, which is a forward pass at test time) is easy when the conjugate priors exist, and hard otherwise. So, in practice, the posterior is approximated. This can be problem at times with the underestimation of the uncertainty.


 
Variational inference (other name for the approximate inference) requires the following steps:
  1. Selection the family of distribution Q (called variational family).
  2. Find the best approximation by minimizing the KL divergence.
In variational inference the selection of Q can lead to different results - bigger Q results in more accuracy but is often difficult to compute (set of all the possible distributions); smaller Q results in less accurate results but is often easy to compute. The true posterior is to lie in Q, but it is often difficult to know. Note that in KL divergence the posterior can be unnormalized distribution, as the evidence or normalizing term does not play any role in optimization of the KL divergence.

 
2. What is the mean field model?

In the mean field variational inference the family of the variational distribution is chosen over the latent variable and are products (factorized in fact) of each local approximates (the distribution over the single latent variables). An example is given in slide 2 where the posterior is a normal distributional with full covariance. With the mean field variational inference the covariance matrix is limited to the diagonal and therefore, the approximation leads to errors.

See the slide 3 for the formulation of the optimization problem where we have the products of approximating distribution over the each latent variables. Until the convergence the coordinate descent - finding q1 that minimizes the KL divergence, then q2, q3 and so on until the convergence can be reached.

As shown in slide 4 the analytical solution for this optimization is derived - just follow through the derivation as its not difficult. The final solution can also be found in this slide.


Follow the example on Ising model in the slide below, as to get more insights.

 

3. Variational Bayes EM.

The EM algorithm involves the optimization of the KL divergence in the E-step. Whenever the full posterior is difficult to compute the variational inference can be used to approximate the true posterior. Note that the full inference is by far the most accurate, but it is the slowest. The accuracy and speed trade off results in the following orders: full inference - mean field approximation - EM algorithm - variational EM - crisp EM.

Comments

Popular posts from this blog

Notes on "Focal Loss for Dense Object Detection" (RetinaNet)

Introduction to Bayesian methods

Conjugate Priors