AlienChen's picture
Upload 72 files
3527383 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from torch import Tensor
from flow_matching.path.path import ProbPath
from flow_matching.path.path_sample import DiscretePathSample
from flow_matching.path.scheduler import ConvexScheduler
from flow_matching.utils import expand_tensor_like, unsqueeze_to_match
class MixtureDiscreteProbPath(ProbPath):
r"""The ``MixtureDiscreteProbPath`` class defines a factorized discrete probability path.
This path remains constant at the source data point :math:`X_0` until a random time, determined by the scheduler, when it flips to the target data point :math:`X_1`.
The scheduler determines the flip probability using the parameter :math:`\sigma_t`, which is a function of time `t`. Specifically, :math:`\sigma_t` represents the probability of remaining at :math:`X_0`, while :math:`1 - \sigma_t` is the probability of flipping to :math:`X_1`:
.. math::
P(X_t = X_0) = \sigma_t \quad \text{and} \quad P(X_t = X_1) = 1 - \sigma_t,
where :math:`\sigma_t` is provided by the scheduler.
Example:
.. code-block:: python
>>> x_0 = torch.zeros((1, 3, 3))
>>> x_1 = torch.ones((1, 3, 3))
>>> path = MixtureDiscreteProbPath(PolynomialConvexScheduler(n=1.0))
>>> result = path.sample(x_0, x_1, t=torch.tensor([0.1])).x_t
>>> result
tensor([[[0.0, 0.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 0.0, 0.0]]])
>>> result = path.sample(x_0, x_1, t=torch.tensor([0.5])).x_t
>>> result
tensor([[[1.0, 0.0, 1.0],
[0.0, 1.0, 0.0],
[0.0, 1.0, 0.0]]])
>>> result = path.sample(x_0, x_1, t=torch.tensor([1.0])).x_t
>>> result
tensor([[[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]]])
Args:
scheduler (ConvexScheduler): The scheduler that provides :math:`\sigma_t`.
"""
def __init__(self, scheduler: ConvexScheduler):
assert isinstance(
scheduler, ConvexScheduler
), "Scheduler for ConvexProbPath must be a ConvexScheduler."
self.scheduler = scheduler
def sample(self, x_0: Tensor, x_1: Tensor, t: Tensor) -> DiscretePathSample:
r"""Sample from the affine probability path:
| given :math:`(X_0,X_1) \sim \pi(X_0,X_1)` and a scheduler :math:`(\alpha_t,\sigma_t)`.
| return :math:`X_0, X_1, t`, and :math:`X_t \sim p_t`.
Args:
x_0 (Tensor): source data point, shape (batch_size, ...).
x_1 (Tensor): target data point, shape (batch_size, ...).
t (Tensor): times in [0,1], shape (batch_size).
Returns:
DiscretePathSample: a conditional sample at :math:`X_t ~ p_t`.
"""
self.assert_sample_shape(x_0=x_0, x_1=x_1, t=t)
sigma_t = self.scheduler(t).sigma_t
sigma_t = expand_tensor_like(input_tensor=sigma_t, expand_to=x_1)
source_indices = torch.rand(size=x_1.shape, device=x_1.device) < sigma_t
x_t = torch.where(condition=source_indices, input=x_0, other=x_1)
return DiscretePathSample(x_t=x_t, x_1=x_1, x_0=x_0, t=t)
def posterior_to_velocity(
self, posterior_logits: Tensor, x_t: Tensor, t: Tensor
) -> Tensor:
r"""Convert the factorized posterior to velocity.
| given :math:`p(X_1|X_t)`. In the factorized case: :math:`\prod_i p(X_1^i | X_t)`.
| return :math:`u_t`.
Args:
posterior_logits (Tensor): logits of the x_1 posterior conditional on x_t, shape (..., vocab size).
x_t (Tensor): path sample at time t, shape (...).
t (Tensor): time in [0,1].
Returns:
Tensor: velocity.
"""
posterior = torch.softmax(posterior_logits, dim=-1)
vocabulary_size = posterior.shape[-1]
x_t = F.one_hot(x_t, num_classes=vocabulary_size)
t = unsqueeze_to_match(source=t, target=x_t)
scheduler_output = self.scheduler(t)
kappa_t = scheduler_output.alpha_t
d_kappa_t = scheduler_output.d_alpha_t
return (d_kappa_t / (1 - kappa_t)) * (posterior - x_t)