Reverse-time Euler-Maruyama Method
In the forward SDE,
d x = f ( x ( t ) , t ) d t + L ( t ) d β ( t ) , dx = f(x(t), t) \ dt + L(t) \ d \beta(t), d x = f ( x ( t ) , t ) d t + L ( t ) d β ( t ) ,
we simulate forward in time using Euler-Maruyama method, also note that L ( t ) L(t) L ( t ) does not depend on the state x ( t ) x(t) x ( t ) ,
x n + 1 = x n + f ( x n , t n ) Δ t + L ( t n ) Δ t ξ n . x_{n + 1} = x_n + f(x_n, t_n) \Delta t + L(t_n) \sqrt{\Delta t} \ \xi_n. x n + 1 = x n + f ( x n , t n ) Δ t + L ( t n ) Δ t ξ n .
For the reverse SDE 1 ,
d x = [ f ( x ( t ) , t ) − L ( t ) 2 ∇ x log p ( x ( t ) , t ) ] d t + L ( t ) d β ˉ ( t ) , dx = [f(x(t), t) - L(t)^2 \nabla_x \log p(x(t), t)] dt + L(t) \ d \bar{\beta}(t), d x = [ f ( x ( t ) , t ) − L ( t ) 2 ∇ x log p ( x ( t ) , t )] d t + L ( t ) d β ˉ ( t ) ,
we simulate backward in time with,
x n − 1 = x n − [ f ( x n , t n ) − L ( t n ) 2 ∇ x log p ( x n , t n ) ] Δ t − L ( t n ) Δ t ξ n . x_{n - 1} = x_n - [f(x_n, t_n) - L(t_n)^2 \nabla_x \log p(x_n, t_n)] \Delta t - L(t_n) \sqrt{\Delta t} \ \xi_n. x n − 1 = x n − [ f ( x n , t n ) − L ( t n ) 2 ∇ x log p ( x n , t n )] Δ t − L ( t n ) Δ t ξ n .
The score function ∇ x log p ( x ( t ) , t ) \nabla_x \log p(x(t), t) ∇ x log p ( x ( t ) , t ) must be estimated or known.
Reverse Euler-Maruyama for Ornstein-Uhlenbeck Process
Consider the forward SDE,
d x = − λ x ( t ) d t + σ d β ( t ) , x ( 0 ) = 0. dx = -\lambda x(t) \ dt + \sigma \ d \beta(t), \quad x(0) = 0. d x = − λ x ( t ) d t + σ d β ( t ) , x ( 0 ) = 0.
The marginal distribution p ( x ( t ) , t ) p(x(t), t) p ( x ( t ) , t ) is Gaussian with zero mean and variance,
V a r [ x ( t ) ] = σ 2 2 λ ( 1 − e − 2 λ t ) . \mathrm{Var}[x(t)] = \frac{\sigma^2}{2 \lambda}(1 - e^{-2 \lambda t}). Var [ x ( t )] = 2 λ σ 2 ( 1 − e − 2 λ t ) .
The score of the process is,
∇ x log p ( x ( t ) , t ) = ∇ x log [ 1 2 π V a r [ x ( t ) ] e − x ( t ) 2 2 V a r [ x ( t ) ] ] = ∇ x [ − 1 2 log ( 2 π ) − 1 2 log ( V a r [ x ( t ) ] ) − x ( t ) 2 2 V a r [ x ( t ) ] ] = − 1 V a r [ x ( t ) ] ∇ x V a r [ x ( t ) ] ⏟ = 0 − x ( t ) V a r [ x ( t ) ] ∇ x x ( t ) ⏟ = 1 = − x ( t ) V a r [ x ( t ) ] . \begin{align*}
\nabla_x \log p(x(t), t) & = \nabla_x \log \left[\frac{1}{\sqrt{2 \pi \mathrm{Var}[x(t)]}} e^{-\frac{x(t)^2}{2 \mathrm{Var}[x(t)]}} \right] \newline
& = \nabla_x \left[-\frac{1}{2} \log(2 \pi) - \frac{1}{2} \log(\mathrm{Var}[x(t)]) - \frac{x(t)^2}{2 \mathrm{Var}[x(t)]} \right] \newline
& = -\frac{1}{\mathrm{Var}[x(t)]} \underbrace{\nabla_x \mathrm{Var}[x(t)]}_{= 0} - \frac{x(t)}{\mathrm{Var}[x(t)]} \underbrace{\nabla_x x(t)}_{= 1} \newline
& = -\frac{x(t)}{\mathrm{Var}[x(t)]}.
\end{align*} ∇ x log p ( x ( t ) , t ) = ∇ x log [ 2 π Var [ x ( t )] 1 e − 2 Var [ x ( t )] x ( t ) 2 ] = ∇ x [ − 2 1 log ( 2 π ) − 2 1 log ( Var [ x ( t )]) − 2 Var [ x ( t )] x ( t ) 2 ] = − Var [ x ( t )] 1 = 0 ∇ x Var [ x ( t )] − Var [ x ( t )] x ( t ) = 1 ∇ x x ( t ) = − Var [ x ( t )] x ( t ) .
Which we can plug into the reverse-time Euler-Maruyama method,
x n − 1 = x n − [ − λ x n + σ 2 x n V a r [ x n ] ] Δ t − σ Δ t ξ n , x_{n - 1} = x_n - \left[-\lambda x_n + \frac{\sigma^2 x_n}{\mathrm{Var}[x_n]} \right] \Delta t - \sigma \sqrt{\Delta t} \ \xi_n, x n − 1 = x n − [ − λ x n + Var [ x n ] σ 2 x n ] Δ t − σ Δ t ξ n ,
Reverse Euler-Maruyama for Ornstein-Uhlenbeck Process
Denoising Diffusion Probabilistic Models
The forward and backward processes of the diffusion model.
How do you design a scalar SDE of the form,
d x ( t ) = f ( x ( t ) , t ) d t + L ( t ) d β ( t ) , dx(t) = f(x(t), t) \ dt + L(t) \ d \beta(t), d x ( t ) = f ( x ( t ) , t ) d t + L ( t ) d β ( t ) ,
that produces a process x ( t ) x(t) x ( t ) with,
Mean: E [ x ( t ) ] = α ( t ) x 0 \mathbb{E}[x(t)] = \alpha(t) x_0 E [ x ( t )] = α ( t ) x 0 ,
Variance: V a r [ x ( t ) ] = σ ( t ) 2 \mathrm{Var}[x(t)] = \sigma(t)^2 Var [ x ( t )] = σ ( t ) 2 ,
Initial Conditions: x ( 0 ) = x 0 , α ( 0 ) = 1 , σ ( 0 ) = 0 x(0) = x_0, \alpha(0) = 1, \sigma(0) = 0 x ( 0 ) = x 0 , α ( 0 ) = 1 , σ ( 0 ) = 0 ?
Let’s start with the mean condition, recall that,
d m ( t ) d t = d E [ x ( t ) ] d t = E [ f ( x ( t ) , t ) ] . \begin{equation}
\frac{dm(t)}{dt} = \frac{d\mathbb{E}[x(t)]}{dt} = \mathbb{E}[f(x(t), t)].
\end{equation} d t d m ( t ) = d t d E [ x ( t )] = E [ f ( x ( t ) , t )] .
Thus, from our ansatz,
E [ x ( t ) ] = α ( t ) x 0 ⟹ d E [ x ( t ) ] d t = α ˙ ( t ) x 0 . \begin{equation}
\mathbb{E}[x(t)] = \alpha(t) x_0 \implies \frac{d \mathbb{E}[x(t)]}{dt} = \dot{\alpha}(t) x_0.
\end{equation} E [ x ( t )] = α ( t ) x 0 ⟹ d t d E [ x ( t )] = α ˙ ( t ) x 0 .
A natural choice for f ( x ( t ) , t ) f(x(t), t) f ( x ( t ) , t ) is a simple affine function of the form,
f ( x ( t ) , t ) = a ( t ) x ( t ) + b ( t ) , f(x(t), t) = a(t) x(t) + b(t), f ( x ( t ) , t ) = a ( t ) x ( t ) + b ( t ) ,
Taking the expectation yields,
E [ f ( x ( t ) , t ) ] = a ( t ) E [ x ( t ) ] + b ( t ) = a ( t ) α ( t ) x 0 + b ( t ) . \begin{align*}
\mathbb{E}[f(x(t), t)] & = a(t) \mathbb{E}[x(t)] + b(t) \newline
& = a(t) \alpha(t) x_0 + b(t).
\end{align*} E [ f ( x ( t ) , t )] = a ( t ) E [ x ( t )] + b ( t ) = a ( t ) α ( t ) x 0 + b ( t ) .
If we, e.g., set b ( t ) = 0 b(t) = 0 b ( t ) = 0 then we can solve for a ( t ) a(t) a ( t ) from Equation (1) and Equation (2) to get,
a ( t ) α ( t ) x 0 = α ˙ ( t ) x 0 a ( t ) = α ˙ ( t ) α ( t ) ( = d d t log ( α ( t ) ) ) . \begin{align*}
a(t) \alpha(t) x_0 & = \dot{\alpha}(t) x_0 \newline
a(t) & = \frac{\dot{\alpha}(t)}{\alpha(t)} \left(= \frac{d}{dt} \log(\alpha(t))\right).
\end{align*} a ( t ) α ( t ) x 0 a ( t ) = α ˙ ( t ) x 0 = α ( t ) α ˙ ( t ) ( = d t d log ( α ( t )) ) .
Thus, we have the linear SDE of the form,
d x ( t ) = ( α ˙ ( t ) α ( t ) x ( t ) ) d t + L ( t ) d β ( t ) . dx(t) = \left(\frac{\dot{\alpha}(t)}{\alpha(t)} x(t) \right) dt + L(t) \ d \beta(t). d x ( t ) = ( α ( t ) α ˙ ( t ) x ( t ) ) d t + L ( t ) d β ( t ) .
Now, one can check that the mean condition is satisfied by this SDE, but we also have a variance condition to satisfy, thus, let’s consider the linear SDE of the form,
d x ( t ) = ( α ˙ ( t ) α ( t ) x ( t ) ) d t + α ( t ) ⏟ new L ( t ) d β ( t ) . dx(t) = \left(\frac{\dot{\alpha}(t)}{\alpha(t)} x(t) \right) dt + \underbrace{\alpha(t)}_{\text{new}} L(t) \ d \beta(t). d x ( t ) = ( α ( t ) α ˙ ( t ) x ( t ) ) d t + new α ( t ) L ( t ) d β ( t ) .
The solution to this SDE is,
x ( t ) = α ( t ) ( x 0 + ∫ 0 t L ( s ) d β ( s ) ) ⏟ y ( t ) . x(t) = \alpha(t) \underbrace{\left(x_0 + \int_0^t L(s) \ d \beta(s) \right)}_{y(t)}. x ( t ) = α ( t ) y ( t ) ( x 0 + ∫ 0 t L ( s ) d β ( s ) ) .
Thus,
d ( y ( t ) ) = L ( t ) d β ( t ) , d(y(t)) = L(t) \ d \beta(t), d ( y ( t )) = L ( t ) d β ( t ) ,
By definition of the variance,
V a r [ x ( t ) ] = E [ ( x ( t ) ⏟ = α ( t ) y ( t ) − E [ x ( t ) ] ⏟ = α ( t ) x 0 ) 2 ] = E [ ( α ( t ) y ( t ) − α ( t ) x 0 ) 2 ] = E [ ( ( α ( t ) ⏟ deterministic ) ( y ( t ) − x 0 ⏟ deterministic ) ) 2 ] = α ( t ) 2 ∫ 0 t L ( s ) 2 d s \begin{align*}
\mathrm{Var}[x(t)] & = \mathbb{E}[(\underbrace{x(t)}_{= \alpha(t) y(t)} - \underbrace{\mathbb{E}[x(t)]}_{= \alpha(t) x_0})^2] \newline
& = \mathbb{E}[(\alpha(t) y(t) - \alpha(t) x_0)^2] \newline
& = \mathbb{E}[((\underbrace{\alpha(t)}_{\text{deterministic}})(y(t) - \underbrace{x_0}_{\text{deterministic}}))^2] \newline
& = \alpha(t)^2 \int_0^t L(s)^2 \ ds \newline
\end{align*} Var [ x ( t )] = E [( = α ( t ) y ( t ) x ( t ) − = α ( t ) x 0 E [ x ( t )] ) 2 ] = E [( α ( t ) y ( t ) − α ( t ) x 0 ) 2 ] = E [(( deterministic α ( t ) ) ( y ( t ) − deterministic x 0 ) ) 2 ] = α ( t ) 2 ∫ 0 t L ( s ) 2 d s
In our case, we want our variance to be σ ( t ) 2 \sigma(t)^2 σ ( t ) 2 , thus we can set,
α ( t ) 2 ∫ 0 t L ( s ) 2 d s = σ ( t ) 2 ∫ 0 t L ( s ) 2 d s = σ ( t ) 2 α ( t ) 2 L ( t ) 2 = d d t ( σ ( t ) 2 α ( t ) 2 ) L ( t ) = d d t ( σ ( t ) 2 α ( t ) 2 ) \begin{align*}
\alpha(t)^2 \int_0^t L(s)^2 \ ds & = \sigma(t)^2 \newline
\int_0^t L(s)^2 \ ds & = \frac{\sigma(t)^2}{\alpha(t)^2} \newline
L(t)^2 & = \frac{d}{dt} \left(\frac{\sigma(t)^2}{\alpha(t)^2}\right) \newline
L(t) & = \sqrt{\frac{d}{dt} \left(\frac{\sigma(t)^2}{\alpha(t)^2}\right)} \newline
\end{align*} α ( t ) 2 ∫ 0 t L ( s ) 2 d s ∫ 0 t L ( s ) 2 d s L ( t ) 2 L ( t ) = σ ( t ) 2 = α ( t ) 2 σ ( t ) 2 = d t d ( α ( t ) 2 σ ( t ) 2 ) = d t d ( α ( t ) 2 σ ( t ) 2 )
Summary
To construct an SDE with,
Mean: E [ x ( t ) ] = α ( t ) x 0 \mathbb{E}[x(t)] = \alpha(t) x_0 E [ x ( t )] = α ( t ) x 0 ,
Variance: V a r [ x ( t ) ] = σ ( t ) 2 \mathrm{Var}[x(t)] = \sigma(t)^2 Var [ x ( t )] = σ ( t ) 2 ,
We use the SDE,
d x ( t ) = ( α ˙ ( t ) α ( t ) x ( t ) ) d t + d d t ( σ ( t ) 2 α ( t ) 2 ) d β ( t ) . dx(t) = \left(\frac{\dot{\alpha}(t)}{\alpha(t)} x(t) \right) dt + \sqrt{\frac{d}{dt} \left(\frac{\sigma(t)^2}{\alpha(t)^2}\right)} \ d \beta(t). d x ( t ) = ( α ( t ) α ˙ ( t ) x ( t ) ) d t + d t d ( α ( t ) 2 σ ( t ) 2 ) d β ( t ) .
This guarantees that the process x ( t ) x(t) x ( t ) is Gaussian with the desired mean and variance.
Given an SDE and target properties
Suppose we are given an SDE of the form,
d x ( t ) = α ˙ ( t ) α ( t ) x ( t ) d t + α ( t ) L ( t ) d β ( t ) , x ( 0 ) = x 0 , dx(t) = \frac{\dot{\alpha}(t)}{\alpha(t)} x(t) \ dt + \alpha(t) L(t) \ d \beta(t), \quad x(0) = x_0, d x ( t ) = α ( t ) α ˙ ( t ) x ( t ) d t + α ( t ) L ( t ) d β ( t ) , x ( 0 ) = x 0 ,
with,
Mean: E [ x ( t ) ] = ( 1 − t ) x 0 \mathbb{E}[x(t)] = (1 - t) x_0 E [ x ( t )] = ( 1 − t ) x 0 ,
Target Variance: V a r [ x ( t ) ] = t 2 \mathrm{Var}[x(t)] = t^2 Var [ x ( t )] = t 2 .
Using α ( t ) = 1 − t \alpha(t) = 1 - t α ( t ) = 1 − t we have,
α ˙ ( t ) = − 1 , α ˙ ( t ) α ( t ) = − 1 1 − t . \dot{\alpha}(t) = -1, \quad \frac{\dot{\alpha}(t)}{\alpha(t)} = -\frac{1}{1 - t}. α ˙ ( t ) = − 1 , α ( t ) α ˙ ( t ) = − 1 − t 1 .
Thus, the SDE becomes,
d x ( t ) = − 1 1 − t x ( t ) d t + ( 1 − t ) L ( t ) d β ( t ) . dx(t) = -\frac{1}{1 - t} x(t) \ dt + (1 - t) L(t) \ d \beta(t). d x ( t ) = − 1 − t 1 x ( t ) d t + ( 1 − t ) L ( t ) d β ( t ) .
By doing the same “trick” as before, we find that,
L 2 ( t ) = d d t ( t 2 ( 1 − t ) 2 ) = 2 t ( 1 − t ) 2 + 2 t 2 ( 1 − t ) ( 1 − t ) 4 = 2 t [ ( 1 − t ) 2 + t ( 1 − t ) ] ( 1 − t ) 4 = 2 t ( 1 − t ) ( 1 − t ) 4 = 2 t ( 1 − t ) 3 . \begin{align*}
L^2(t) & = \frac{d}{dt} \left(\frac{t^2}{(1 - t)^2}\right) \newline
& = \frac{2t(1 - t)^2 + 2t^2(1 - t)}{(1 - t)^4} \newline
& = \frac{2t[(1 - t)^2 + t(1 - t)]}{(1 - t)^4} \newline
& = \frac{2t(1 - t)}{(1 - t)^4} \newline
& = \frac{2t}{(1 - t)^3}.
\end{align*} L 2 ( t ) = d t d ( ( 1 − t ) 2 t 2 ) = ( 1 − t ) 4 2 t ( 1 − t ) 2 + 2 t 2 ( 1 − t ) = ( 1 − t ) 4 2 t [( 1 − t ) 2 + t ( 1 − t )] = ( 1 − t ) 4 2 t ( 1 − t ) = ( 1 − t ) 3 2 t .
So,
L ( t ) = 2 t ( 1 − t ) 3 , L(t) = \sqrt{\frac{2t}{(1 - t)^3}}, L ( t ) = ( 1 − t ) 3 2 t ,
we can also further simplify the diffusion term to,
( 1 − t ) L ( t ) = ( 1 − t ) 2 t ( 1 − t ) 3 = 2 t ( 1 − t ) . (1 - t) L(t) = (1 - t) \sqrt{\frac{2t}{(1 - t)^3}} = \sqrt{\frac{2t}{(1 - t)}}. ( 1 − t ) L ( t ) = ( 1 − t ) ( 1 − t ) 3 2 t = ( 1 − t ) 2 t .
So the SDE becomes,
d x ( t ) = − 1 1 − t x ( t ) d t + 2 t ( 1 − t ) d β ( t ) . dx(t) = -\frac{1}{1 - t} x(t) \ dt + \sqrt{\frac{2t}{(1 - t)}} \ d \beta(t). d x ( t ) = − 1 − t 1 x ( t ) d t + ( 1 − t ) 2 t d β ( t ) .
This is a fully simplified form valid for t < 1 t < 1 t < 1 which produces,
Mean: E [ x ( t ) ] = ( 1 − t ) x 0 \mathbb{E}[x(t)] = (1 - t) x_0 E [ x ( t )] = ( 1 − t ) x 0 ,
Variance: V a r [ x ( t ) ] = t 2 \mathrm{Var}[x(t)] = t^2 Var [ x ( t )] = t 2 .
The process is Gaussian, fully characterized by these expressions.
Flow and Diffusion Models
Consider the SDE in R d \mathbb{R}^d R d ,
d x ( t ) = f θ ( x ( t ) , t ) d t + σ t d β ( t ) , d \mathbf{x}(t) = \mathbf{f}^{\theta}(\mathbf{x}(t), t) \ dt + \sigma_t \ d \beta(t), d x ( t ) = f θ ( x ( t ) , t ) d t + σ t d β ( t ) ,
Our goal is to, given x ( 0 ) ∼ p init ( x ( 0 ) ) \mathbf{x}(0) \sim p_{\text{init}}(\mathbf{x}(0)) x ( 0 ) ∼ p init ( x ( 0 )) , make sure that x ( 1 ) ∼ p data ( x ( 1 ) ) \mathbf{x}(1) \sim p_{\text{data}}(\mathbf{x}(1)) x ( 1 ) ∼ p data ( x ( 1 ))
Note (Terminology) σ t = 0 \sigma_t = 0 σ t = 0 Flow model : The SDE is reduced to an ODE (e.g., Flow Matching).
σ t ≠ 0 \sigma_t \neq 0 σ t = 0 Diffusion model : The SDE is a stochastic process (e.g., Denoising Diffusion Probabilistic Models).
The forward and backward processes of the diffusion model.
Consider the SDE in R d \mathbb{R}^d R d ,
d x ( t ) = f θ ( x ( t ) , t ) d t + σ t d β ( t ) , d \mathbf{x}(t) = \mathbf{f}^{\theta}(\mathbf{x}(t), t) \ dt + \sigma_t \ d \beta(t), d x ( t ) = f θ ( x ( t ) , t ) d t + σ t d β ( t ) ,
Suppose x ( 0 ) ∼ p 0 = p init \mathbf{x}(0) \sim p_0 = p_{\text{init}} x ( 0 ) ∼ p 0 = p init at time t = 0 t = 0 t = 0 where p init p_{\text{init}} p init is a simple distribution like a Gaussian.
We would like x ( 1 ) ∼ p 1 = p data \mathbf{x}(1) \sim p_1 = p_{\text{data}} x ( 1 ) ∼ p 1 = p data at time t = 1 t = 1 t = 1 where p data p_{\text{data}} p data is some complicated data distribution.
We’ll study how to,
Construct an (ideal) vector field f target \mathbf{f}^{\text{target}} f target and σ t \sigma_t σ t .
Train such that f θ ≈ f target \mathbf{f}^{\theta} \approx \mathbf{f}^{\text{target}} f θ ≈ f target
We can then sample from p init p_{\text{init}} p init , run it through the SDE and obtain a sample from p data p_{\text{data}} p data .
Method Overview — 4 Steps
Start by constructing a probability path p t ( ⋅ ∣ z ) p_t(\cdot | \mathbf{z}) p t ( ⋅ ∣ z ) for individual samples z z z which interpolates from noise p 0 ( ⋅ ∣ z ) = p init ( ⋅ ) p_0(\cdot | \mathbf{z}) = p_{\text{init}}(\cdot) p 0 ( ⋅ ∣ z ) = p init ( ⋅ ) to data sample z z z , p 1 ( ⋅ ∣ z ) = δ z ( ⋅ ) p_1(\cdot | \mathbf{z}) = \delta_{\mathbf{z}}(\cdot) p 1 ( ⋅ ∣ z ) = δ z ( ⋅ )
Marginalization gives us,
p t ( x ) = ∫ p t ( x ∣ z ) p data ( z ) d z . p_t(\mathbf{x}) = \int p_t(\mathbf{x} | \mathbf{z}) p_{\text{data}}(\mathbf{z}) \ d \mathbf{z}. p t ( x ) = ∫ p t ( x ∣ z ) p data ( z ) d z .
Construct an (ideal) SDE which simulates p t p_t p t .
Train a neural network that approximates this SDE.
Step 1: Constructing the Probability Path
Start by constructing,
p t ( ⋅ ∣ z ) , p_t(\cdot | \mathbf{z}), p t ( ⋅ ∣ z ) ,
for individual samples z z z from data distribution with conditions,
p 0 ( ⋅ ∣ z ) = p init ( ⋅ ) and p 1 ( ⋅ ∣ z ) = δ z ( ⋅ ) for all z ∈ R d . p_0(\cdot | \mathbf{z}) = p_{\text{init}}(\cdot) \text{ and } p_1(\cdot | \mathbf{z}) = \delta_{\mathbf{z}}(\cdot) \text{ for all } \mathbf{z} \in \mathbb{R}^d. p 0 ( ⋅ ∣ z ) = p init ( ⋅ ) and p 1 ( ⋅ ∣ z ) = δ z ( ⋅ ) for all z ∈ R d .
Such a path p t ( ⋅ ∣ z ) p_t(\cdot | \mathbf{z}) p t ( ⋅ ∣ z ) is called a conditional probability path .
Step 2: Marginalization
Each conditional path induces a marginal probability path ,
p t ( x ) = ∫ p t ( x ∣ z ) p data ( z ) d z . p_t(\mathbf{x}) = \int p_t(\mathbf{x} | \mathbf{z}) p_{\text{data}}(\mathbf{z}) \ d \mathbf{z}. p t ( x ) = ∫ p t ( x ∣ z ) p data ( z ) d z .
We can sample from p t p_t p t via,
z ∼ p data , x ∼ p t ( ⋅ ∣ z ) ⟹ x ∼ p t . \mathbf{z} \sim p_{\text{data}}, \quad \mathbf{x} \sim p_t(\cdot | \mathbf{z}) \implies \mathbf{x} \sim p_t. z ∼ p data , x ∼ p t ( ⋅ ∣ z ) ⟹ x ∼ p t .
Note we do not know the density values of p t ( x ) p_t(\mathbf{x}) p t ( x ) as the integral is intractable.
The marginal probability path p t ( x ) p_t(\mathbf{x}) p t ( x ) interpolates between p init p_{\text{init}} p init and p data p_{\text{data}} p data ,
p 0 ( x ) = p init ( x ) and p 1 ( x ) = p data ( x ) . p_0(\mathbf{x}) = p_{\text{init}}(\mathbf{x}) \text{ and } p_1(\mathbf{x}) = p_{\text{data}}(\mathbf{x}). p 0 ( x ) = p init ( x ) and p 1 ( x ) = p data ( x ) .
Tip (Example: Gaussian) Define the path,
p t ( ⋅ ∣ z ) = N ( α t z , σ t 2 I d ) , p_t(\cdot | \mathbf{z}) = \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d), p t ( ⋅ ∣ z ) = N ( α t z , σ t 2 I d ) ,
with α 0 = σ 1 = 0 \alpha_0 = \sigma_1 = 0 α 0 = σ 1 = 0 and α 1 = σ 0 = 1 \alpha_1 = \sigma_0 = 1 α 1 = σ 0 = 1 , or in other words, we move from a N ( 0 , 1 ) \mathcal{N}(0, 1) N ( 0 , 1 ) to a N ( z , 0 ) \mathcal{N}(\mathbf{z}, 0) N ( z , 0 ) .
Note, α t \alpha_t α t and σ t \sigma_t σ t should be continuous and differentiable functions in t t t .
Sampling from the marginal path p t p_t p t ,
z ∼ p data , ϵ ∼ p init = N ( 0 , I d ) ⟹ x ( t ) = α t z + σ t ϵ ∼ p t . \mathbf{z} \sim p_{\text{data}}, \epsilon \sim p_{\text{init}} = \mathcal{N}(0, \mathbf{I}_d) \implies \mathbf{x}(t) = \alpha_t \mathbf{z} + \sigma_t \epsilon \sim p_t. z ∼ p data , ϵ ∼ p init = N ( 0 , I d ) ⟹ x ( t ) = α t z + σ t ϵ ∼ p t .
Step 3: Constructing the SDE
First, we construct analytically an SDE for simulating the conditional probability path.
Then, via the so-called “marginalization trick”, we obtain an SDE for the marginal probability path.
Recall our notation for an SDE d x ( t ) = f ( x ( t ) , t ) d t + L ( t ) d β ( t ) d\mathbf{x}(t) = \mathbf{f}(\mathbf{x}(t), t) \ dt + L(t) d \ \beta(t) d x ( t ) = f ( x ( t ) , t ) d t + L ( t ) d β ( t ) .
Tip (Example: Gaussian) For the path,
p t ( ⋅ ∣ z ) = N ( α t z , σ t 2 I d ) , p_t(\cdot | \mathbf{z}) = \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d), p t ( ⋅ ∣ z ) = N ( α t z , σ t 2 I d ) , with α 0 = σ 1 = 0 \alpha_0 = \sigma_1 = 0 α 0 = σ 1 = 0 and α 1 = σ 0 = 1 \alpha_1 = \sigma_0 = 1 α 1 = σ 0 = 1 , then the following ODE,
f ( x ( t ) , t ) = ( α ˙ t − σ ˙ t σ t α t ) z + σ ˙ t σ t x ( t ) and L ( t ) = 0 , \mathbf{f}(\mathbf{x}(t), t) = (\dot{\alpha}_t - \frac{\dot{\sigma}_t}{\sigma_t} \alpha_t) \mathbf{z} + \frac{\dot{\sigma}_t}{\sigma_t} \mathbf{x}(t) \text{ and } L(t) = 0, f ( x ( t ) , t ) = ( α ˙ t − σ t σ ˙ t α t ) z + σ t σ ˙ t x ( t ) and L ( t ) = 0 ,
simulates the conditional probability path in the sense that x ( t ) ∼ N ( α t z , σ t 2 I d \mathbf{x}(t) \sim \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d x ( t ) ∼ N ( α t z , σ t 2 I d .
We can check this with,
By definition of mean and variance of an SDE,
The Fokker-Planck-Kolmogorov equation.
Marginalization Trick
What is the SDE for the marginal path?
Note (Theorem) If the ODE with vector field f target ( x ( t ) , t ∣ z ) \mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) f target ( x ( t ) , t ∣ z ) yields the conditional path p t ( x ( t ) , z ) p_t(\mathbf{x}(t), \mathbf{z}) p t ( x ( t ) , z ) then the marginal field,
f target ( x ( t ) , t ) = ∫ f target ( x ( t ) , t ∣ z ) p t ( x ( t ) ∣ z ) p data ( z ) p t ( x ( t ) ) ⏟ p t ( z ∣ x ( t ) ) d z , \mathbf{f}^{\text{target}}(\mathbf{x}(t), t) = \int \mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) \underbrace{\frac{p_t(\mathbf{x}(t) | \mathbf{z}) p_{\text{data}}(\mathbf{z})}{p_t(\mathbf{x}(t))}}_{p_t(\mathbf{z} | \mathbf{x}(t))} \ d \mathbf{z}, f target ( x ( t ) , t ) = ∫ f target ( x ( t ) , t ∣ z ) p t ( z ∣ x ( t )) p t ( x ( t )) p t ( x ( t ) ∣ z ) p data ( z ) d z ,
yields the marginal path p t ( x ( t ) ) p_t(\mathbf{x}(t)) p t ( x ( t )) .
Hence, using f target ( x ( t ) , t ) \mathbf{f}^{\text{target}}(\mathbf{x}(t), t) f target ( x ( t ) , t ) yields a good path — it turns noise into data! (sadly, it is unknown).
The proof? Fokker-Planck-Kolmogorov equation!
Step 4: Training the Neural Network
A possible loss function is,
L F M = E t ∼ Unif , x ( t ) ∼ p t [ ∥ f θ ( x ( t ) , t ) − f target ( x ( t ) , t ) ∥ 2 ] \mathcal{L}_{FM} = \mathbb{E}_{t \sim \text{Unif}, \mathbf{x}(t) \sim p_t} \left[ \Vert \mathbf{f}^{\theta}(\mathbf{x}(t), t) - \mathbf{f}^{\text{target}}(\mathbf{x}(t), t) \Vert^2 \right] L F M = E t ∼ Unif , x ( t ) ∼ p t [ ∥ f θ ( x ( t ) , t ) − f target ( x ( t ) , t ) ∥ 2 ]
However, f target ( x ( t ) , t ) \mathbf{f}^{\text{target}}(\mathbf{x}(t), t) f target ( x ( t ) , t ) is not explicitly available.
Note (Theorem) For,
L F M = E t ∼ Unif , x ( t ) ∼ p t [ ∥ f θ ( x ( t ) , t ) − f target ( x ( t ) , t ) ∥ 2 ] , L C F M = E t ∼ Unif , z ∼ p data , x ∼ p t ( ⋅ ∣ z ) [ ∥ f θ ( x ( t ) , t ) − f target ( x ( t ) , t ∣ z ) ∥ 2 ] , \begin{align*}
\mathcal{L}_{FM} & = \mathbb{E}_{t \sim \text{Unif}, \mathbf{x}(t) \sim p_t} \left[ \Vert \mathbf{f}^{\theta}(\mathbf{x}(t), t) - \mathbf{f}^{\text{target}}(\mathbf{x}(t), t) \Vert^2 \right], \newline
\mathcal{L}_{CFM} & = \mathbb{E}_{t \sim \text{Unif}, \mathbf{z} \sim p_{\text{data}}, \mathbf{x} \sim p_t(\cdot | \mathbf{z})} \left[ \Vert \mathbf{f}^{\theta}(\mathbf{x}(t), t) - \mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) \Vert^2 \right],
\end{align*} L F M L C F M = E t ∼ Unif , x ( t ) ∼ p t [ ∥ f θ ( x ( t ) , t ) − f target ( x ( t ) , t ) ∥ 2 ] , = E t ∼ Unif , z ∼ p data , x ∼ p t ( ⋅ ∣ z ) [ ∥ f θ ( x ( t ) , t ) − f target ( x ( t ) , t ∣ z ) ∥ 2 ] , it holds that ∇ θ L F M = ∇ θ L C F M \nabla_{\theta} \mathcal{L}_{FM} = \nabla_{\theta} \mathcal{L}_{CFM} ∇ θ L F M = ∇ θ L C F M .
Hence, we can use L C F M \mathcal{L}_{CFM} L C F M to train f θ ( x ( t ) , t ) \mathbf{f}^{\theta}(\mathbf{x}(t), t) f θ ( x ( t ) , t ) (this is the main takeaway from the original flow matching paper 2 ).
Tip (Example: Gaussian) The path p t ( ⋅ ∣ z ) = N ( α t z , σ t 2 I d ) p_t(\cdot | \mathbf{z}) = \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d) p t ( ⋅ ∣ z ) = N ( α t z , σ t 2 I d ) is simulated via the ODE with vector field,
f target ( x ( t ) , t ∣ z ) = ( α ˙ t − σ ˙ t σ t α t ) z + σ ˙ t σ t x ( t ) . \mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) = (\dot{\alpha}_t - \frac{\dot{\sigma}_t}{\sigma_t} \alpha_t) \mathbf{z} + \frac{\dot{\sigma}_t}{\sigma_t} \mathbf{x}(t). f target ( x ( t ) , t ∣ z ) = ( α ˙ t − σ t σ ˙ t α t ) z + σ t σ ˙ t x ( t ) . A natural choice is α ( t ) = t \alpha(t) = t α ( t ) = t and σ ( t ) = 1 − t \sigma(t) = 1 - t σ ( t ) = 1 − t , which gives α ˙ ( t ) = 1 \dot{\alpha}(t) = 1 α ˙ ( t ) = 1 and σ ˙ ( t ) = − 1 \dot{\sigma}(t) = -1 σ ˙ ( t ) = − 1 .
It yields the loss function,
L C F M = E t ∼ Unif , z ∼ p data , x ∼ N ( α t z , σ t 2 I d ) [ ∥ f θ ( x ( t ) , t ) − f target ( x ( t ) , t ∣ z ) ∥ 2 ] = E t ∼ Unif , z ∼ p data , x ∼ N ( 0 , I d ) [ ∥ f θ ( t z + ( 1 − t ) ϵ , t ) − ( z − ϵ ) ∥ 2 ] \begin{align*}
\mathcal{L}_{CFM} & = \mathbb{E}_{t \sim \text{Unif}, \mathbf{z} \sim p_{\text{data}}, \mathbf{x} \sim \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d)} \left[ \Vert \mathbf{f}^{\theta}(\mathbf{x}(t), t) - \mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) \Vert^2 \right] \newline
& = \mathbb{E}_{t \sim \text{Unif}, \mathbf{z} \sim p_{\text{data}}, \mathbf{x} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}_d)} \left[ \Vert \mathbf{f}^{\theta}(t \mathbf{z} + (1 - t) \epsilon, t) - (\mathbf{z} - \epsilon) \Vert^2 \right] \newline
\end{align*} L C F M = E t ∼ Unif , z ∼ p data , x ∼ N ( α t z , σ t 2 I d ) [ ∥ f θ ( x ( t ) , t ) − f target ( x ( t ) , t ∣ z ) ∥ 2 ] = E t ∼ Unif , z ∼ p data , x ∼ N ( 0 , I d ) [ ∥ f θ ( t z + ( 1 − t ) ϵ , t ) − ( z − ϵ ) ∥ 2 ]
where we have used x ( t ) = t z + ( 1 − t ) ϵ \mathbf{x}(t) = t \mathbf{z} + (1 - t) \epsilon x ( t ) = t z + ( 1 − t ) ϵ .
We have constructed an ODE for a conditional probability path. One can also construct an SDE and the same methodology is applicable.