Part 5 - Flow and Diffusion models

Reverse-time Euler-Maruyama Method

In the forward SDE,

dx=f(x(t),t) dt+L(t) dβ(t),dx = f(x(t), t) \ dt + L(t) \ d \beta(t),

we simulate forward in time using Euler-Maruyama method, also note that L(t)L(t) does not depend on the state x(t)x(t),

xn+1=xn+f(xn,tn)Δt+L(tn)Δt ξn.x_{n + 1} = x_n + f(x_n, t_n) \Delta t + L(t_n) \sqrt{\Delta t} \ \xi_n.

For the reverse SDE 1,

dx=[f(x(t),t)L(t)2xlogp(x(t),t)]dt+L(t) dβˉ(t),dx = [f(x(t), t) - L(t)^2 \nabla_x \log p(x(t), t)] dt + L(t) \ d \bar{\beta}(t),

we simulate backward in time with,

xn1=xn[f(xn,tn)L(tn)2xlogp(xn,tn)]ΔtL(tn)Δt ξn.x_{n - 1} = x_n - [f(x_n, t_n) - L(t_n)^2 \nabla_x \log p(x_n, t_n)] \Delta t - L(t_n) \sqrt{\Delta t} \ \xi_n.

The score function xlogp(x(t),t)\nabla_x \log p(x(t), t) must be estimated or known.

Reverse Euler-Maruyama for Ornstein-Uhlenbeck Process

Consider the forward SDE,

dx=λx(t) dt+σ dβ(t),x(0)=0.dx = -\lambda x(t) \ dt + \sigma \ d \beta(t), \quad x(0) = 0.

The marginal distribution p(x(t),t)p(x(t), t) is Gaussian with zero mean and variance,

Var[x(t)]=σ22λ(1e2λt).\mathrm{Var}[x(t)] = \frac{\sigma^2}{2 \lambda}(1 - e^{-2 \lambda t}).

The score of the process is,

xlogp(x(t),t)=xlog[12πVar[x(t)]ex(t)22Var[x(t)]]=x[12log(2π)12log(Var[x(t)])x(t)22Var[x(t)]]=1Var[x(t)]xVar[x(t)]=0x(t)Var[x(t)]xx(t)=1=x(t)Var[x(t)].\begin{align*} \nabla_x \log p(x(t), t) & = \nabla_x \log \left[\frac{1}{\sqrt{2 \pi \mathrm{Var}[x(t)]}} e^{-\frac{x(t)^2}{2 \mathrm{Var}[x(t)]}} \right] \newline & = \nabla_x \left[-\frac{1}{2} \log(2 \pi) - \frac{1}{2} \log(\mathrm{Var}[x(t)]) - \frac{x(t)^2}{2 \mathrm{Var}[x(t)]} \right] \newline & = -\frac{1}{\mathrm{Var}[x(t)]} \underbrace{\nabla_x \mathrm{Var}[x(t)]}_{= 0} - \frac{x(t)}{\mathrm{Var}[x(t)]} \underbrace{\nabla_x x(t)}_{= 1} \newline & = -\frac{x(t)}{\mathrm{Var}[x(t)]}. \end{align*}

Which we can plug into the reverse-time Euler-Maruyama method,

xn1=xn[λxn+σ2xnVar[xn]]ΔtσΔt ξn,x_{n - 1} = x_n - \left[-\lambda x_n + \frac{\sigma^2 x_n}{\mathrm{Var}[x_n]} \right] \Delta t - \sigma \sqrt{\Delta t} \ \xi_n,
Reverse Euler-Maruyama for Ornstein-Uhlenbeck Process
Reverse Euler-Maruyama for Ornstein-Uhlenbeck Process

Denoising Diffusion Probabilistic Models

The forward and backward processes of the diffusion model.
The forward and backward processes of the diffusion model.

How do you design a scalar SDE of the form,

dx(t)=f(x(t),t) dt+L(t) dβ(t),dx(t) = f(x(t), t) \ dt + L(t) \ d \beta(t),

