File size: 1,956 Bytes
3527383 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from torch import Tensor
@dataclass
class PathSample:
r"""Represents a sample of a conditional-flow generated probability path.
Attributes:
x_1 (Tensor): the target sample :math:`X_1`.
x_0 (Tensor): the source sample :math:`X_0`.
t (Tensor): the time sample :math:`t`.
x_t (Tensor): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...).
dx_t (Tensor): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...).
"""
x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."})
x_t: Tensor = field(
metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."}
)
dx_t: Tensor = field(
metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."}
)
@dataclass
class DiscretePathSample:
"""
Represents a sample of a conditional-flow generated discrete probability path.
Attributes:
x_1 (Tensor): the target sample :math:`X_1`.
x_0 (Tensor): the source sample :math:`X_0`.
t (Tensor): the time sample :math:`t`.
x_t (Tensor): the sample along the path :math:`X_t \sim p_t`.
"""
x_1: Tensor = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
x_0: Tensor = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
t: Tensor = field(metadata={"help": "time samples t (batch_size, ...)."})
x_t: Tensor = field(
metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."}
)
|