|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import abc |
|
|
|
|
|
import torch.nn as nn |
|
|
from torch import Tensor |
|
|
|
|
|
|
|
|
class Manifold(nn.Module, metaclass=abc.ABCMeta): |
|
|
"""A manifold class that contains projection operations and logarithm and exponential maps.""" |
|
|
|
|
|
@abc.abstractmethod |
|
|
def expmap(self, x: Tensor, u: Tensor) -> Tensor: |
|
|
r"""Computes exponential map :math:`\exp_x(u)`. |
|
|
|
|
|
Args: |
|
|
x (Tensor): point on the manifold |
|
|
u (Tensor): tangent vector at point :math:`x` |
|
|
|
|
|
Raises: |
|
|
NotImplementedError: if not implemented |
|
|
|
|
|
Returns: |
|
|
Tensor: transported point |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
@abc.abstractmethod |
|
|
def logmap(self, x: Tensor, y: Tensor) -> Tensor: |
|
|
r"""Computes logarithmic map :math:`\log_x(y)`. |
|
|
|
|
|
Args: |
|
|
x (Tensor): point on the manifold |
|
|
y (Tensor): point on the manifold |
|
|
|
|
|
Raises: |
|
|
NotImplementedError: if not implemented |
|
|
|
|
|
Returns: |
|
|
Tensor: tangent vector at point :math:`x` |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
@abc.abstractmethod |
|
|
def projx(self, x: Tensor) -> Tensor: |
|
|
"""Project point :math:`x` on the manifold. |
|
|
|
|
|
Args: |
|
|
x (Tensor): point to be projected |
|
|
|
|
|
Raises: |
|
|
NotImplementedError: if not implemented |
|
|
|
|
|
Returns: |
|
|
Tensor: projected point on the manifold |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
@abc.abstractmethod |
|
|
def proju(self, x: Tensor, u: Tensor) -> Tensor: |
|
|
"""Project vector :math:`u` on a tangent space for :math:`x`. |
|
|
|
|
|
Args: |
|
|
x (Tensor): point on the manifold |
|
|
u (Tensor): vector to be projected |
|
|
|
|
|
Raises: |
|
|
NotImplementedError: if not implemented |
|
|
|
|
|
Returns: |
|
|
Tensor: projected tangent vector |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class Euclidean(Manifold): |
|
|
"""The Euclidean manifold.""" |
|
|
|
|
|
def expmap(self, x: Tensor, u: Tensor) -> Tensor: |
|
|
return x + u |
|
|
|
|
|
def logmap(self, x: Tensor, y: Tensor) -> Tensor: |
|
|
return y - x |
|
|
|
|
|
def projx(self, x: Tensor) -> Tensor: |
|
|
return x |
|
|
|
|
|
def proju(self, x: Tensor, u: Tensor) -> Tensor: |
|
|
return u |
|
|
|