|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PathSample: |
|
|
r"""Represents a sample of a conditional-flow generated probability path. |
|
|
|
|
|
Attributes: |
|
|
x_1 (Tensor): the target sample :math:`X_1`. |
|
|
x_0 (Tensor): the source sample :math:`X_0`. |
|
|
t (Tensor): the time sample :math:`t`. |
|
|
x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...). |
|
|
dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...). |
|
|
|
|
|
""" |
|
|
|
|
|
x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) |
|
|
x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) |
|
|
t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) |
|
|
x_t: Tensor = field( |
|
|
metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."} |
|
|
) |
|
|
dx_t: Tensor = field( |
|
|
metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."} |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DiscretePathSample: |
|
|
""" |
|
|
Represents a sample of a conditional-flow generated discrete probability path. |
|
|
|
|
|
Attributes: |
|
|
x_1 (Tensor): the target sample :math:`X_1`. |
|
|
x_0 (Tensor): the source sample :math:`X_0`. |
|
|
t (Tensor): the time sample :math:`t`. |
|
|
x_t (Tensor): the sample along the path :math:`X_t \sim p_t`. |
|
|
""" |
|
|
|
|
|
x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) |
|
|
x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) |
|
|
t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."}) |
|
|
x_t: Tensor = field( |
|
|
metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."} |
|
|
) |
|
|
|