that produces a process x(t)x(t) with,

  1. Mean: E[x(t)]=α(t)x0\mathbb{E}[x(t)] = \alpha(t) x_0,
  2. Variance: Var[x(t)]=σ(t)2\mathrm{Var}[x(t)] = \sigma(t)^2,
  3. Initial Conditions: x(0)=x0,α(0)=1,σ(0)=0x(0) = x_0, \alpha(0) = 1, \sigma(0) = 0?

Let’s start with the mean condition, recall that,

dm(t)dt=dE[x(t)]dt=E[f(x(t),t)].\begin{equation} \frac{dm(t)}{dt} = \frac{d\mathbb{E}[x(t)]}{dt} = \mathbb{E}[f(x(t), t)]. \end{equation}

Thus, from our ansatz,

E[x(t)]=α(t)x0    dE[x(t)]dt=α˙(t)x0.\begin{equation} \mathbb{E}[x(t)] = \alpha(t) x_0 \implies \frac{d \mathbb{E}[x(t)]}{dt} = \dot{\alpha}(t) x_0. \end{equation}

A natural choice for f(x(t),t)f(x(t), t) is a simple affine function of the form,

f(x(t),t)=a(t)x(t)+b(t),f(x(t), t) = a(t) x(t) + b(t),

Taking the expectation yields,

E[f(x(t),t)]=a(t)E[x(t)]+b(t)=a(t)α(t)x0+b(t).\begin{align*} \mathbb{E}[f(x(t), t)] & = a(t) \mathbb{E}[x(t)] + b(t) \newline & = a(t) \alpha(t) x_0 + b(t). \end{align*}

If we, e.g., set b(t)=0b(t) = 0 then we can solve for a(t)a(t) from Equation (1) and Equation (2) to get,

a(t)α(t)x0=α˙(t)x0a(t)=α˙(t)α(t)(=ddtlog(α(t))).\begin{align*} a(t) \alpha(t) x_0 & = \dot{\alpha}(t) x_0 \newline a(t) & = \frac{\dot{\alpha}(t)}{\alpha(t)} \left(= \frac{d}{dt} \log(\alpha(t))\right). \end{align*}

Thus, we have the linear SDE of the form,

dx(t)=(α˙(t)α(t)x(t))dt+L(t) dβ(t).dx(t) = \left(\frac{\dot{\alpha}(t)}{\alpha(t)} x(t) \right) dt + L(t) \ d \beta(t).

Now, one can check that the mean condition is satisfied by this SDE, but we also have a variance condition to satisfy, thus, let’s consider the linear SDE of the form,

dx(t)=(α˙(t)α(t)x(t))dt+α(t)newL(t) dβ(t).dx(t) = \left(\frac{\dot{\alpha}(t)}{\alpha(t)} x(t) \right) dt + \underbrace{\alpha(t)}_{\text{new}} L(t) \ d \beta(t).

The solution to this SDE is,

x(t)=α(t)(x0+0tL(s) dβ(s))y(t).x(t) = \alpha(t) \underbrace{\left(x_0 + \int_0^t L(s) \ d \beta(s) \right)}_{y(t)}.

Thus,

d(y(t))=L(t) dβ(t),d(y(t)) = L(t) \ d \beta(t),

By definition of the variance,

Var[x(t)]=E[(x(t)=α(t)y(t)E[x(t)]=α(t)x0)2]=E[(α(t)y(t)α(t)x0)2]=E[((α(t)deterministic)(y(t)x0deterministic))2]=α(t)20tL(s)2 ds\begin{align*} \mathrm{Var}[x(t)] & = \mathbb{E}[(\underbrace{x(t)}_{= \alpha(t) y(t)} - \underbrace{\mathbb{E}[x(t)]}_{= \alpha(t) x_0})^2] \newline & = \mathbb{E}[(\alpha(t) y(t) - \alpha(t) x_0)^2] \newline & = \mathbb{E}[((\underbrace{\alpha(t)}_{\text{deterministic}})(y(t) - \underbrace{x_0}_{\text{deterministic}}))^2] \newline & = \alpha(t)^2 \int_0^t L(s)^2 \ ds \newline \end{align*}

