Part 10 & 11 - Variational Inference

Introduction

Last time, we introduced Monte Carlo methods for (approximate) inference.

In this part, we will introduce Variational Inference, which is another popular method for approximate inference.

Notation

We will use the following notation,

  • z\mathbf{z}: set of latent variables and parameters.
  • x\mathbf{x}: set of observed variables (data).

Thus, the goal (as last time), given a probabilistic model that specifices p(x,z)p(\mathbf{x}, \mathbf{z}), we want to find an approximation of the posterior p(zx)p(\mathbf{z} \mid \mathbf{x}).

Deterministic Approximate Inference

Intuition (Deterministic Approximate Inference)

If we can approximate a complex posterior distribution p(zx)p(\mathbf{z} \mid \mathbf{x}) by a tractable distribution q(z)Ωq(\mathbf{z}) \in \Omega that is close to p(zx)p(\mathbf{z} \mid \mathbf{x}). Here, Ω\Omega is a tractable family of densities over the latent variables z\mathbf{z}. Thus, each q(z)Ωq(\mathbf{z}) \in \Omega is a candidate approximation to the true posterior p(zx)p(\mathbf{z} \mid \mathbf{x}).

But how do we differentiate for the best candidate (i.e., closest to p(zx)p(\mathbf{z} \mid \mathbf{x}))? Given a definition of “discrepancy” between q(z)q(\mathbf{z}) and p(zx)p(\mathbf{z} \mid \mathbf{x}), we can define free parameters of q(z)q(\mathbf{z}) to minimize this discrepancy.

The Kullback-Leibler (KL) Divergence

Definition 1 (Kullback-Leibler (KL) Divergence)

The Kullback-Leibler (KL) divergence is a measure of how one probability distribution q(z)q(\mathbf{z}) diverges from a second expected probability distribution p(z)p(\mathbf{z}). It is defined as,

KL(p(x)q(x))=p(x)log(p(x)q(x)) dx=Ep(x)[log(p(x)q(x))]= p(x)log(q(x)p(x)) dx\begin{align*} \mathrm{KL}(p(\mathbf{x}) \mid \mid q(\mathbf{x})) & = \int p(\mathbf{x}) \log \left(\frac{p(\mathbf{x})}{q(\mathbf{x})}\right) \ d\mathbf{x} \newline & = \mathbb{E}_{p(\mathbf{x})} \left[\log \left(\frac{p(\mathbf{x})}{q(\mathbf{x})}\right)\right] \newline & = -\ \int p(\mathbf{x}) \log \left(\frac{q(\mathbf{x})}{p(\mathbf{x})}\right) \ d\mathbf{x} \end{align*}

The KL divergence has the following properties,

  1. KL(p(x)q(x))0\mathrm{KL}(p(\mathbf{x}) \mid \mid q(\mathbf{x})) \geq 0 (non-negativity).

  2. KL(p(x)q(x))=0\mathrm{KL}(p(\mathbf{x}) \mid \mid q(\mathbf{x})) = 0 if and only if p(x)=q(x)p(\mathbf{x}) = q(\mathbf{x}) almost everywhere.

  3. It is not symmetric, i.e., KL(p(x)q(x))KL(q(x)p(x))\mathrm{KL}(p(\mathbf{x}) \mid \mid q(\mathbf{x})) \neq \mathrm{KL}(q(\mathbf{x}) \mid \mid p(\mathbf{x})).

Deterministic Approximate Inference (Contd.) and Variational Inference

Intuition (Deterministic Approximate Inference (Contd.))

Thus, we have two possibilities.

  • Variational Inference 1One important application and more recent implemenation of this is Variational Autoencoders!: Minimize the reverse KL divergence,
