File size: 8,605 Bytes
3527383 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
from torchdiffeq import odeint
from flow_matching.solver.solver import Solver
from flow_matching.utils import gradient, ModelWrapper
class ODESolver(Solver):
"""A class to solve ordinary differential equations (ODEs) using a specified velocity model.
This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.
Args:
velocity_model (Union[ModelWrapper, Callable]): a velocity field model receiving :math:`(x,t)` and returning :math:`u_t(x)`
"""
def __init__(self, velocity_model: Union[ModelWrapper, Callable]):
super().__init__()
self.velocity_model = velocity_model
def sample(
self,
x_init: Tensor,
step_size: Optional[float],
method: str = "euler",
atol: float = 1e-5,
rtol: float = 1e-5,
time_grid: Tensor = torch.tensor([0.0, 1.0]),
return_intermediates: bool = False,
enable_grad: bool = False,
**model_extras,
) -> Union[Tensor, Sequence[Tensor]]:
r"""Solve the ODE with the velocity field.
Example:
.. code-block:: python
import torch
from flow_matching.utils import ModelWrapper
from flow_matching.solver import ODESolver
class DummyModel(ModelWrapper):
def __init__(self):
super().__init__(None)
def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
return torch.ones_like(x) * 3.0 * t**2
velocity_model = DummyModel()
solver = ODESolver(velocity_model=velocity_model)
x_init = torch.tensor([0.0, 0.0])
step_size = 0.001
time_grid = torch.tensor([0.0, 1.0])
result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
Args:
x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`). Shape: [batch_size, ...].
step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
atol (float): Absolute tolerance, used for adaptive step solvers.
rtol (float): Relative tolerance, used for adaptive step solvers.
time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
**model_extras: Additional input for the model.
Returns:
Union[Tensor, Sequence[Tensor]]: The last timestep when return_intermediates=False, otherwise all values specified in time_grid.
"""
time_grid = time_grid.to(x_init.device)
def ode_func(t, x):
return self.velocity_model(x=x, t=t, **model_extras)
ode_opts = {"step_size": step_size} if step_size is not None else {}
with torch.set_grad_enabled(enable_grad):
# Approximate ODE solution with numerical ODE solver
sol = odeint(
ode_func,
x_init,
time_grid,
method=method,
options=ode_opts,
atol=atol,
rtol=rtol,
)
if return_intermediates:
return sol
else:
return sol[-1]
def compute_likelihood(
self,
x_1: Tensor,
log_p0: Callable[[Tensor], Tensor],
step_size: Optional[float],
method: str = "euler",
atol: float = 1e-5,
rtol: float = 1e-5,
time_grid: Tensor = torch.tensor([1.0, 0.0]),
return_intermediates: bool = False,
exact_divergence: bool = False,
enable_grad: bool = False,
**model_extras,
) -> Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]:
r"""Solve for log likelihood given a target sample at :math:`t=0`.
Works similarly to sample, but solves the ODE in reverse to compute the log-likelihood. The velocity model must be differentiable with respect to x.
The function assumes log_p0 is the log probability of the source distribution at :math:`t=0`.
Args:
x_1 (Tensor): target sample (e.g., samples :math:`X_1 \sim p_1`).
log_p0 (Callable[[Tensor], Tensor]): Log probability function of the source distribution.
step_size (Optional[float]): The step size. Must be None for adaptive step solvers.
method (str): A method supported by torchdiffeq. Defaults to "euler". Other commonly used solvers are "dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq.
atol (float): Absolute tolerance, used for adaptive step solvers.
rtol (float): Relative tolerance, used for adaptive step solvers.
time_grid (Tensor): If step_size is None then time discretization is set by the time grid. Must start at 1.0 and end at 0.0, otherwise the likelihood computation is not valid. Defaults to torch.tensor([1.0, 0.0]).
return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Otherwise only return the final sample. Defaults to False.
exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator.
enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
**model_extras: Additional input for the model.
Returns:
Union[Tuple[Tensor, Tensor], Tuple[Sequence[Tensor], Tensor]]: Samples at time_grid and log likelihood values of given x_1.
"""
assert (
time_grid[0] == 1.0 and time_grid[-1] == 0.0
), f"Time grid must start at 1.0 and end at 0.0. Got {time_grid}"
# Fix the random projection for the Hutchinson divergence estimator
if not exact_divergence:
z = (torch.randn_like(x_1).to(x_1.device) < 0) * 2.0 - 1.0
def ode_func(x, t):
return self.velocity_model(x=x, t=t, **model_extras)
def dynamics_func(t, states):
xt = states[0]
with torch.set_grad_enabled(True):
xt.requires_grad_()
ut = ode_func(xt, t)
if exact_divergence:
# Compute exact divergence
div = 0
for i in range(ut.flatten(1).shape[1]):
div += gradient(ut[:, i], xt, create_graph=True)[:, i]
else:
# Compute Hutchinson divergence estimator E[z^T D_x(ut) z]
ut_dot_z = torch.einsum(
"ij,ij->i", ut.flatten(start_dim=1), z.flatten(start_dim=1)
)
grad_ut_dot_z = gradient(ut_dot_z, xt)
div = torch.einsum(
"ij,ij->i",
grad_ut_dot_z.flatten(start_dim=1),
z.flatten(start_dim=1),
)
return ut.detach(), div.detach()
y_init = (x_1, torch.zeros(x_1.shape[0], device=x_1.device))
ode_opts = {"step_size": step_size} if step_size is not None else {}
with torch.set_grad_enabled(enable_grad):
sol, log_det = odeint(
dynamics_func,
y_init,
time_grid,
method=method,
options=ode_opts,
atol=atol,
rtol=rtol,
)
x_source = sol[-1]
source_log_p = log_p0(x_source)
if return_intermediates:
return sol, source_log_p + log_det[-1]
else:
return sol[-1], source_log_p + log_det[-1]
|