In our case, we want our variance to be σ(t)2\sigma(t)^2, thus we can set,

α(t)20tL(s)2 ds=σ(t)20tL(s)2 ds=σ(t)2α(t)2L(t)2=ddt(σ(t)2α(t)2)L(t)=ddt(σ(t)2α(t)2)\begin{align*} \alpha(t)^2 \int_0^t L(s)^2 \ ds & = \sigma(t)^2 \newline \int_0^t L(s)^2 \ ds & = \frac{\sigma(t)^2}{\alpha(t)^2} \newline L(t)^2 & = \frac{d}{dt} \left(\frac{\sigma(t)^2}{\alpha(t)^2}\right) \newline L(t) & = \sqrt{\frac{d}{dt} \left(\frac{\sigma(t)^2}{\alpha(t)^2}\right)} \newline \end{align*}

Summary

To construct an SDE with,

  1. Mean: E[x(t)]=α(t)x0\mathbb{E}[x(t)] = \alpha(t) x_0,
  2. Variance: Var[x(t)]=σ(t)2\mathrm{Var}[x(t)] = \sigma(t)^2,

We use the SDE,

dx(t)=(α˙(t)α(t)x(t))dt+ddt(σ(t)2α(t)2) dβ(t).dx(t) = \left(\frac{\dot{\alpha}(t)}{\alpha(t)} x(t) \right) dt + \sqrt{\frac{d}{dt} \left(\frac{\sigma(t)^2}{\alpha(t)^2}\right)} \ d \beta(t).

This guarantees that the process x(t)x(t) is Gaussian with the desired mean and variance.

Given an SDE and target properties

Suppose we are given an SDE of the form,

dx(t)=α˙(t)α(t)x(t) dt+α(t)L(t) dβ(t),x(0)=x0,dx(t) = \frac{\dot{\alpha}(t)}{\alpha(t)} x(t) \ dt + \alpha(t) L(t) \ d \beta(t), \quad x(0) = x_0,

with,

  1. Mean: E[x(t)]=(1t)x0\mathbb{E}[x(t)] = (1 - t) x_0,
  2. Target Variance: Var[x(t)]=t2\mathrm{Var}[x(t)] = t^2.

Using α(t)=1t\alpha(t) = 1 - t we have,

α˙(t)=1,α˙(t)α(t)=11t.\dot{\alpha}(t) = -1, \quad \frac{\dot{\alpha}(t)}{\alpha(t)} = -\frac{1}{1 - t}.

Thus, the SDE becomes,

dx(t)=11tx(t) dt+(1t)L(t) dβ(t).dx(t) = -\frac{1}{1 - t} x(t) \ dt + (1 - t) L(t) \ d \beta(t).

By doing the same “trick” as before, we find that,

L2(t)=ddt(t2(1t)2)=2t(1t)2+2t2(1t)(1t)4=2t[(1t)2+t(1t)](1t)4=2t(1t)(1t)4=2t(1t)3.\begin{align*} L^2(t) & = \frac{d}{dt} \left(\frac{t^2}{(1 - t)^2}\right) \newline & = \frac{2t(1 - t)^2 + 2t^2(1 - t)}{(1 - t)^4} \newline & = \frac{2t[(1 - t)^2 + t(1 - t)]}{(1 - t)^4} \newline & = \frac{2t(1 - t)}{(1 - t)^4} \newline & = \frac{2t}{(1 - t)^3}. \end{align*}

So,

L(t)=2t(1t)3,L(t) = \sqrt{\frac{2t}{(1 - t)^3}},

we can also further simplify the diffusion term to,

(1t)L(t)=(1t)2t(1t)3=2t(1t).(1 - t) L(t) = (1 - t) \sqrt{\frac{2t}{(1 - t)^3}} = \sqrt{\frac{2t}{(1 - t)}}.

So the SDE becomes,

dx(t)=11tx(t) dt+2t(1t) dβ(t).dx(t) = -\frac{1}{1 - t} x(t) \ dt + \sqrt{\frac{2t}{(1 - t)}} \ d \beta(t).

