|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from torch import Tensor |
|
|
|
|
|
from flow_matching.path.path import ProbPath |
|
|
|
|
|
from flow_matching.path.path_sample import DiscretePathSample |
|
|
from flow_matching.path.scheduler import ConvexScheduler |
|
|
from flow_matching.utils import expand_tensor_like, unsqueeze_to_match |
|
|
|
|
|
|
|
|
class MixtureDiscreteProbPath(ProbPath): |
|
|
r"""The ``MixtureDiscreteProbPath`` class defines a factorized discrete probability path. |
|
|
|
|
|
This path remains constant at the source data point :math:`X_0` until a random time, determined by the scheduler, when it flips to the target data point :math:`X_1`. |
|
|
The scheduler determines the flip probability using the parameter :math:`\sigma_t`, which is a function of time `t`. Specifically, :math:`\sigma_t` represents the probability of remaining at :math:`X_0`, while :math:`1 - \sigma_t` is the probability of flipping to :math:`X_1`: |
|
|
|
|
|
.. math:: |
|
|
|
|
|
P(X_t = X_0) = \sigma_t \quad \text{and} \quad P(X_t = X_1) = 1 - \sigma_t, |
|
|
|
|
|
where :math:`\sigma_t` is provided by the scheduler. |
|
|
|
|
|
Example: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
>>> x_0 = torch.zeros((1, 3, 3)) |
|
|
>>> x_1 = torch.ones((1, 3, 3)) |
|
|
|
|
|
>>> path = MixtureDiscreteProbPath(PolynomialConvexScheduler(n=1.0)) |
|
|
>>> result = path.sample(x_0, x_1, t=torch.tensor([0.1])).x_t |
|
|
>>> result |
|
|
tensor([[[0.0, 0.0, 0.0], |
|
|
[0.0, 0.0, 1.0], |
|
|
[0.0, 0.0, 0.0]]]) |
|
|
|
|
|
>>> result = path.sample(x_0, x_1, t=torch.tensor([0.5])).x_t |
|
|
>>> result |
|
|
tensor([[[1.0, 0.0, 1.0], |
|
|
[0.0, 1.0, 0.0], |
|
|
[0.0, 1.0, 0.0]]]) |
|
|
|
|
|
>>> result = path.sample(x_0, x_1, t=torch.tensor([1.0])).x_t |
|
|
>>> result |
|
|
tensor([[[1.0, 1.0, 1.0], |
|
|
[1.0, 1.0, 1.0], |
|
|
[1.0, 1.0, 1.0]]]) |
|
|
|
|
|
Args: |
|
|
scheduler (ConvexScheduler): The scheduler that provides :math:`\sigma_t`. |
|
|
""" |
|
|
|
|
|
def __init__(self, scheduler: ConvexScheduler): |
|
|
assert isinstance( |
|
|
scheduler, ConvexScheduler |
|
|
), "Scheduler for ConvexProbPath must be a ConvexScheduler." |
|
|
|
|
|
self.scheduler = scheduler |
|
|
|
|
|
def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample: |
|
|
r"""Sample from the affine probability path: |
|
|
| given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`. |
|
|
| return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`. |
|
|
Args: |
|
|
x_0 (Tensor): source data point, shape (batch_size, ...). |
|
|
x_1 (Tensor): target data point, shape (batch_size, ...). |
|
|
t (Tensor): times in [0,1], shape (batch_size). |
|
|
|
|
|
Returns: |
|
|
DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`. |
|
|
""" |
|
|
self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t) |
|
|
|
|
|
sigma_t = self.scheduler(t).sigma_t |
|
|
|
|
|
sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1) |
|
|
|
|
|
source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t |
|
|
x_t = torch.where(condition=source_indices, input=x_0, other=x_1) |
|
|
|
|
|
return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t) |
|
|
|
|
|
def posterior_to_velocity( |
|
|
self, posterior_logits: Tensor, x_t: Tensor, t: Tensor |
|
|
) -> Tensor: |
|
|
r"""Convert the factorized posterior to velocity. |
|
|
|
|
|
| given :math:`p(X_1|X_t)`. In the factorized case: :math:`\prod_i p(X_1^i | X_t)`. |
|
|
| return :math:`u_t`. |
|
|
|
|
|
Args: |
|
|
posterior_logits (Tensor): logits of the x_1 posterior conditional on x_t, shape (..., vocab size). |
|
|
x_t (Tensor): path sample at time t, shape (...). |
|
|
t (Tensor): time in [0,1]. |
|
|
|
|
|
Returns: |
|
|
Tensor: velocity. |
|
|
""" |
|
|
posterior = torch.softmax(posterior_logits, dim=-1) |
|
|
vocabulary_size = posterior.shape[-1] |
|
|
x_t = F.one_hot(x_t, num_classes=vocabulary_size) |
|
|
t = unsqueeze_to_match(source=t, target=x_t) |
|
|
|
|
|
scheduler_output = self.scheduler(t) |
|
|
|
|
|
kappa_t = scheduler_output.alpha_t |
|
|
d_kappa_t = scheduler_output.d_alpha_t |
|
|
|
|
|
return (d_kappa_t / (1 - kappa_t)) * (posterior - x_t) |
|
|
|