Reverse-time Euler-Maruyama Method
In the forward SDE,
$$ dx = f(x(t), t) \ dt + L(t) \ d \beta(t), $$
we simulate forward in time using Euler-Maruyama method, also note that $L(t)$ does not depend on the state $x(t)$,
$$ x_{n + 1} = x_n + f(x_n, t_n) \Delta t + L(t_n) \sqrt{\Delta t} \ \xi_n. $$
For the reverse SDE 1,
$$ dx = [f(x(t), t) - L(t)^2 \nabla_x \log p(x(t), t)] dt + L(t) \ d \bar{\beta}(t), $$
we simulate backward in time with,
$$ x_{n - 1} = x_n - [f(x_n, t_n) - L(t_n)^2 \nabla_x \log p(x_n, t_n)] \Delta t - L(t_n) \sqrt{\Delta t} \ \xi_n. $$
The score function $\nabla_x \log p(x(t), t)$ must be estimated or known.
Reverse Euler-Maruyama for Ornstein-Uhlenbeck Process
Consider the forward SDE,
$$ dx = -\lambda x(t) \ dt + \sigma \ d \beta(t), \quad x(0) = 0. $$
The marginal distribution $p(x(t), t)$ is Gaussian with zero mean and variance,
$$ \mathrm{Var}[x(t)] = \frac{\sigma^2}{2 \lambda}(1 - e^{-2 \lambda t}). $$
The score of the process is,
$$ \begin{align*} \nabla_x \log p(x(t), t) & = \nabla_x \log \left[\frac{1}{\sqrt{2 \pi \mathrm{Var}[x(t)]}} e^{-\frac{x(t)^2}{2 \mathrm{Var}[x(t)]}} \right] \newline & = \nabla_x \left[-\frac{1}{2} \log(2 \pi) - \frac{1}{2} \log(\mathrm{Var}[x(t)]) - \frac{x(t)^2}{2 \mathrm{Var}[x(t)]} \right] \newline & = -\frac{1}{\mathrm{Var}[x(t)]} \underbrace{\nabla_x \mathrm{Var}[x(t)]}_{= 0} - \frac{x(t)}{\mathrm{Var}[x(t)]} \underbrace{\nabla_x x(t)}_{= 1} \newline & = -\frac{x(t)}{\mathrm{Var}[x(t)]}. \end{align*} $$
Which we can plug into the reverse-time Euler-Maruyama method,
$$ x_{n - 1} = x_n - \left[-\lambda x_n + \frac{\sigma^2 x_n}{\mathrm{Var}[x_n]} \right] \Delta t - \sigma \sqrt{\Delta t} \ \xi_n, $$
Denoising Diffusion Probabilistic Models
How do you design a scalar SDE of the form,
$$ dx(t) = f(x(t), t) \ dt + L(t) \ d \beta(t), $$
that produces a process $x(t)$ with,
- Mean: $\mathbb{E}[x(t)] = \alpha(t) x_0$,
- Variance: $\mathrm{Var}[x(t)] = \sigma(t)^2$,
- Initial Conditions: $x(0) = x_0, \alpha(0) = 1, \sigma(0) = 0$?
Let’s start with the mean condition, recall that,
$$ \begin{equation} \frac{dm(t)}{dt} = \frac{d\mathbb{E}[x(t)]}{dt} = \mathbb{E}[f(x(t), t)]. \end{equation} $$
Thus, from our ansatz,
$$ \begin{equation} \mathbb{E}[x(t)] = \alpha(t) x_0 \implies \frac{d \mathbb{E}[x(t)]}{dt} = \dot{\alpha}(t) x_0. \end{equation} $$
A natural choice for $f(x(t), t)$ is a simple affine function of the form,
$$ f(x(t), t) = a(t) x(t) + b(t), $$
Taking the expectation yields,
$$ \begin{align*} \mathbb{E}[f(x(t), t)] & = a(t) \mathbb{E}[x(t)] + b(t) \newline & = a(t) \alpha(t) x_0 + b(t). \end{align*} $$
If we, e.g., set $b(t) = 0$ then we can solve for $a(t)$ from Equation (1) and (2) to get,
$$ \begin{align*} a(t) \alpha(t) x_0 & = \dot{\alpha}(t) x_0 \newline a(t) & = \frac{\dot{\alpha}(t)}{\alpha(t)} \left(= \frac{d}{dt} \log(\alpha(t))\right). \end{align*} $$
Thus, we have the linear SDE of the form,
$$ dx(t) = \left(\frac{\dot{\alpha}(t)}{\alpha(t)} x(t) \right) dt + L(t) \ d \beta(t). $$
Now, one can check that the mean condition is satisfied by this SDE, but we also have a variance condition to satisfy, thus, let’s consider the linear SDE of the form,
$$ dx(t) = \left(\frac{\dot{\alpha}(t)}{\alpha(t)} x(t) \right) dt + \underbrace{\alpha(t)}_{\text{new}} L(t) \ d \beta(t). $$
The solution to this SDE is,
$$ x(t) = \alpha(t) \underbrace{\left(x_0 + \int_0^t L(s) \ d \beta(s) \right)}_{y(t)}. $$
Thus,
$$ d(y(t)) = L(t) \ d \beta(t), $$
By definition of the variance,
$$ \begin{align*} \mathrm{Var}[x(t)] & = \mathbb{E}[(\underbrace{x(t)}_{= \alpha(t) y(t)} - \underbrace{\mathbb{E}[x(t)]}_{= \alpha(t) x_0})^2] \newline & = \mathbb{E}[(\alpha(t) y(t) - \alpha(t) x_0)^2] \newline & = \mathbb{E}[((\underbrace{\alpha(t)}_{\text{deterministic}})(y(t) - \underbrace{x_0}_{\text{deterministic}}))^2] \newline & = \alpha(t)^2 \int_0^t L(s)^2 \ ds \newline \end{align*} $$
In our case, we want our variance to be $\sigma(t)^2$, thus we can set,
$$ \begin{align*} \alpha(t)^2 \int_0^t L(s)^2 \ ds & = \sigma(t)^2 \newline \int_0^t L(s)^2 \ ds & = \frac{\sigma(t)^2}{\alpha(t)^2} \newline L(t)^2 & = \frac{d}{dt} \left(\frac{\sigma(t)^2}{\alpha(t)^2}\right) \newline L(t) & = \sqrt{\frac{d}{dt} \left(\frac{\sigma(t)^2}{\alpha(t)^2}\right)} \newline \end{align*} $$
Summary
To construct an SDE with,
- Mean: $\mathbb{E}[x(t)] = \alpha(t) x_0$,
- Variance: $\mathrm{Var}[x(t)] = \sigma(t)^2$,
We use the SDE,
$$ dx(t) = \left(\frac{\dot{\alpha}(t)}{\alpha(t)} x(t) \right) dt + \sqrt{\frac{d}{dt} \left(\frac{\sigma(t)^2}{\alpha(t)^2}\right)} \ d \beta(t). $$
This guarantees that the process $x(t)$ is Gaussian with the desired mean and variance.
Given an SDE and target properties
Suppose we are given an SDE of the form,
$$ dx(t) = \frac{\dot{\alpha}(t)}{\alpha(t)} x(t) \ dt + \alpha(t) L(t) \ d \beta(t), \quad x(0) = x_0, $$
with,
- Mean: $\mathbb{E}[x(t)] = (1 - t) x_0$,
- Target Variance: $\mathrm{Var}[x(t)] = t^2$.
Using $\alpha(t) = 1 - t$ we have,
$$ \dot{\alpha}(t) = -1, \quad \frac{\dot{\alpha}(t)}{\alpha(t)} = -\frac{1}{1 - t}. $$
Thus, the SDE becomes,
$$ dx(t) = -\frac{1}{1 - t} x(t) \ dt + (1 - t) L(t) \ d \beta(t). $$
By doing the same “trick” as before, we find that,
$$ \begin{align*} L^2(t) & = \frac{d}{dt} \left(\frac{t^2}{(1 - t)^2}\right) \newline & = \frac{2t(1 - t)^2 + 2t^2(1 - t)}{(1 - t)^4} \newline & = \frac{2t[(1 - t)^2 + t(1 - t)]}{(1 - t)^4} \newline & = \frac{2t(1 - t)}{(1 - t)^4} \newline & = \frac{2t}{(1 - t)^3}. \end{align*} $$
So,
$$ L(t) = \sqrt{\frac{2t}{(1 - t)^3}}, $$
we can also further simplify the diffusion term to,
$$ (1 - t) L(t) = (1 - t) \sqrt{\frac{2t}{(1 - t)^3}} = \sqrt{\frac{2t}{(1 - t)}}. $$
So the SDE becomes,
$$ dx(t) = -\frac{1}{1 - t} x(t) \ dt + \sqrt{\frac{2t}{(1 - t)}} \ d \beta(t). $$
This is a fully simplified form valid for $t < 1$ which produces,
- Mean: $\mathbb{E}[x(t)] = (1 - t) x_0$,
- Variance: $\mathrm{Var}[x(t)] = t^2$.
The process is Gaussian, fully characterized by these expressions.
Flow and Diffusion Models
Consider the SDE in $\mathbb{R}^d$,
$$ d \mathbf{x}(t) = \mathbf{f}^{\theta}(\mathbf{x}(t), t) \ dt + \sigma_t \ d \beta(t), $$
Our goal is to, given $\mathbf{x}(0) \sim p_{\text{init}}(\mathbf{x}(0))$, make sure that $\mathbf{x}(1) \sim p_{\text{data}}(\mathbf{x}(1))$
Terminology
$\sigma_t = 0$ Flow model: The SDE is reduced to an ODE (e.g., Flow Matching). $\sigma_t \neq 0$ Diffusion model: The SDE is a stochastic process (e.g., Denoising Diffusion Probabilistic Models).
Consider the SDE in $\mathbb{R}^d$,
$$ d \mathbf{x}(t) = \mathbf{f}^{\theta}(\mathbf{x}(t), t) \ dt + \sigma_t \ d \beta(t), $$
Suppose $\mathbf{x}(0) \sim p_0 = p_{\text{init}}$ at time $t = 0$ where $p_{\text{init}}$ is a simple distribution like a Gaussian.
We would like $\mathbf{x}(1) \sim p_1 = p_{\text{data}}$ at time $t = 1$ where $p_{\text{data}}$ is some complicated data distribution.
We’ll study how to,
- Construct an (ideal) vector field $\mathbf{f}^{\text{target}}$ and $\sigma_t$.
- Train such that $\mathbf{f}^{\theta} \approx \mathbf{f}^{\text{target}}$
We can then sample from $p_{\text{init}}$, run it through the SDE and obtain a sample from $p_{\text{data}}$.
Method Overview — 4 Steps
- Start by constructing a probability path $p_t(\cdot | \mathbf{z})$ for individual samples $z$ which interpolates from noise $p_0(\cdot | \mathbf{z}) = p_{\text{init}}(\cdot)$ to data sample $z$, $p_1(\cdot | \mathbf{z}) = \delta_{\mathbf{z}}(\cdot)$
- Marginalization gives us, $$ p_t(\mathbf{x}) = \int p_t(\mathbf{x} | \mathbf{z}) p_{\text{data}}(\mathbf{z}) \ d \mathbf{z}. $$
- Construct an (ideal) SDE which simulates $p_t$.
- Train a neural network that approximates this SDE.
Step 1: Constructing the Probability Path
Start by constructing,
$$ p_t(\cdot | \mathbf{z}), $$
for individual samples $z$ from data distribution with conditions,
$$ p_0(\cdot | \mathbf{z}) = p_{\text{init}}(\cdot) \text{ and } p_1(\cdot | \mathbf{z}) = \delta_{\mathbf{z}}(\cdot) \text{ for all } \mathbf{z} \in \mathbb{R}^d. $$
Such a path $p_t(\cdot | \mathbf{z})$ is called a conditional probability path.
Step 2: Marginalization
Each conditional path induces a marginal probability path,
$$ p_t(\mathbf{x}) = \int p_t(\mathbf{x} | \mathbf{z}) p_{\text{data}}(\mathbf{z}) \ d \mathbf{z}. $$
We can sample from $p_t$ via,
$$ \mathbf{z} \sim p_{\text{data}}, \quad \mathbf{x} \sim p_t(\cdot | \mathbf{z}) \implies \mathbf{x} \sim p_t. $$
Note we do not know the density values of $p_t(\mathbf{x})$ as the integral is intractable.
The marginal probability path $p_t(\mathbf{x})$ interpolates between $p_{\text{init}}$ and $p_{\text{data}}$,
$$ p_0(\mathbf{x}) = p_{\text{init}}(\mathbf{x}) \text{ and } p_1(\mathbf{x}) = p_{\text{data}}(\mathbf{x}). $$
Example: Gaussian
Define the path, $$ p_t(\cdot | \mathbf{z}) = \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d), $$
with $\alpha_0 = \sigma_1 = 0$ and $\alpha_1 = \sigma_0 = 1$, or in other words, we move from a $\mathcal{N}(0, 1)$ to a $\mathcal{N}(\mathbf{z}, 0)$.
Note, $\alpha_t$ and $\sigma_t$ should be continuous and differentiable functions in $t$.
Sampling from the marginal path $p_t$,
$$ \mathbf{z} \sim p_{\text{data}}, \epsilon \sim p_{\text{init}} = \mathcal{N}(0, \mathbf{I}_d) \implies \mathbf{x}(t) = \alpha_t \mathbf{z} + \sigma_t \epsilon \sim p_t. $$
Step 3: Constructing the SDE
First, we construct analytically an SDE for simulating the conditional probability path.
Then, via the so-called “marginalization trick”, we obtain an SDE for the marginal probability path.
Recall our notation for an SDE $d\mathbf{x}(t) = \mathbf{f}(\mathbf{x}(t), t) \ dt + L(t) d \ \beta(t)$.
Example: Gaussian
For the path, $$ p_t(\cdot | \mathbf{z}) = \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d), $$ with $\alpha_0 = \sigma_1 = 0$ and $\alpha_1 = \sigma_0 = 1$, then the following ODE, $$ \mathbf{f}(\mathbf{x}(t), t) = (\dot{\alpha}_t - \frac{\dot{\sigma}_t}{\sigma_t} \alpha_t) \mathbf{z} + \frac{\dot{\sigma}_t}{\sigma_t} \mathbf{x}(t) \text{ and } L(t) = 0, $$
simulates the conditional probability path in the sense that $\mathbf{x}(t) \sim \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d$.
We can check this with,
- By definition of mean and variance of an SDE,
- The Fokker-Planck-Kolmogorov equation.
Marginalization Trick
What is the SDE for the marginal path?
Theorem
If the ODE with vector field $\mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z})$ yields the conditional path $p_t(\mathbf{x}(t), \mathbf{z})$ then the marginal field, $$ \mathbf{f}^{\text{target}}(\mathbf{x}(t), t) = \int \mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) \underbrace{\frac{p_t(\mathbf{x}(t) | \mathbf{z}) p_{\text{data}}(\mathbf{z})}{p_t(\mathbf{x}(t))}}_{p_t(\mathbf{z} | \mathbf{x}(t))} \ d \mathbf{z}, $$
yields the marginal path $p_t(\mathbf{x}(t))$.
Hence, using $\mathbf{f}^{\text{target}}(\mathbf{x}(t), t)$ yields a good path — it turns noise into data! (sadly, it is unknown).
The proof? Fokker-Planck-Kolmogorov equation!
Step 4: Training the Neural Network
A possible loss function is, $$ \mathcal{L}_{FM} = \mathbb{E}_{t \sim \text{Unif}, \mathbf{x}(t) \sim p_t} \left[ \Vert \mathbf{f}^{\theta}(\mathbf{x}(t), t) - \mathbf{f}^{\text{target}}(\mathbf{x}(t), t) \Vert^2 \right] $$
However, $\mathbf{f}^{\text{target}}(\mathbf{x}(t), t)$ is not explicitly available.
Theorem
For, $$ \begin{align*} \mathcal{L}_{FM} & = \mathbb{E}_{t \sim \text{Unif}, \mathbf{x}(t) \sim p_t} \left[ \Vert \mathbf{f}^{\theta}(\mathbf{x}(t), t) - \mathbf{f}^{\text{target}}(\mathbf{x}(t), t) \Vert^2 \right], \newline \mathcal{L}_{CFM} & = \mathbb{E}_{t \sim \text{Unif}, \mathbf{z} \sim p_{\text{data}}, \mathbf{x} \sim p_t(\cdot | \mathbf{z})} \left[ \Vert \mathbf{f}^{\theta}(\mathbf{x}(t), t) - \mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) \Vert^2 \right], \end{align*} $$ it holds that $\nabla_{\theta} \mathcal{L}_{FM} = \nabla_{\theta} \mathcal{L}_{CFM}$.
Hence, we can use $\mathcal{L}_{CFM}$ to train $\mathbf{f}^{\theta}(\mathbf{x}(t), t)$ (this is the main takeaway from the original flow matching paper 2).
Example: Gaussian
The path $p_t(\cdot | \mathbf{z}) = \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d)$ is simulated via the ODE with vector field, $$ \mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) = (\dot{\alpha}_t - \frac{\dot{\sigma}_t}{\sigma_t} \alpha_t) \mathbf{z} + \frac{\dot{\sigma}_t}{\sigma_t} \mathbf{x}(t). $$ A natural choice is $\alpha(t) = t$ and $\sigma(t) = 1 - t$, which gives $\dot{\alpha}(t) = 1$ and $\dot{\sigma}(t) = -1$. It yields the loss function,
$$ \begin{align*} \mathcal{L}_{CFM} & = \mathbb{E}_{t \sim \text{Unif}, \mathbf{z} \sim p_{\text{data}}, \mathbf{x} \sim \mathcal{N}(\alpha_t \mathbf{z}, \sigma_t^2 \mathbf{I}_d)} \left[ \Vert \mathbf{f}^{\theta}(\mathbf{x}(t), t) - \mathbf{f}^{\text{target}}(\mathbf{x}(t), t | \mathbf{z}) \Vert^2 \right] \newline & = \mathbb{E}_{t \sim \text{Unif}, \mathbf{z} \sim p_{\text{data}}, \mathbf{x} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}_d)} \left[ \Vert \mathbf{f}^{\theta}(t \mathbf{z} + (1 - t) \epsilon, t) - (\mathbf{z} - \epsilon) \Vert^2 \right] \newline \end{align*} $$
where we have used $\mathbf{x}(t) = t \mathbf{z} + (1 - t) \epsilon$.
We have constructed an ODE for a conditional probability path. One can also construct an SDE and the same methodology is applicable.
Footnotes
-
Reverse-Time Diffusion Equation Models by Anderson. ↩
-
Flow Matching for Generative Modeling by Lipman et al. ↩