This is a fully simplified form valid for t<1t < 1 which produces,

  1. Mean: E[x(t)]=(1t)x0\mathbb{E}[x(t)] = (1 - t) x_0,
  2. Variance: Var[x(t)]=t2\mathrm{Var}[x(t)] = t^2.

The process is Gaussian, fully characterized by these expressions.

Flow and Diffusion Models

Consider the SDE in Rd\mathbb{R}^d,

dx(t)=fθ(x(t),t) dt+σt dβ(t),d \mathbf{x}(t) = \mathbf{f}^{\theta}(\mathbf{x}(t), t) \ dt + \sigma_t \ d \beta(t),

Our goal is to, given x(0)pinit(x(0))\mathbf{x}(0) \sim p_{\text{init}}(\mathbf{x}(0)), make sure that x(1)pdata(x(1))\mathbf{x}(1) \sim p_{\text{data}}(\mathbf{x}(1))

Note (Terminology)

σt=0\sigma_t = 0 Flow model: The SDE is reduced to an ODE (e.g., Flow Matching). σt0\sigma_t \neq 0 Diffusion model: The SDE is a stochastic process (e.g., Denoising Diffusion Probabilistic Models).

The forward and backward processes of the diffusion model.
The forward and backward processes of the diffusion model.

Consider the SDE in Rd\mathbb{R}^d,

dx(t)=fθ(x(t),t) dt+σt dβ(t),d \mathbf{x}(t) = \mathbf{f}^{\theta}(\mathbf{x}(t), t) \ dt + \sigma_t \ d \beta(t),

Suppose x(0)p0=pinit\mathbf{x}(0) \sim p_0 = p_{\text{init}} at time t=0t = 0 where pinitp_{\text{init}} is a simple distribution like a Gaussian.

We would like x(1)p1=pdata\mathbf{x}(1) \sim p_1 = p_{\text{data}} at time t=1t = 1 where pdatap_{\text{data}} is some complicated data distribution.

We’ll study how to,

  1. Construct an (ideal) vector field ftarget\mathbf{f}^{\text{target}} and σt\sigma_t.
  2. Train such that fθftarget\mathbf{f}^{\theta} \approx \mathbf{f}^{\text{target}}

We can then sample from pinitp_{\text{init}}, run it through the SDE and obtain a sample from pdatap_{\text{data}}.

Method Overview — 4 Steps

  1. Start by constructing a probability path pt(z)p_t(\cdot | \mathbf{z}) for individual samples zz which interpolates from noise p0(z)=pinit()p_0(\cdot | \mathbf{z}) = p_{\text{init}}(\cdot) to data sample zz, p1(z)=δz()p_1(\cdot | \mathbf{z}) = \delta_{\mathbf{z}}(\cdot)
  2. Marginalization gives us,
pt(x)=pt(xz)pdata(z) dz.p_t(\mathbf{x}) = \int p_t(\mathbf{x} | \mathbf{z}) p_{\text{data}}(\mathbf{z}) \ d \mathbf{z}.
  1. Construct an (ideal) SDE which simulates ptp_t.
  2. Train a neural network that approximates this SDE.

Step 1: Constructing the Probability Path

Start by constructing,

pt(z),p_t(\cdot | \mathbf{z}),

for individual samples zz from data distribution with conditions,

p0(z)=pinit() and p1(z)=δz() for all zRd.p_0(\cdot | \mathbf{z}) = p_{\text{init}}(\cdot) \text{ and } p_1(\cdot | \mathbf{z}) = \delta_{\mathbf{z}}(\cdot) \text{ for all } \mathbf{z} \in \mathbb{R}^d.

Such a path pt(z)p_t(\cdot | \mathbf{z}) is called a conditional probability path.

Step 2: Marginalization

Each conditional path induces a marginal probability path,

pt(x)=pt(xz)pdata(z) dz.p_t(\mathbf{x}) = \int p_t(\mathbf{x} | \mathbf{z}) p_{\text{data}}(\mathbf{z}) \ d \mathbf{z}.

We can sample from ptp_t via,

