|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
import torch |
|
|
from torch import Tensor |
|
|
|
|
|
from flow_matching.utils.manifolds import Manifold |
|
|
|
|
|
|
|
|
class FlatTorus(Manifold): |
|
|
r"""Represents a flat torus on the :math:`[0, 2\pi]^D` subspace. Isometric to the product of 1-D spheres.""" |
|
|
|
|
|
def expmap(self, x: Tensor, u: Tensor) -> Tensor: |
|
|
return (x + u) % (2 * math.pi) |
|
|
|
|
|
def logmap(self, x: Tensor, y: Tensor) -> Tensor: |
|
|
return torch.atan2(torch.sin(y - x), torch.cos(y - x)) |
|
|
|
|
|
def projx(self, x: Tensor) -> Tensor: |
|
|
return x % (2 * math.pi) |
|
|
|
|
|
def proju(self, x: Tensor, u: Tensor) -> Tensor: |
|
|
return u |
|
|
|