AlienChen's picture
Upload 72 files
3527383 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import Tensor
from flow_matching.utils.manifolds import Manifold
class Sphere(Manifold):
"""Represents a hyperpshere in :math:`R^D`. Isometric to the product of 1-D spheres."""
EPS = {torch.float32: 1e-4, torch.float64: 1e-7}
def expmap(self, x: Tensor, u: Tensor) -> Tensor:
norm_u = u.norm(dim=-1, keepdim=True)
exp = x * torch.cos(norm_u) + u * torch.sin(norm_u) / norm_u
retr = self.projx(x + u)
cond = norm_u > self.EPS[norm_u.dtype]
return torch.where(cond, exp, retr)
def logmap(self, x: Tensor, y: Tensor) -> Tensor:
u = self.proju(x, y - x)
dist = self.dist(x, y, keepdim=True)
cond = dist.gt(self.EPS[x.dtype])
result = torch.where(
cond,
u * dist / u.norm(dim=-1, keepdim=True).clamp_min(self.EPS[x.dtype]),
u,
)
return result
def projx(self, x: Tensor) -> Tensor:
return x / x.norm(dim=-1, keepdim=True)
def proju(self, x: Tensor, u: Tensor) -> Tensor:
return u - (x * u).sum(dim=-1, keepdim=True) * x
def dist(self, x: Tensor, y: Tensor, *, keepdim=False) -> Tensor:
inner = (x * y).sum(-1, keepdim=keepdim)
return torch.acos(inner)