zpdata,xpt(z)    xpt.\mathbf{z} \sim p_{\text{data}}, \quad \mathbf{x} \sim p_t(\cdot | \mathbf{z}) \implies \mathbf{x} \sim p_t.

Note we do not know the density values of pt(x)p_t(\mathbf{x}) as the integral is intractable.

The marginal probability path pt(x)p_t(\mathbf{x}) interpolates between pinitp_{\text{init}} and pdatap_{\text{data}},

p0(x)=pinit(x) and p1(x)=pdata(x).p_0(\mathbf{x}) = p_{\text{init}}(\mathbf{x}) \text{ and } p_1(\mathbf{x}) = p_{\text{data}}(\mathbf{x}).
Tip (Example: Gaussian)

Define the path,

pt(z)=N(αtz,σt2Id),p_t(\cdot | \mathbf{z}) = \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d),

with α0=σ1=0\alpha_0 = \sigma_1 = 0 and α1=σ0=1\alpha_1 = \sigma_0 = 1, or in other words, we move from a N(0,1)\mathcal{N}(0, 1) to a N(z,0)\mathcal{N}(\mathbf{z}, 0).

Note, αt\alpha_t and σt\sigma_t should be continuous and differentiable functions in tt.

Sampling from the marginal path ptp_t,

zpdata,ϵpinit=N(0,Id)    x(t)=αtz+σtϵpt.\mathbf{z} \sim p_{\text{data}}, \epsilon \sim p_{\text{init}} = \mathcal{N}(0, \mathbf{I}_d) \implies \mathbf{x}(t) = \alpha_t \mathbf{z} + \sigma_t \epsilon \sim p_t.

Step 3: Constructing the SDE

First, we construct analytically an SDE for simulating the conditional probability path.

Then, via the so-called “marginalization trick”, we obtain an SDE for the marginal probability path.

Recall our notation for an SDE dx(t)=f(x(t),t) dt+L(t)d β(t)d\mathbf{x}(t) = \mathbf{f}(\mathbf{x}(t), t) \ dt + L(t) d \ \beta(t).

Tip (Example: Gaussian)

For the path,

pt(z)=N(αtz,σt2Id),p_t(\cdot | \mathbf{z}) = \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d),

with α0=σ1=0\alpha_0 = \sigma_1 = 0 and α1=σ0=1\alpha_1 = \sigma_0 = 1, then the following ODE,

f(x(t),t)=(α˙tσ˙tσtαt)z+σ˙tσtx(t) and L(t)=0,\mathbf{f}(\mathbf{x}(t), t) = (\dot{\alpha}_t - \frac{\dot{\sigma}_t}{\sigma_t} \alpha_t) \mathbf{z} + \frac{\dot{\sigma}_t}{\sigma_t} \mathbf{x}(t) \text{ and } L(t) = 0,

simulates the conditional probability path in the sense that x(t)N(αtz,σt2Id\mathbf{x}(t) \sim \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d.

We can check this with,

  1. By definition of mean and variance of an SDE,
  2. The Fokker-Planck-Kolmogorov equation.

Marginalization Trick

What is the SDE for the marginal path?

Note (Theorem)

If the ODE with vector field ftarget(x(t),tz)\mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) yields the conditional path pt(x(t),z)p_t(\mathbf{x}(t), \mathbf{z}) then the marginal field,

ftarget(x(t),t)=ftarget(x(t),tz)pt(x(t)z)pdata(z)pt(x(t))pt(zx(t)) dz,\mathbf{f}^{\text{target}}(\mathbf{x}(t), t) = \int \mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) \underbrace{\frac{p_t(\mathbf{x}(t) | \mathbf{z}) p_{\text{data}}(\mathbf{z})}{p_t(\mathbf{x}(t))}}_{p_t(\mathbf{z} | \mathbf{x}(t))} \ d \mathbf{z},

yields the marginal path pt(x(t))p_t(\mathbf{x}(t)).

Hence, using ftarget(x(t),t)\mathbf{f}^{\text{target}}(\mathbf{x}(t), t) yields a good path — it turns noise into data! (sadly, it is unknown).

The proof? Fokker-Planck-Kolmogorov equation!

