|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from contextlib import nullcontext |
|
|
from math import ceil |
|
|
from typing import Callable, Optional, Union |
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
import gc |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from flow_matching.path import MixtureDiscreteProbPath |
|
|
|
|
|
from flow_matching.solver.solver import Solver |
|
|
from flow_matching.utils import categorical, ModelWrapper |
|
|
from .utils import get_nearest_times |
|
|
from ..utils.multi_guidance import * |
|
|
|
|
|
try: |
|
|
from tqdm import tqdm |
|
|
|
|
|
TQDM_AVAILABLE = True |
|
|
except ImportError: |
|
|
TQDM_AVAILABLE = False |
|
|
|
|
|
|
|
|
class MixtureDiscreteEulerSolver(Solver): |
|
|
r"""Solver that simulates the CTMC process :math:`(X_t)_{t_{\text{init}}\leq t\leq t_{\text{final}}}` defined by :math:`p_t` the marginal probability path of ``path``. |
|
|
Given :math:`X_t \sim p_t`, the algorithm of solver step from :math:`t` to :math:`t+h` for the i-th coordinate is: |
|
|
|
|
|
.. math:: |
|
|
|
|
|
\begin{align*} |
|
|
& X_1^i \sim p_{1|t}^i(\cdot|X_t)\\ |
|
|
& \lambda^i \gets \sum_{x^i\ne X_t^i} u_t^i(x^i, X_t^i|X_1^i)\\ |
|
|
& Z^i_{\text{change}} \sim U[0,1]\\ |
|
|
& X_{t+h}^i \sim \begin{cases} |
|
|
\frac{u_t^i(\cdot, X_t^i|X_1^i)}{\lambda^i}(1-\delta_{X_t^i}(\cdot)) \text{ if $Z^i_{\text{change}}\le 1-e^{-h\lambda^i}$}\\ |
|
|
\delta_{X_t^i}(\cdot) \text{ else } |
|
|
\end{cases} |
|
|
\end{align*} |
|
|
|
|
|
Where :math:`p_{1|t}(\cdot|X_t)` is the output of ``model``, and the conditional probability velocity is of the mixture probability path is: |
|
|
|
|
|
.. math:: |
|
|
|
|
|
u_t^i(x^i, y^i|x_1^i) = \hat{u}_t^i(x^i, y^i|x_1^i) + c_{\text{div\_free}}\left[\hat{u}_t^i(x^i, y^i|x_1^i) - \check{u}_t^i(x^i, y^i|x_1^i) \right], |
|
|
|
|
|
where |
|
|
|
|
|
.. math:: |
|
|
\hat{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left[ \delta_{x_1^i}(x^i) - \delta_{y^i}(x^i) \right], |
|
|
|
|
|
and |
|
|
|
|
|
.. math:: |
|
|
|
|
|
\check{u}_t^i(x^i, y^i|x_1^i) = \frac{\dot{\kappa}_t}{\kappa_t}\left[ \delta_{y^i}(x^i) - p(x^i) \right]. |
|
|
|
|
|
The source distribution :math:`p(x^i)` is given by ``p``. |
|
|
|
|
|
Args: |
|
|
model (ModelWrapper): trained with x-prediction, outputting posterior probabilities (in the range :math:`[0,1]`), output must be [..., vocabulary_size]. |
|
|
path (MixtureDiscreteProbPath): Probability path used for x-prediction training. |
|
|
vocabulary_size (int): size of the discrete vocabulary. |
|
|
source_distribution_p (Optional[Tensor], optional): Source distribution, must be of shape [vocabulary_size]. Required only when divergence-free term for the probability velocity is non-zero. Defaults to None. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: ModelWrapper, |
|
|
path: MixtureDiscreteProbPath, |
|
|
vocabulary_size: int, |
|
|
source_distribution_p: Optional[Tensor] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
self.path = path |
|
|
self.vocabulary_size = vocabulary_size |
|
|
|
|
|
if source_distribution_p is not None: |
|
|
assert source_distribution_p.shape == torch.Size( |
|
|
[vocabulary_size] |
|
|
), f"Source distribution p dimension must match the vocabulary size {vocabulary_size}. Got {source_distribution_p.shape}." |
|
|
|
|
|
self.source_distribution_p = source_distribution_p |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample( |
|
|
self, |
|
|
x_init: Tensor, |
|
|
step_size: Optional[float], |
|
|
div_free: Union[float, Callable[[float], float]] = 0.0, |
|
|
dtype_categorical: torch.dtype = torch.float32, |
|
|
time_grid: Tensor = torch.tensor([0.0, 1.0]), |
|
|
return_intermediates: bool = False, |
|
|
verbose: bool = False, |
|
|
**model_extras, |
|
|
) -> Tensor: |
|
|
""" |
|
|
Sample a sequence of discrete values from the given model. |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
import torch |
|
|
from flow_matching.utils import ModelWrapper |
|
|
from flow_matching.solver import MixtureDiscreteEulerSolver |
|
|
|
|
|
class DummyModel(ModelWrapper): |
|
|
def __init__(self): |
|
|
super().__init__(None) |
|
|
def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: |
|
|
return ... |
|
|
|
|
|
model = DummyModel() |
|
|
solver = MixtureDiscreteEulerSolver(model=model) |
|
|
|
|
|
x_init = torch.LongTensor([122, 725]) |
|
|
step_size = 0.001 |
|
|
time_grid = torch.tensor([0.0, 1.0]) |
|
|
|
|
|
result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid) |
|
|
|
|
|
Args: |
|
|
x_init (Tensor): The initial state. |
|
|
step_size (Optional[float]): If float then time discretization is uniform with the given step size. If None then time discretization is set to be time_grid. |
|
|
div_free (Union[float, Callable[[float], float]]): The coefficient of the divergence-free term in the probability velocity. Can be either a float or a time dependent function. Defaults to 0.0. |
|
|
dtype_categorical (torch.dtype): Precision to use for categorical sampler. Defaults to torch.float32. |
|
|
time_grid (Tensor): The CTMC process is solved in the interval [time_grid[0], time_grid[-1]] and if step_size is None then time discretization is set by the time grid. Defaults to torch.tensor([0.0,1.0]). |
|
|
return_intermediates (bool): If True then return intermediate time steps according to time_grid. Defaults to False. |
|
|
verbose (bool): Whether to print progress bars. Defaults to False. |
|
|
**model_extras: Additional input for the model. |
|
|
|
|
|
Returns: |
|
|
Tensor: The sampled sequence of discrete values. |
|
|
|
|
|
Raises: |
|
|
ImportError: To run in verbose mode, tqdm must be installed. |
|
|
""" |
|
|
if not div_free == 0.0: |
|
|
assert ( |
|
|
self.source_distribution_p is not None |
|
|
), "Source distribution p must be specified in order to add a divergence-free term to the probability velocity." |
|
|
|
|
|
|
|
|
time_grid = time_grid.to(device=x_init.device) |
|
|
|
|
|
if step_size is None: |
|
|
|
|
|
t_discretization = time_grid |
|
|
n_steps = len(time_grid) - 1 |
|
|
else: |
|
|
|
|
|
t_init = time_grid[0].item() |
|
|
t_final = time_grid[-1].item() |
|
|
assert ( |
|
|
t_final - t_init |
|
|
) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." |
|
|
|
|
|
n_steps = ceil((t_final - t_init) / step_size) |
|
|
t_discretization = torch.tensor( |
|
|
[t_init + step_size * i for i in range(n_steps)] + [t_final], |
|
|
device=x_init.device, |
|
|
) |
|
|
|
|
|
if return_intermediates: |
|
|
|
|
|
order = torch.argsort(time_grid) |
|
|
|
|
|
time_grid = get_nearest_times( |
|
|
time_grid=time_grid, t_discretization=t_discretization |
|
|
) |
|
|
|
|
|
x_t = x_init.clone() |
|
|
steps_counter = 0 |
|
|
res = [] |
|
|
|
|
|
if return_intermediates: |
|
|
res = [x_init.clone()] |
|
|
|
|
|
if verbose: |
|
|
if not TQDM_AVAILABLE: |
|
|
raise ImportError( |
|
|
"tqdm is required for verbose mode. Please install it." |
|
|
) |
|
|
ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}") |
|
|
else: |
|
|
ctx = nullcontext() |
|
|
|
|
|
with ctx: |
|
|
for i in range(n_steps): |
|
|
t = t_discretization[i : i + 1] |
|
|
h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1] |
|
|
|
|
|
|
|
|
p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras) |
|
|
x_1 = categorical(p_1t.to(dtype=dtype_categorical)) |
|
|
|
|
|
|
|
|
if i == n_steps - 1: |
|
|
x_t = x_1 |
|
|
else: |
|
|
|
|
|
scheduler_output = self.path.scheduler(t=t) |
|
|
|
|
|
k_t = scheduler_output.alpha_t |
|
|
d_k_t = scheduler_output.d_alpha_t |
|
|
|
|
|
delta_1 = F.one_hot(x_1, num_classes=self.vocabulary_size).to( |
|
|
k_t.dtype |
|
|
) |
|
|
u = d_k_t / (1 - k_t) * delta_1 |
|
|
|
|
|
|
|
|
div_free_t = div_free(t) if callable(div_free) else div_free |
|
|
|
|
|
if div_free_t > 0: |
|
|
p_0 = self.source_distribution_p[(None,) * x_t.dim()] |
|
|
u = u + div_free_t * d_k_t / (k_t * (1 - k_t)) * ( |
|
|
(1 - k_t) * p_0 + k_t * delta_1 |
|
|
) |
|
|
|
|
|
|
|
|
delta_t = F.one_hot(x_t, num_classes=self.vocabulary_size) |
|
|
u = torch.where( |
|
|
delta_t.to(dtype=torch.bool), torch.zeros_like(u), u |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
intensity = u.sum(dim=-1) |
|
|
mask_jump = torch.rand(size=x_t.shape, device=x_t.device) < 1 - torch.exp(-h * intensity) |
|
|
|
|
|
if mask_jump.sum() > 0: |
|
|
x_t[mask_jump] = categorical( |
|
|
u[mask_jump].to(dtype=dtype_categorical) |
|
|
) |
|
|
|
|
|
steps_counter += 1 |
|
|
t = t + h |
|
|
|
|
|
if return_intermediates and (t in time_grid): |
|
|
res.append(x_t.clone()) |
|
|
|
|
|
if verbose: |
|
|
ctx.n = t.item() |
|
|
ctx.refresh() |
|
|
ctx.set_description(f"NFE: {steps_counter}") |
|
|
|
|
|
if return_intermediates: |
|
|
if step_size is None: |
|
|
return torch.stack(res, dim=0) |
|
|
else: |
|
|
return torch.stack(res, dim=0)[order] |
|
|
else: |
|
|
return x_t |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def multi_guidance_sample( |
|
|
self, |
|
|
args, |
|
|
x_init: Tensor, |
|
|
step_size: Optional[float], |
|
|
div_free: Union[float, Callable[[float], float]] = 0.0, |
|
|
dtype_categorical: torch.dtype = torch.float32, |
|
|
time_grid: Tensor = torch.tensor([0.0, 1.0]), |
|
|
return_intermediates: bool = False, |
|
|
verbose: bool = False, |
|
|
score_models: list = None, |
|
|
num_objectives: int = 1, |
|
|
weights: list = None, |
|
|
**model_extras, |
|
|
) -> Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pdb |
|
|
|
|
|
if not div_free == 0.0: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
time_grid = time_grid.to(device=x_init.device) |
|
|
|
|
|
if step_size is None: |
|
|
|
|
|
t_discretization = time_grid |
|
|
n_steps = len(time_grid) - 1 |
|
|
else: |
|
|
|
|
|
t_init = time_grid[0].item() |
|
|
t_final = time_grid[-1].item() |
|
|
assert ( |
|
|
t_final - t_init |
|
|
) > step_size, f"Time interval [time_grid[0], time_grid[-1]] must be larger than step_size. Got a time interval [{t_init}, {t_final}] and step_size {step_size}." |
|
|
|
|
|
n_steps = ceil((t_final - t_init) / step_size) |
|
|
t_discretization = torch.tensor( |
|
|
[t_init + step_size * i for i in range(n_steps)] + [t_final], |
|
|
device=x_init.device, |
|
|
) |
|
|
|
|
|
if return_intermediates: |
|
|
|
|
|
order = torch.argsort(time_grid) |
|
|
|
|
|
time_grid = get_nearest_times( |
|
|
time_grid=time_grid, t_discretization=t_discretization |
|
|
) |
|
|
|
|
|
x_t = x_init.clone() |
|
|
steps_counter = 0 |
|
|
res = [] |
|
|
|
|
|
if return_intermediates: |
|
|
res = [x_init.clone()] |
|
|
|
|
|
if verbose: |
|
|
if not TQDM_AVAILABLE: |
|
|
raise ImportError( |
|
|
"tqdm is required for verbose mode. Please install it." |
|
|
) |
|
|
ctx = tqdm(total=t_final, desc=f"NFE: {steps_counter}") |
|
|
else: |
|
|
ctx = nullcontext() |
|
|
|
|
|
|
|
|
if weights is not None: |
|
|
w = torch.tensor(weights).to(device=x_init.device) |
|
|
else: |
|
|
w, _ = select_random_weight_vector(num_objectives, args.num_div) |
|
|
|
|
|
w = w.to(device=x_init.device) |
|
|
print(f"Weight Vector: {w}") |
|
|
Phi = args.Phi_init |
|
|
ema_r_t = None |
|
|
|
|
|
with ctx: |
|
|
for i in range(n_steps): |
|
|
t = t_discretization[i : i + 1] |
|
|
h = t_discretization[i + 1 : i + 2] - t_discretization[i : i + 1] |
|
|
|
|
|
p_1t = self.model(x=x_t, t=t.repeat(x_t.shape[0]), **model_extras) |
|
|
x_1 = categorical(p_1t.to(dtype=dtype_categorical)) |
|
|
|
|
|
|
|
|
if i != n_steps - 1: |
|
|
|
|
|
scheduler_output = self.path.scheduler(t=t) |
|
|
k_t = scheduler_output.alpha_t |
|
|
d_k_t = scheduler_output.d_alpha_t |
|
|
u_t = d_k_t / (1 - k_t) * p_1t |
|
|
|
|
|
guided_u_t, pos_indices, cand_tokens, improvement_values, delta_S = guided_transition_scoring(x_t, u_t, w, score_models, t, w, args) |
|
|
|
|
|
best_candidate, accepted_mask, valid_mask, Phi, ema_r_t = adaptive_hypercone_filtering(improvement_values, cand_tokens, delta_S, w, Phi, args, ema_r_t=ema_r_t) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_t = euler_sample(x_t, pos_indices, best_candidate, guided_u_t, h) |
|
|
|
|
|
|
|
|
steps_counter += 1 |
|
|
t = t + h |
|
|
|
|
|
scores = [] |
|
|
for i, s in enumerate(score_models): |
|
|
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) |
|
|
if 't' in sig.parameters: |
|
|
candidate_scores = s(x_t, 1) |
|
|
else: |
|
|
candidate_scores = s(x_t) |
|
|
|
|
|
if isinstance(candidate_scores, tuple): |
|
|
for score in candidate_scores: |
|
|
scores.append(score.item()) |
|
|
else: |
|
|
scores.append(candidate_scores.item()) |
|
|
print(scores) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if return_intermediates and (t in time_grid): |
|
|
res.append(x_t.clone()) |
|
|
|
|
|
if verbose: |
|
|
ctx.n = t.item() |
|
|
ctx.refresh() |
|
|
ctx.set_description(f"NFE: {steps_counter}") |
|
|
|
|
|
|
|
|
if return_intermediates: |
|
|
if step_size is None: |
|
|
return torch.stack(res, dim=0) |
|
|
else: |
|
|
return torch.stack(res, dim=0)[order] |
|
|
else: |
|
|
|
|
|
return x_t |