Noname manuscript No. (will be inserted by the editor) A Tutorial on Variational Bayesian Inference Charles Fox · Stephen Roberts Received: date / Accepted: date Abstract This tutorial describes the mean-field variational Bayesian approximation to inference in graphical models, using modern machine learning terminology rather than statistical physics concepts. It begins by seeking to find an approximate meanfield distribution close to the target joint in the KL-divergence sense. It then derives local node updates and reviews the recent Variational Message Passing framework. Keywords Variational Bayes · mean-field · tutorial 1 Introduction Variational methods have recently become popular in the context of inference problems, [1], [4]. Variational Bayes is a particular variational method which aims to find some approximate joint distribution Q(x; θ) over hidden variables x to approximate the true joint P (x), and defines ‘closeness’ as the KL divergence KL[Q(x; θ)||P (x)]. The mean-field form of VB assumes that Q factorises into single-variable factors, Q Q(x) = i Qi (xi |θi ). The asymetric KL[Q||P ] is chosen in VB principally to yield useful computational simplifications, but can be viewed as preferring approximation where areas of high Q are accurate, rather than areas of high P . This is often useful because if we were to draw samples from, or integrate over, Q, then the areas used will be largely accurate (though they may well miss out areas of high P ). 1.1 Problem statement We wish to find a set of distributions {Qi (xi ; θi )} to minimise the KL-divergence: Charles W. Fox Adaptive Behaviour Research Group University of Sheffield, UK charles.fox@sheffield.ac.uk Stephen Roberts Pattern Analysis and Machine Learning Research Group Department of Engineering Science University of Oxford, UK 2 µ π µ Q γ P x x (a) (b) Fig. 1 (a) Graphical model for a population mean problem. Square nodes indicate observed variables. (b) True joint P and VB approximation Q. KL[Q(x)||P (x|D)] = Z dx.Q(x) ln Q(x) P (x|D) where D is the observed data and x are the unobserved variables and Q(x) = Y Qi (xi |θi ) i We sometimes omit the notational dependence of Qi on θi for clarity. As the Qi are approximate beliefs, they are subject to the normalisation constraints: ∀i. Z dxi Qi (xi ) = 1. 1.2 VB approximates joints, not marginals Q approximates the joint, but the individual Qi (xi ) are poor approximations to the true marginals Pi (xi ). The Qi (xi ) components should not be expected to resemble – even remotely – the true marginals, for example when inspecting the computational states of network nodes. This makes VB considerably harder to debug than algorithms whose node states do have some local interpretation: the optimal VB components can be highly counter-intuitive and make sense only in the global context. Fig. 1(a) shows a graphical model of an extreme example of mean-field VB breaking down, as a warning for the algorithms that follow. Suppose a factory produces screws of unknown diameter µ. We know the machines are very accurate so the precision γ = 1/σ 2 of the diameters is high. However no-one has told us what the mean diameter 3 is, except for a drunken engineer in a pub who said is was 5mm. We might have a prior belief π that µ is say 5±4mm (the 4mm deviation reflecting our low confidence in the report.) Suppose we are given a sealed box containing one screw. What is our belief in its diameter x? The exact marginal belief in x should be almost identical to our belief in µ, i.e. 5±4mm. Fig. 1(b) shows one standard deviation of the true Gaussian joint P (µ, x) and the best mean-field Gaussian approximate joint Q(µ, x). As discussed above, this Q is chosen to minimise error in its own domain so it appears as a tight Gaussian. Importantly, the marginal Qx (x) is now very small, and not at all equal to the marginal P (x). We emphasise such cases because we found them to be the major cause of debugging time during development. We also want to find the model likelihood, P (D|M ) for model comparison. This quantity will appear naturally as we try to minimise the KL distance. There are many ways to think about the following derivations, which we will see involves a balance of three terms called lower bound, energy and entropy. The derivation presented here begins by aiming to minimise the above KL distance. For example, other equivalent derivations may begin by aiming to maximise the lower bound. 1.3 Rewriting KL optimisation as an easier problem We will rewrite the KL equation in terms that are more tractable. First we flip the numerator and denominator, and flip the sign: KL[Q(x)||P (x|D)] = Z dx.Q(x) ln Q(x) =− P (x|D) Z dx.Q(x) ln P (x|D) Q(x) Next, replace the conditional P (x|D with a joint P (x, D) and a prior P (D). The reason for making this rewrite is that for Bayesian networks with exponential family nodes, the log P (x, D) term will be a be a very simple sum of node energy terms, whereas log P (x|D) is more complicated. This will simplify later computations. P (x, D) = P (x|D)P (D) ⇒ ln P (x|D) = ln P (x, D) − ln P (D) KL[Q(x)||P (x|D)] = − Z  dx Q(x) ln P (x, D) + ln P (D) Q(x)  The log P (D) term does not involve Q so we can ignore it for the purposes of our minimisation. Finally, define   Z P (x, D) L[Q(x)] = dx Q(x) ln Q(x) So that KL[Q(x)||P (x|D)] = −L + ln P (D) So to minimize the KL divergence, we must maximize L. 4 The maximisation is still subject to the normalisation constraints: ∀i. Z dxi Qi (xi ) = 1 L[Q(x)] are lower bounds on the model log-likelihood, P (D) = P (D|M ) (where we generally drop the M notation as we are working with a single model only). The best bound is thus achieved when L[Q(x)] is maximised over Q. The reason for L being a lower bound is seen by rearranging as: ln P (D) = L[Q(x)] + KL[Q(x|D)||P (x)] Thus when the KL-divergence is zero (a perfect fit), L is equal to the model loglikelihood. When the fit is not perfect, the KL-divergence is always positive and so L[Q(x)] < ln P (D). Another rearrangement gives L[Q(x)] = ln P (D) − KL[Q(x|D)||P (x)] showing that the KL divergence is the error between L and ln P (D). 1.4 Solution of free energy optimisation We now wish to find Q to maximise the lower bound, subject to normalisation constraints, L[Q(x)] = = Z Z dxQ(x) ln dxQ(x) log P (x, D) − Z P (x, D) Q(x) dxQ(x) ln Q(x) = hE(x, D)iQ(x) + H[Q(x)] where we define energy as E = ln P and entropy1 as H[Q(x)] = − dxQ(x) ln Q(x). For exponential models, this will become a convenient sum of linear functions. By the mean field assumption: R L[Q(x)] = Z dx Y Qi (xi ) i ! E(x, D) − Z dx Y Qk (xk ) k ! Consider the entropy (rightmost) term. We can bring out the sum: XZ i dx Y Qk (xk ) k ! ln Qi (xi ) Consider the partitions x = {xi , x̄i } where x̄i = x\xi . 1 This is Shannon entropy, used by convention. X i ln Qi (xi ) (1) 5 = XZ dxi dx̄i Qi (x̄i )Qi (xi ) ln Qi (xi ) i = XZ h dxi Qi (xi ) ln Qi (xi )iQ(x̄i ) i = XZ dxi Qi (xi ) ln Qi (xi ) i Substituting this into the right term of equation 1: L[Q(x)] = Z Y dx Q(xi ) i ! E(x, D) − XZ dxi .Q(xi ) ln Q(xi ) (2) i Now look at and rearrange the energy (left) term of equation 2, again separating out one variable: Z dx Y ! E(x, D) = = Z Z dxi Qi (xi ) ln exphE(x, D)iQ(x̄i ) = Z Qi (xi ) i = Z dxi Qi (xi ) Z dx̄i .Q(x̄i )E(x, D) dxi Qi (xi )hE(x, D)iQ(x̄i ) dxi Qi (xi ) ln Q∗i (xi ) + ln Z where we have defined Q∗i (xi ) = Z1 exphE(x, D)iQ(x̄i ) and Z normalises Q∗i (xi ). Substituting this new form of the energy back into equation 2 yields L[Q(x)] = Z dxi Qi (xi ) ln Q∗i (xi ) − XZ dxi Qi (xi ) ln Qi (xi ) + ln Z i Separate out the entropy for Hi = H[Qi (xi )] from the rest of the entropy sum: L[Q(x)] = Z dxi Qi (xi ) ln Q∗i (xi ) − Z dxi Qi (xi ) ln Qi (xi )  + H[(Q(x̄i )] + ln Z Consider the terms in the brackets: Z dxi Qi (xi ) ln Q∗i (xi )− Z dxi Qi (xi ) ln Qi (xi ) = Z dxi Qi (xi ) ln Q∗i (xi ) = −KL[Qi (xi )||Q∗i (xi )] Qi (xi ) What a lucky co-incidence! Though we started by trying to minimise the KLdivergence between large joint distributions (which is hard), we have converted the problem to that of minimising KL-divergences between individual 1D distributions (which is easier). Write: L[Q(x)] = −KL[Qi (xi )||Q∗ (xi )] + H[Qi (x̄i )] + ln Z 6 Thus L depends on each individual Qi only through the KL term. We wish to maximise L with respect to each Qi , subject to the constraint that all Qi are normalized to unity. This could be achieved by Lagrange multipliers and functional differentiation: δ δQi (xi )  −KL[Qi (xi )||Q∗ (xi )] − λi ( Z Qi (xi )dxi − 1) si  := 0 A long algebraic derivation would then eventually lead to a Gibbs distribution. However, thanks to the KL form rearrangement we do not need to perform any of this, because we can see immediately that L will be maximised when the KL divergence is zero, hence when Q(xi ) = Q∗ (xi ) (The normalisation constraint on Qi is satisfied thanks to the inclusion of Z in the previous definition of Q∗i ). Expanding back the definition gives the optimal Qi to be Q(xi ) = 1 exphE(xi , x̄i , D)iQ(x̄i ) Z where E(xi , x̄i , D) = log P (xi , x̄i , D) is the energy. 1.5 Solution as iterative update equations Converting the equation above into an iterative update equation gives: Q(xi ) ← 1 exphE(xi , x̄i , DiQ(x̄i ) Z where xj is a hidden node to be updated; D are the observed data, and x¯i = (x\xi ) are the other hidden nodes. These updates give us an EM-like algorithm, optimising one node at a time in the hope of converging to a local minimum. Graphical models define conditional independence relationships between nodes, so for models having such structure there exists a Markov blanket set of nodes mb(xi ) for each node xi such that P (xi |x¯i ) = P (xi |mb(xi )). For undirected graphical model links, the Markov blanket is the net of neighbouring nodes of xi ; for directed Bayesian networks it is the neighbours of xi augmented with the set of co-parents of xi . By the definition of Markov blankets, we may write ln Q(xi ) ← hln P (xi , mb(xi ), DiQ(mb(xi )) where mb(xi ) is the Markov blanket of the node xi of interest. 1.6 Example: variational Bayesian mean update Until very recently, variational implementations have consisted of calculating the iterative update equations by hand for each specific network model. These calculations required much algebra and were error-prone. We consider the simple graphical model shown in fig. 2: the task is to infer the mean µ component of the posterior joint of a Gaussian population of unknown precision γ given a set of M observations D = {Di }M i=1 drawn from this population. 7 m β a µ D1 b γ ... D2 DM Fig. 2 Graphical model for variational example. Square nodes are observed. Making the mean-field approximation and assuming conjugate exponential priors we have: Q(x) = Q(µ)Q(γ) = N (µ; m, β −1 )Γ (γ; a, b) where m, β, a and b are constants specifying conjugate priors on the population parameters. The update equation is ln Q(xi ) ← hln P (xi , mb(xi ), DiQ(mb(xi )) Substituting the variables for the mean update, xi = µ: ln Q(µ) ← = = = Z Z Z Z dγΓ (γ; a, b) ln P (µ, γ, {Di }) dγΓ (γ; a, b) ln N (µ; m, β −1 )Γ (γ; a, b) dγΓ (γ; a, b) Y N (Di |µ, γ) i ( ln N (µ; m, β dγΓ (γ; a, b) ln N (µ; m, β −1 )+ −1 ) + ln Γ (γ; a, b) + X ln N (Di |µ, γ) i Z dγΓ (γ; a, b) ln Γ (γ; a, b)]+ XZ ) dγΓ (γ; a, b) ln N (Di |µ, γ) i This simplifies greatly. First, integrands not dependent on µ can be discarded and replaced by a normalising Z, since we are only interested in the PDF for Q(µ) : ln Q(µ) = Z dγΓ (γ; a, b) ln N (µ; m, β −1 ) + XZ dγΓ (γ; a, b) ln N (Di |µ, γ) + ln i 1 Z The left term does depend on µ but not on γ, so the integral over the normalised Q(γ) has no effect. Its integral is simply its integrand, so we can write ln Q(µ) = ln N (µ; m, β −1 ) + XZ i dγΓ (γ; a, b) ln N (Di |µ, γ) + ln 1 Z 8 We next consider the integral containing data Di . Writing out its log Gaussian energy expression in full, this term becomes Z dγΓ (γ; a, b) ln N (Di |µ, γ) = Z dγΓ (γ; a, b)[c + (Di − µ)γ(Di − µ)] where c is its normalising constant. Rewriting up to normalisation gives ∝ (D − µ) Z  dγΓ (γ; a, b)γ (D − µ) The required integral is now just the expectation (first moment) of a Gamma distribution, which can be found in standard statistics tables: hγi = E[Γ (γ; a, b)] = ab−1 (We notate pre- and post-multiplications by (D −µ) rather than a single multiplication by (D − µ)2 so that the multivariate Wishart case follows the same derivation and notation.) The term is now ∝ (D − µ)hγi(D − µ) Substituting this back into the equation for ln Q(µ) gives ln Q(µ) = ln N (µ; m, β −1 ) + X (Di − µ)hγi(Di − µ) + ln i 1 Q(µ) = N (µ; m, β −1 ) exp Z Q(µ) = ( X (Di − µ)hγi(Di − µ) i 1 Z ) Y 1 N (µ; m, β −1 ) N (Di |µ, hγi) Z i Switching round the dependent variable we obtain the following. (For µ in Gaussian distribution, the flipped result is still Gaussian. For other distributions, it will be conjugate.) Q(µ) = Y 1 N (µ; m, β −1 ) N (µ|Di , hγi) Z i Finally we can use the standard equation for product of Gaussians to give Q(µ) = N (µ|m′ , β ′−1 ) with β ′ = β + M hγi m′ = β ′−1 (βm + hγi M X Di ) i=1 The above illustrates how the general VB updates can be transformed into particular updates for graphical models. The Q(µ) component is meaningless by itself as VB aims to approximate the full joint rather than local marginals, so a more useful analysis would repeat the above to obtain the Q(γ) updates as well – requiring even more algebra. We see that this process is time-consuming and error prone even for very simple networks such as the Gaussian population used here. 9 1.7 Variational Bayes with message passing The previous method of by-hand derivation was time-consuming and tedious, though until recently was state-of-the-art. However the recent variational message passing (VMP) algorithm ([3], [4]) has shown how to automate these derivations in the case of conjugate-exponential networks. For such networks, the updates all have a standard form, involving the sufficient statistics and natural parameters of the particular node types. For unavoidable non-conjugate-exponential nodes (such as mixture models in particular) it is possible to make further approximations to bring them into the standard form. For conjugate-exponential networks, VMP should now be the standard method for variational Bayesian inference, replacing derivations by hand. For nonconjugate-exponential networks, VMP may still be useful if fast approximations are required at the expense of accuracy. 1.8 VMP details A standard theorem [2] about exponential family distributions shows that the expectation of sufficient statistics are given simply by: hu(xi )i = ▽φ g(φ)|φ where the Dell symbol (▽) means we form a vector whose rth component is the derivative of g with respect to the rth component of the vector φ. We will write xi ’s set of parents as pa(xi ); its set of children as ch(xi ); its set of co-parents with respect to particular child ch as cop(xi ; ch); and its set of co-parents over all children as cop(xi ). We wish to compute the variational Bayesian update as in the previous section: Qi (xi ) ← hln P (xi , mb(xi ), DiQ(mb(xi )) Assuming the effects of D are already in the Markov blanket nodes, and separating, this becomes hln P (pa(xi )) + ln P (cop(xi )) + ln P (xi |pa(xi )) + ln P (ch(xi )|xi , cop(xi ))iQ(mb(xi )) Simplifying and dropping constant parent and co-parent terms: = hln P (xi |pa(xi ))iQ(pa(xi )) + hln P (ch(xi )|xi , cop(xi ))iQ(ch(xi ),cop(xi )) The children separate: = hln P (xi |pa(xi ))iQ(pa(xi )) + X hln P (ch|xi , cop(xi ; ch)iQ(ch,cop(xi ;ch)) ch∈ch(xi ) We will consider the two parts one at a time. 10 1.8.1 Messages from parents A conjugate-exponential node xi is parametrised by a natural parameter vector φi . By the definition of such nodes, hln P (xi |pa(xi ))iQ(pa(xi )) = hφi u(xi ) + fi (xi ) + gi (φi )iQ(pa(xi )) = hφi iQ(pa(xi )) ui (xi ) + fi (xi ) + hg(φi )iQ(pa(xi )) As φ and g are multi-linear functions of the parent sufficient statistics (by construction), and using the mean field assumption, we can simply take their formulae (defined as conditional on single values of the parents) and substitute the expectations for the sufficient statistics, to get the expectation of the whole expression as required. So parents of xi need only to send their sufficient statistic expectations to xi as messages. (A concrete example is shown in section 1.9.) 1.8.2 Messages to parents A key property of the exponential family is that we can multiply (fuse) similar distributions by adding their natural parameter vectors φ: exp{φ1 u(xi ) + f (xi ) + g(φ1 )} exp{φ1 u(xi ) + f (xi ) + g(φ1 )} = exp{(φ1 + φ2 )u(xi ) + f (xi ) + g(φ1 + φ2 )} A second property is that by conjugacy, φ and g are also multi-linear in parental sufficient statistics. So we can always rearrange the formula by finding functions φij , fj , gij to make it look like a function of a parent xj ∈ pa(xi ): hln P (xi |pa(xi ))iQ(pa(xi )) = hφij uj (xj ) + fij (xj ) + gij (φij )iQ(pa(xi )) As before, we may handle the expectation by using the multi-linear property to push all the expectations tight around the sufficient statistics. So from the point of view of the parent, this is written in terms of the sufficient statistic expectations of its child and co-parents. We can thus pass a likelihood message consisting of  φij hu(xi )i, {hu(cop)i}cop∈cop(xj ;xi )  The parent may then simply add these to its prior parameters, by the first property. Observed data nodes D may be treated as Delta distributions which send standard messages to their parents and children. 1.9 Example: mean, precision and data using VMP To demonstrate the power of the VMP formalism, we consider the same scenario as used to hand-derive the VB updates in section 1.6. Using VMP we can quickly substitute the particular Gaussian and Gamma distributions into the VMP update equations and quickly obtain all network messages for mean, precision, and data nodes as follows. (Messages to D are not required in this particular example, but are useful in general for making inferences about unobserved data.) The exponential forms used here are standard [2]. 11 1.9.1 Mean node (with known prior parameters) Beginning with the conjugate-exponential form of the Gaussian distribution: ln P (µ|m, β) = with g(m, β) =    µ mβ . −β/2 µ2  − g(m, β) 1 (ln β − βm2 − ln 2π) 2 The message to child D is the expectation of sufficient statistics: h   µ i = ▽φ g(φ)|φ = µ2  µ µ + β −1 2  1.10 Computing the log-likelihood bound using VMP VMP also makes computation of the log-likelihood bound simple. Recall that ln P (D|M ) ≥ L[Q(x)] L[Q(x)] = Z dxQ(x) log P (x, D) Q(x) = hlog P (x, D)iQ(x) − hlog Q(x)iQ(x) Writing as the sum of individual node contributions from the universe of all (data and hidden) nodes Ω = x ∪ D, = X ω∈Ω hlog P (ω|pa(ω))iQ(ω,pa(ω)) − hlog Q(ω)iQ(ω) = X Lω ω∈Ω with Lω = hlog P (ω|pa(ω))iQ(ω,pa(ω)) − hlog Q(ω)iQ(ω) π ∗ ∗ = hφπ ω u(ω) + f (ω) + g(φω )iQ(ω,pa(ω)) − hφω u(ω) + f (ω) + g(φω )iQ(ω) where φπ are the prior natural parameters conditioned only on the parents, and φ∗ are the posterior parameters, after child message are fused. Q(ω) uses the posterior parameters. By multi-linearity we can push in the expectations and simplify to obtain: ∗ π ∗ Lω = (hφπ ω iQ(pa(ω)) + φω )hu(ω)iQ(ω) + hg(φω )iQ(pa(ω)) + g(φω ) which is simple to compute locally, from each node’s received messages. (The term hφπ ω iQ(pa(ω)) is computed by substituting in the received parent sufficient statistics expectations into the conjugate-exponential formula for φω . By multi-linearity, the expectation can be pushed into the sufficient statistics.) The global model log-likelihood bound is computed by summing the contributions from all nodes, both hidden and observed. 12 1.11 Summary In this tutorial we have seen how variational methods may be used to approximate joint posteriors with mean-field distributions. An long-hand calculation of variational inference was shown, then the more general variational message passing framework introduced, which greatly simplifies calculations for conjugate-exponential networks. Variational methods are not appropriate when the marginals or the correlations structure of the joint are required, but are useful in model comparison tasks when the joint is integrated out. References 1. Hagai Attias. A variational Bayesian framework for graphical models. In Advances in Neural Information Processing Systems. MIT Press, 2000. 2. J.M. Bernardo and A.F.M. Smith. Bayesian Theory. Wiley, 2000. 3. C.M Bishop, J.M.Winn, and D. Spiegelhalter. VIBES: A variational inference engine for Bayesian networks. In Advances in Neural Information Processing Systems, 2002. 4. J. Winn and C. Bishop. Variational message passing. Journal of Machine Learning Research, 6:661–694, 2005.