| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the CC-by-NC license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from torch import Tensor | |
| def categorical(probs: Tensor) -> Tensor: | |
| r"""Categorical sampler according to weights in the last dimension of ``probs`` using :func:`torch.multinomial`. | |
| Args: | |
| probs (Tensor): probabilities. | |
| Returns: | |
| Tensor: Samples. | |
| """ | |
| return torch.multinomial(probs.flatten(0, -2), 1, replacement=True).view( | |
| *probs.shape[:-1] | |
| ) | |