# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the CC-by-NC license found in the # LICENSE file in the root directory of this source tree. from typing import Optional import torch from torch import Tensor def unsqueeze_to_match(source: Tensor, target: Tensor, how: str = "suffix") -> Tensor: """ Unsqueeze the source tensor to match the dimensionality of the target tensor. Args: source (Tensor): The source tensor to be unsqueezed. target (Tensor): The target tensor to match the dimensionality of. how (str, optional): Whether to unsqueeze the source tensor at the beginning ("prefix") or end ("suffix"). Defaults to "suffix". Returns: Tensor: The unsqueezed source tensor. """ assert ( how == "prefix" or how == "suffix" ), f"{how} is not supported, only 'prefix' and 'suffix' are supported." dim_diff = target.dim() - source.dim() for _ in range(dim_diff): if how == "prefix": source = source.unsqueeze(0) elif how == "suffix": source = source.unsqueeze(-1) return source def expand_tensor_like(input_tensor: Tensor, expand_to: Tensor) -> Tensor: """`input_tensor` is a 1d vector of length equal to the batch size of `expand_to`, expand `input_tensor` to have the same shape as `expand_to` along all remaining dimensions. Args: input_tensor (Tensor): (batch_size,). expand_to (Tensor): (batch_size, ...). Returns: Tensor: (batch_size, ...). """ assert input_tensor.ndim == 1, "Input tensor must be a 1d vector." assert ( input_tensor.shape[0] == expand_to.shape[0] ), f"The first (batch_size) dimension must match. Got shape {input_tensor.shape} and {expand_to.shape}." dim_diff = expand_to.ndim - input_tensor.ndim t_expanded = input_tensor.clone() t_expanded = t_expanded.reshape(-1, *([1] * dim_diff)) return t_expanded.expand_as(expand_to) def gradient( output: Tensor, x: Tensor, grad_outputs: Optional[Tensor] = None, create_graph: bool = False, ) -> Tensor: """ Compute the gradient of the inner product of output and grad_outputs w.r.t :math:`x`. Args: output (Tensor): [N, D] Output of the function. x (Tensor): [N, d_1, d_2, ... ] input grad_outputs (Optional[Tensor]): [N, D] Gradient of outputs, if `None`, then will use a tensor of ones create_graph (bool): If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False. Returns: Tensor: [N, d_1, d_2, ... ]. the gradient w.r.t x. """ if grad_outputs is None: grad_outputs = torch.ones_like(output).detach() grad = torch.autograd.grad( output, x, grad_outputs=grad_outputs, create_graph=create_graph )[0] return grad