q(z)=argminq(zΩ KL(q(z)p(zx))q^{\star}(\mathbf{z}) = \underset{q(\mathbf{z} \in \Omega}{\arg\min} \ \mathrm{KL}(q(\mathbf{z}) \mid \mid p(\mathbf{z} \mid \mathbf{x}))
  • Expectation Propagation: Minimize the forward KL divergence,
q(z)=argminq(zΩ KL(p(zx)q(z))q^{\star}(\mathbf{z}) = \underset{q(\mathbf{z} \in \Omega}{\arg\min} \ \mathrm{KL}(p(\mathbf{z} \mid \mathbf{x}) \mid \mid q(\mathbf{z}))
Purple: Bimodal (target) distribution. Green: Single Gaussian (); Left is the forward KL div., Middle and right is the reverse KL div.
Purple: Bimodal (target) distribution. Green: Single Gaussian (); Left is the forward KL div., Middle and right is the reverse KL div.
Intuition (Variational Inference)

So, minimizing the reverse KL divergence corresponds to Variational Inference. However, this is still not tractable, as it requires knowledge of the true posterior p(zx)p(\mathbf{z} \mid \mathbf{x}).

But, we can rewrite KL(q(z)p(zx))\mathrm{KL}(q(\mathbf{z}) \mid \mid p(\mathbf{z} \mid \mathbf{x})) as,

KL(q(z)p(zx))=q(z)log(p(zx)q(z)) dz=q(z)log(p(x,z)q(z)p(x)) dz=logp(x)q(z)log(p(x,z)q(z)) dzL(q)=logp(x)+KL(q(z)p(zx))\begin{align*} \mathrm{KL}(q(\mathbf{z}) \mid \mid p(\mathbf{z} \mid \mathbf{x})) & = - \int q(\mathbf{z}) \log \left(\frac{p(\mathbf{z} \mid \mathbf{x})}{q(\mathbf{z})}\right) \ d\mathbf{z} \newline & = - \int q(\mathbf{z}) \log \left(\frac{p(\mathbf{x}, \mathbf{z})}{q(\mathbf{z}) p(\mathbf{x})}\right) \ d\mathbf{z} \newline & = \log p(\mathbf{x}) - \underbrace{\int q(\mathbf{z}) \log \left(\frac{p(\mathbf{x}, \mathbf{z})}{q(\mathbf{z})}\right) \ d\mathbf{z}}_{\eqqcolon \mathcal{L}(q)} \newline & = \log p(\mathbf{x}) + \mathrm{KL}(q(\mathbf{z}) \mid \mid p(\mathbf{z} \mid \mathbf{x})) \newline \end{align*}

Further, it follows that,

logp(x)=logp(x,z) dz=logq(z)p(x,z)q(z) dz=log(Eq(z)[p(x,z)q(z)])Eq(z)[log(p(x,z)q(z))](Jensen’s inequality)=q(z)log(p(x,z)q(z)) dzL(q)\begin{align*} \log p(\mathbf{x}) & = \log \int p(\mathbf{x}, \mathbf{z}) \ d\mathbf{z} \newline & = \log \int q(\mathbf{z}) \frac{p(\mathbf{x}, \mathbf{z})}{q(\mathbf{z})} \ d\mathbf{z} \newline & = \log \left(\mathbb{E}_{q(\mathbf{z})} \left[\frac{p(\mathbf{x}, \mathbf{z})}{q(\mathbf{z})}\right]\right) \newline & \geq \mathbb{E}_{q(\mathbf{z})} \left[\log \left(\frac{p(\mathbf{x}, \mathbf{z})}{q(\mathbf{z})}\right)\right] \quad \text{(Jensen's inequality)} \newline & = \int q(\mathbf{z}) \log \left(\frac{p(\mathbf{x}, \mathbf{z})}{q(\mathbf{z})}\right) \ d\mathbf{z} \newline & \triangleq \mathcal{L}(q) \end{align*}

L(q)\mathcal{L}(q) is called the Evidence Lower Bound (ELBO), as it provides a lower bound on the log-evidence logp(x)\log p(\mathbf{x}).

Thus, solving,

q(z)=argminq(zΩ KL(q(z)p(zx))q^{\star}(\mathbf{z}) = \underset{q(\mathbf{z} \in \Omega}{\arg\min} \ \mathrm{KL}(q(\mathbf{z}) \mid \mid p(\mathbf{z} \mid \mathbf{x}))

is equivalent to solving,

q(z)=argmaxq(zΩ L(q)q(z)log(p(x,z)q(z)) dzq^{\star}(\mathbf{z}) = \underset{q(\mathbf{z} \in \Omega}{\arg\max} \ \mathcal{L}(q) \triangleq \int q(\mathbf{z}) \log \left(\frac{p(\mathbf{x}, \mathbf{z})}{q(\mathbf{z})}\right) \ d\mathbf{z}

which in general is (still) intractable!

But, we can choose a parametric distribution q(zω)q(\mathbf{z} \mid \boldsymbol{\omega}) that is tractable, but rich enough to provide a good approximation of the true posterior. In this setting L(q)\mathcal{L}(q) becomes a function of ω\boldsymbol{\omega}, i.e., L(ω)\mathcal{L}(\boldsymbol{\omega}), thus we can exploit standard nonlinear optimization methods to find optimal parameters.

But we can also restrict q(z)q(\mathbf{z}) such that it factorizes as,

q(z)=i=1Mqi(zi)q(\mathbf{z}) = \prod_{i = 1}^M q_i(\mathbf{z}_i)

where z1,,zM\mathbf{z}_1, \ldots, \mathbf{z}_M are disjoint partitions of z\mathbf{z}.

Intuition (Mean-Field Variational Inference)

This is called Mean-Field Variational Inference. In this case, we are solving the optimization problem,

maxq1,,qM L(q)\underset{q_1, \ldots, q_M}{\max} \ \mathcal{L}(q)

i.e., amongst all q(z)=i=1Mqi(zi)q(\mathbf{z}) = \prod_{i = 1}^M q_i(\mathbf{z}_i), we want to find distribution with largest L(q)\mathcal{L}(q).

Further, if you are familiar with optimization, we will optimize the ELBO using coordinate ascent, i.e., optimize one factor qj(zj)q_j(\mathbf{z}_j) at a time, while keeping the others fixed.

Derivation (Solving Mean-Field Variational Inference with Coordinate Ascent)qj(zj)=argmaxq(zΩ L(q)q(z)log(p(x,z)q(z)) dzq^{\star}_j(\mathbf{z}_j) = \underset{q(\mathbf{z} \in \Omega}{\arg\max} \ \mathcal{L}(q) \triangleq \int q(\mathbf{z}) \log \left(\frac{p(\mathbf{x}, \mathbf{z})}{q(\mathbf{z})}\right) \ d\mathbf{z}

with,

q(z)=i=1Mqi(zi)q(\mathbf{z}) = \prod_{i = 1}^M q_i(\mathbf{z}_i)

Singling out terms that involve qj(zj)q_j(\mathbf{z}_j), we have,

L(q)=iqi(logp(x,z)klogqk(zk)) dz=(iqilogp(x,z) dz)(iqi(klogqk) dz)=(iqilogp(x,z) dz)(qjlogqj dzj)(iqi(kjlogqk) dz)\begin{align*} \mathcal{L}(q) & = \int \prod_i q_i \left(\log p(\mathbf{x}, \mathbf{z}) - \sum_k \log q_k(\mathbf{z}_k)\right) \ d\mathbf{z} \newline & = \left(\int \prod_i q_i \log p(\mathbf{x}, \mathbf{z}) \ d\mathbf{z}\right) - \left(\int \prod_i q_i \left(\sum_k \log q_k \right) \ d\mathbf{z}\right) \newline & = \left(\int \prod_i q_i \log p(\mathbf{x}, \mathbf{z}) \ d\mathbf{z}\right) - \left(\int q_j \log q_j \ d\mathbf{z}_j\right) - \left(\int \prod_i q_i \left(\sum_{k \neq j} \log q_k \right) \ d\mathbf{z}\right) \newline \end{align*}

Let’s focus term-by-term. First term,

iqilogp(x,z) dz=qj(zj)(logp(x,z)ijqi(zi)) dzj=qjE{zi}ijijqi(zi)[logp(x,z)] dzj=qjEij[logp(x,z)] dzj\begin{align*} \int \prod_i q_i \log p(\mathbf{x}, \mathbf{z}) \ d\mathbf{z} & = \int q_j(\mathbf{z}_j) \left(\int \log p(\mathbf{x}, \mathbf{z}) \prod_{i \neq j} q_i(\mathbf{z}_i) \right) \ d\mathbf{z}_j \newline & = \int q_j \mathbb{E}_{\{\mathbf{z}_i\}_{i \neq j} \sim \prod_{i \neq j} q_i(\mathbf{z}_i)} [\log p(\mathbf{x}, \mathbf{z})] \ d\mathbf{z}_j \newline & = \int q_j \mathbb{E}_{i \neq j} [\log p(\mathbf{x}, \mathbf{z})] \ d\mathbf{z}_j \newline \end{align*}

For the second term,

iqilogqj dz=qjlogqjijqi dzj dzij=(qjlogqj dzj)(ijqi dzij)=qjlogqj dzj\begin{align*} \int \prod_i q_i \log q_j \ d\mathbf{z} & = \int q_j \log q_j \prod_{i \neq j} q_i \ d\mathbf{z}_j \ d\mathbf{z}_{i \neq j} \newline & = \left(\int q_j \log q_j \ d\mathbf{z}_j\right) \left(\int \prod_{i \neq j} q_i \ d\mathbf{z}_{i \neq j}\right) \newline & = \int q_j \log q_j \ d\mathbf{z}_j \newline \end{align*}

Lastly, the third term,

iqi(kjlogqk) dz=qjijqi(kjlogqk) dzj dzij=(qj dzj)(ijqi(kjlogqk) dzij)=ijqi(kjlogqk) dzij\begin{align*} \int \prod_i q_i \left(\sum_{k \neq j} \log q_k \right) \ d\mathbf{z} & = \int q_j \prod_{i \neq j} q_i \left(\sum_{k \neq j} \log q_k \right) \ d\mathbf{z}_j \ d\mathbf{z}_{i \neq j} \newline & = \left(\int q_j \ d\mathbf{z}_j\right) \left(\int \prod_{i \neq j} q_i \left(\sum_{k \neq j} \log q_k \right) \ d\mathbf{z}_{i \neq j}\right) \newline & = \int \prod_{i \neq j} q_i \left(\sum_{k \neq j} \log q_k \right) \ d\mathbf{z}_{i \neq j} \newline \end{align*}

Note that here we are left with a constant (w.r.t. qjq_j)!

Thus, putting everything together, we have,

L(q)=qjEij[logp(x,z)] dzjqjlogqj dzj+const=qjlogp~(x,zj) dzjqjlogqj dzj+const=qjlog(p~(x,zj)qj) dzj+const= KL(qj(zj)p~(x,zj))+const\begin{align*} \mathcal{L}(q) & = \int q_j \mathbb{E_{i \neq j} [\log p(\mathbf{x}, \mathbf{z})]} \ d\mathbf{z}_j - \int q_j \log q_j \ d\mathbf{z}_j + \text{const} \newline & = \int q_j \log \tilde{p}(\mathbf{x}, \mathbf{z}_j) \ d\mathbf{z}_j - \int q_j \log q_j \ d\mathbf{z}_j + \text{const} \newline & = \int q_j \log \left(\frac{\tilde{p}(\mathbf{x}, \mathbf{z}_j)}{q_j}\right) \ d\mathbf{z}_j + \text{const} \newline & = - \ \mathrm{KL}(q_j(\mathbf{z}_j) \mid \mid \tilde{p}(\mathbf{x}, \mathbf{z}_j)) + \text{const} \newline \end{align*}

where logp~(x,zj)Eij[logp(x,z)]\log \tilde{p}(\mathbf{x}, \mathbf{z}_j) \coloneqq \mathbb{E}_{i \neq j} [\log p(\mathbf{x}, \mathbf{z})].

Thus, if we go back to our optimization problem,

qj(zj)=argmaxqj L(q)=argmaxqj  KL(qj(zj)p~(x,zj))+const=argminqj KL(qj(zj)p~(x,zj))=p~(x,zj)=exp(Eij[logp(x,z)]+const)\begin{align*} q^{\star}_j(\mathbf{z}_j) & = \underset{q_j}{\arg\max} \ \mathcal{L}(q) \newline & = \underset{q_j}{\arg\max} \ - \ \mathrm{KL}(q_j(\mathbf{z}_j) \mid \mid \tilde{p}(\mathbf{x}, \mathbf{z}_j)) + \text{const} \newline & = \underset{q_j}{\arg\min} \ \mathrm{KL}(q_j(\mathbf{z}_j) \mid \mid \tilde{p}(\mathbf{x}, \mathbf{z}_j)) \newline & = \tilde{p}(\mathbf{x}, \mathbf{z}_j) \newline & = \exp \left(\mathbb{E}_{i \neq j} [\log p(\mathbf{x}, \mathbf{z})] + \text{const}\right) \newline \end{align*}

or, equivalently,

logqj(zj)=Eij[logp(x,z)]+const\log q^{\star}_j(\mathbf{z}_j) = \mathbb{E}_{i \neq j} [\log p(\mathbf{x}, \mathbf{z})] + \text{const}
Algorithm (Mean-Field Variational Inference with Coordinate Ascent)
  • Initialization: Set {qi(zi)}\{q_i(\mathbf{z}_i)\}.
  • For =1,,max\ell = 1, \ldots, \ell_{\text{max}}:
    • Fix {qi(zi)}ij\{q_i(\mathbf{z}_i)\}_{i \neq j} to their last estimated values qi(zi)q_i^{\star}(\mathbf{z}_i).
    • Update qj(zj)q_j^{\star}(\mathbf{z}_j) as,
qj(zj)=exp(Eij[logp(x,z)]+const)q_j^{\star}(\mathbf{z}_j) = \exp \left(\mathbb{E}_{i \neq j} [\log p(\mathbf{x}, \mathbf{z})] + \text{const}\right)
  • Normalize qj(zj)q_j^{\star}(\mathbf{z}_j).
  • Repeat until ELBO (L(q))(\mathcal{L}(q)) converges.

Variational Linear Regression

Example 1 (Variational Linear Regression)

We have previously used probabilistic models and joint distributions to solve the linear regression problem. Here, we will use Variational Inference to solve the same problem.

Recall (Predictive Distribution in Bayesian Linear Regression)

Recall that the predictive distribution has the form,

p(yD,x,β)=p(wD,β)p(yx,w,β) dwp(y \mid \mathcal{D}, \mathbf{x}, \beta) = \int p(\mathbf{w} \mid \mathcal{D}, \beta) p(y \mid \mathbf{x}, \mathbf{w}, \beta) \ d\mathbf{w}

Thus, the goal to find an approximation of p(w,αD,β)=p(w,αD)p(\mathbf{w}, \alpha \mid \mathcal{D}, \beta) = p(\mathbf{w}, \alpha \mid \mathcal{D}) is precisely the variational inference problem.

We will consider a posterior p(w,αD,β)q(w,α)p(\mathbf{w}, \alpha \mid \mathcal{D}, \beta) \approx q(\mathbf{w}, \alpha) that factorizes as,

q(w,α)=q(w)q(α)q(\mathbf{w}, \alpha) = q(\mathbf{w}) q(\alpha)

with q(w,α)p(w,αD)q(\mathbf{w}, \alpha) \equiv p(\mathbf{w}, \alpha \mid \mathcal{D}), q(w)p(wD)q(\mathbf{w}) \equiv p(\mathbf{w} \mid \mathcal{D}), and q(α)p(αD)q(\alpha) \equiv p(\alpha \mid \mathcal{D}). Thus, our goal is (again) to minimize ELBO.

Derivation (Variational Linear Regression)

We need to iterate the equations,

logq(α)=Eq(w)[logp(yD,w,α)]+constlogq(w)=Eq(α)[logp(yD,w,α)]+const\begin{align*} \log q^{\star}(\alpha) & = \mathbb{E}_{q(\mathbf{w})} [\log p(y_{\mathcal{D}}, \mathbf{w}, \alpha)] + \text{const} \newline \log q^{\star}(\mathbf{w}) & = \mathbb{E}_{q(\alpha)} [\log p(y_{\mathcal{D}}, \mathbf{w}, \alpha)] + \text{const} \newline \end{align*}

where p(yD,w,α)=p(yDw)p(wα)p(α)p(y_{\mathcal{D}}, \mathbf{w}, \alpha) = p(y_{\mathcal{D}} \mid \mathbf{w}) p(\mathbf{w} \mid \alpha) p(\alpha). Thus,

logq(α)=Eq(w)[logp(yD,w,α)]+const=Eq(w)[logp(wα)+logp(α)]+const=logp(α)+Eq(w)[logp(wα)]+const=(a01)logαb0α+M2logαα2Eq(w)[wTw]+const\begin{align*} \log q^{\star}(\alpha) & = \mathbb{E}_{q(\mathbf{w})} [\log p(y_{\mathcal{D}}, \mathbf{w}, \alpha)] + \text{const} \newline & = \mathbb{E}_{q(\mathbf{w})} [\log p(\mathbf{w} \mid \alpha) + \log p(\alpha)] + \text{const} \newline & = \log p(\alpha) + \mathbb{E}_{q(\mathbf{w})} [\log p(\mathbf{w} \mid \alpha)] + \text{const} \newline & = (a_0 - 1) \log \alpha - b_0 \alpha + \frac{M}{2} \log \alpha - \frac{\alpha}{2} \mathbb{E}_{q(\mathbf{w})} [\mathbf{w}^T \mathbf{w}] + \text{const} \newline \end{align*}

which is a Gamma distribution,

q(α)=Gam(αaN,bN),aN=a0+M2,bN=b0+12Eq(w)[wTw]q^{\star}(\alpha) = \mathrm{Gam}(\alpha \mid a_N, b_N), \quad a_N = a_0 + \frac{M}{2}, \quad b_N = b_0 + \frac{1}{2} \mathbb{E}_{q(\mathbf{w})} [\mathbf{w}^T \mathbf{w}]

We can (easily) generalize this with,

logq(α)=Eq(w)[logp(yD,w,α)]+const=Eq(w)[logp(wα)+logp(α)]+const=logp(α)+Eq(w)[logp(wα)]+const=β2i=1N(yiwTϕ(xi))212Eq(α)[α]wTw+const=12wT(Eq(α)[α]I+βΦTΦ)w+βwTΦTyD+const\begin{align*} \log q^{\star}(\alpha) & = \mathbb{E}_{q(\mathbf{w})} [\log p(y_{\mathcal{D}}, \mathbf{w}, \alpha)] + \text{const} \newline & = \mathbb{E}_{q(\mathbf{w})} [\log p(\mathbf{w} \mid \alpha) + \log p(\alpha)] + \text{const} \newline & = \log p(\alpha) + \mathbb{E}_{q(\mathbf{w})} [\log p(\mathbf{w} \mid \alpha)] + \text{const} \newline & = - \frac{\beta}{2} \sum_{i = 1}^N (y_i - \mathbf{w}^T \boldsymbol{\phi}(\mathbf{x}_i))^2 - \frac{1}{2} \mathbb{E}_{q(\alpha)} [\alpha] \mathbf{w}^T \mathbf{w} + \text{const} \newline & = -\frac{1}{2} \mathbf{w}^T \left(\mathbb{E}_{q(\alpha)} [\alpha] \mathbf{I} + \beta \boldsymbol{\Phi}^T \boldsymbol{\Phi}\right) \mathbf{w} + \beta \mathbf{w}^T \boldsymbol{\Phi}^T \mathbf{y}_{\mathcal{D}} + \text{const} \newline \end{align*}

which is a Gaussian distribution,

q(w)=N(wmN,SN),SN=(Eq(α)[α]I+βΦTΦ)1,mN=βSNΦTyDq^{\star}(\mathbf{w}) = \mathcal{N}(\mathbf{w} \mid \mathbf{m}_N, \mathbf{S}_N), \quad \mathbf{S}_N = \left(\mathbb{E}_{q(\alpha)} [\alpha] \mathbf{I} + \beta \boldsymbol{\Phi}^T \boldsymbol{\Phi}\right)^{-1}, \quad \mathbf{m}_N = \beta \mathbf{S}_N \boldsymbol{\Phi}^T \mathbf{y}_{\mathcal{D}}

Expectation Propagation

Intuition (Expectation Propagation)

Consider,

p(D,θ)iIfi(θ)p(\mathcal{D}, \boldsymbol{\theta}) \coloneqq \prod_i^I f_i(\boldsymbol{\theta})

where fi(θ)f_i(\boldsymbol{\theta}) are factors (e.g., likelihoods, priors, etc.), our goal is to evaluate p(θD)p(\boldsymbol{\theta} \mid \mathcal{D}),

p(θD)1p(D)iIfi(θ)p(\boldsymbol{\theta} \mid \mathcal{D}) \coloneqq \frac{1}{p(\mathcal{D})} \prod_i^I f_i(\boldsymbol{\theta})

We can approximate p(θD)p(\boldsymbol{\theta} \mid \mathcal{D}) with a tractable distribution q(θ)Ωq(\boldsymbol{\theta}) \in \Omega,

q(θ)1ZiIqi(θ)q(\boldsymbol{\theta}) \coloneqq \frac{1}{Z} \prod_i^I q_i(\boldsymbol{\theta})

Often assumed that factors come from the exponential family, i.e.,

q(θ)=1ZiIN(θμi,Σi)q(\boldsymbol{\theta}) = \frac{1}{Z} \prod_i^I \mathcal{N}(\boldsymbol{\theta} \mid \boldsymbol{\mu}_i, \boldsymbol{\Sigma}_i)

Thus, we want to find q(θ)q(\boldsymbol{\theta}) that minimizes the forward KL divergence,

q(θ)=argminq(θ)Ω KL(p(θD)q(θ))q^{\star}(\boldsymbol{\theta}) = \underset{q(\boldsymbol{\theta}) \in \Omega}{\arg\min} \ \mathrm{KL}(p(\boldsymbol{\theta} \mid \mathcal{D}) \mid \mid q(\boldsymbol{\theta}))

However, this is still intractable, as it requires knowledge of the true posterior p(θD)p(\boldsymbol{\theta} \mid \mathcal{D}).

The idea is instead to optimize each factor in turn (keeping others constant).

Algorithm (Expectation Propagation)
  1. Initial factors qi(θ)q_i(\boldsymbol{\theta}).

  2. Until convergence, cycle through factors qj(θ)q_j(\boldsymbol{\theta}) and optimize as,

qjnew(θ)=argminqj(θ)Ω KL[1p(D)fj(θ)ijqi(θ)old(θ)1Zqj(θ)ijqiold(θ)]q_j^{\text{new}}(\boldsymbol{\theta}) = \underset{q_j(\boldsymbol{\theta}) \in \Omega}{\arg\min} \ \mathrm{KL}\left[\frac{1}{p(\mathcal{D})} f_j(\boldsymbol{\theta}) \prod_{i \neq j} q_i(\boldsymbol{\theta})^{\text{old}}(\boldsymbol{\theta}) \mid \mid \frac{1}{Z} q_j(\boldsymbol{\theta}) \prod_{i \neq j} q_i^{\text{old}}(\boldsymbol{\theta})\right]