Step 4: Training the Neural Network

A possible loss function is,

LFM=EtUnif,x(t)pt[fθ(x(t),t)ftarget(x(t),t)2]\mathcal{L}_{FM} = \mathbb{E}_{t \sim \text{Unif}, \mathbf{x}(t) \sim p_t} \left[ \Vert \mathbf{f}^{\theta}(\mathbf{x}(t), t) - \mathbf{f}^{\text{target}}(\mathbf{x}(t), t) \Vert^2 \right]

However, ftarget(x(t),t)\mathbf{f}^{\text{target}}(\mathbf{x}(t), t) is not explicitly available.

Note (Theorem)

For,

LFM=EtUnif,x(t)pt[fθ(x(t),t)ftarget(x(t),t)2],LCFM=EtUnif,zpdata,xpt(z)[fθ(x(t),t)ftarget(x(t),tz)2],\begin{align*} \mathcal{L}_{FM} & = \mathbb{E}_{t \sim \text{Unif}, \mathbf{x}(t) \sim p_t} \left[ \Vert \mathbf{f}^{\theta}(\mathbf{x}(t), t) - \mathbf{f}^{\text{target}}(\mathbf{x}(t), t) \Vert^2 \right], \newline \mathcal{L}_{CFM} & = \mathbb{E}_{t \sim \text{Unif}, \mathbf{z} \sim p_{\text{data}}, \mathbf{x} \sim p_t(\cdot | \mathbf{z})} \left[ \Vert \mathbf{f}^{\theta}(\mathbf{x}(t), t) - \mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) \Vert^2 \right], \end{align*}

it holds that θLFM=θLCFM\nabla_{\theta} \mathcal{L}_{FM} = \nabla_{\theta} \mathcal{L}_{CFM}.

Hence, we can use LCFM\mathcal{L}_{CFM} to train fθ(x(t),t)\mathbf{f}^{\theta}(\mathbf{x}(t), t) (this is the main takeaway from the original flow matching paper 2).

Tip (Example: Gaussian)

The path pt(z)=N(αtz,σt2Id)p_t(\cdot | \mathbf{z}) = \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d) is simulated via the ODE with vector field,

ftarget(x(t),tz)=(α˙tσ˙tσtαt)z+σ˙tσtx(t).\mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) = (\dot{\alpha}_t - \frac{\dot{\sigma}_t}{\sigma_t} \alpha_t) \mathbf{z} + \frac{\dot{\sigma}_t}{\sigma_t} \mathbf{x}(t).

A natural choice is α(t)=t\alpha(t) = t and σ(t)=1t\sigma(t) = 1 - t, which gives α˙(t)=1\dot{\alpha}(t) = 1 and σ˙(t)=1\dot{\sigma}(t) = -1. It yields the loss function,

LCFM=EtUnif,zpdata,xN(αtz,σt2Id)[fθ(x(t),t)ftarget(x(t),tz)2]=EtUnif,zpdata,xN(0,Id)[fθ(tz+(1t)ϵ,t)(zϵ)2]\begin{align*} \mathcal{L}_{CFM} & = \mathbb{E}_{t \sim \text{Unif}, \mathbf{z} \sim p_{\text{data}}, \mathbf{x} \sim \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d)} \left[ \Vert \mathbf{f}^{\theta}(\mathbf{x}(t), t) - \mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) \Vert^2 \right] \newline & = \mathbb{E}_{t \sim \text{Unif}, \mathbf{z} \sim p_{\text{data}}, \mathbf{x} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}_d)} \left[ \Vert \mathbf{f}^{\theta}(t \mathbf{z} + (1 - t) \epsilon, t) - (\mathbf{z} - \epsilon) \Vert^2 \right] \newline \end{align*}

where we have used x(t)=tz+(1t)ϵ\mathbf{x}(t) = t \mathbf{z} + (1 - t) \epsilon.

We have constructed an ODE for a conditional probability path. One can also construct an SDE and the same methodology is applicable.

Footnotes

  1. Reverse-Time Diffusion Equation Models by Anderson.

  2. Flow Matching for Generative Modeling by Lipman et al.