AlienChen's picture
Upload 72 files
3527383 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable
import torch
from torch import Tensor
from flow_matching.utils.manifolds import Manifold
def geodesic(
manifold: Manifold, start_point: Tensor, end_point: Tensor
) -> Callable[[Tensor], Tensor]:
"""Generate parameterized function for geodesic curve.
Args:
manifold (Manifold): the manifold to compute geodesic on.
start_point (Tensor): point on the manifold at :math:`t=0`.
end_point (Tensor): point on the manifold at :math:`t=1`.
Returns:
Callable[[Tensor], Tensor]: a function that takes in :math:`t` and outputs the geodesic at time :math:`t`.
"""
shooting_tangent_vec = manifold.logmap(start_point, end_point)
def path(t: Tensor) -> Tensor:
"""Generate parameterized function for geodesic curve.
Args:
t (Tensor): Times at which to compute points of the geodesics.
Returns:
Tensor: geodesic path evaluated at time t.
"""
tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec)
points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs)
return points_at_time_t
return path