Spaces:
Sleeping
Sleeping
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| from typing import Tuple, Dict | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| class Pd(object): | |
| """ | |
| Overview: | |
| Abstract class for parameterizable probability distributions and sampling functions. | |
| Interfaces: | |
| ``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample`` | |
| .. tip:: | |
| In dereived classes, `logits` should be an attribute member stored in class. | |
| """ | |
| def neglogp(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Calculate cross_entropy between input x and logits | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): the input tensor | |
| Return: | |
| - cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss | |
| """ | |
| raise NotImplementedError | |
| def entropy(self) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Calculate the softmax entropy of logits | |
| Arguments: | |
| - reduction (:obj:`str`): support [None, 'mean'], default set to 'mean' | |
| Returns: | |
| - entropy (:obj:`torch.Tensor`): the calculated entropy | |
| """ | |
| raise NotImplementedError | |
| def noise_mode(self): | |
| """ | |
| Overview: | |
| Add noise to logits. This method is designed for randomness | |
| """ | |
| raise NotImplementedError | |
| def mode(self): | |
| """ | |
| Overview: | |
| Return logits argmax result. This method is designed for deterministic. | |
| """ | |
| raise NotImplementedError | |
| def sample(self): | |
| """ | |
| Overview: | |
| Sample from logits's distribution by using softmax. This method is designed for multinomial. | |
| """ | |
| raise NotImplementedError | |
| class CategoricalPd(Pd): | |
| """ | |
| Overview: | |
| Catagorical probility distribution sampler | |
| Interfaces: | |
| ``__init__``, ``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample`` | |
| """ | |
| def __init__(self, logits: torch.Tensor = None) -> None: | |
| """ | |
| Overview: | |
| Init the Pd with logits | |
| Arguments: | |
| - logits (:obj:torch.Tensor): logits to sample from | |
| """ | |
| self.update_logits(logits) | |
| def update_logits(self, logits: torch.Tensor) -> None: | |
| """ | |
| Overview: | |
| Updata logits | |
| Arguments: | |
| - logits (:obj:`torch.Tensor`): logits to update | |
| """ | |
| self.logits = logits | |
| def neglogp(self, x, reduction: str = 'mean') -> torch.Tensor: | |
| """ | |
| Overview: | |
| Calculate cross_entropy between input x and logits | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): the input tensor | |
| - reduction (:obj:`str`): support [None, 'mean'], default set to mean | |
| Return: | |
| - cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss | |
| """ | |
| return F.cross_entropy(self.logits, x, reduction=reduction) | |
| def entropy(self, reduction: str = 'mean') -> torch.Tensor: | |
| """ | |
| Overview: | |
| Calculate the softmax entropy of logits | |
| Arguments: | |
| - reduction (:obj:`str`): support [None, 'mean'], default set to mean | |
| Returns: | |
| - entropy (:obj:`torch.Tensor`): the calculated entropy | |
| """ | |
| a = self.logits - self.logits.max(dim=-1, keepdim=True)[0] | |
| ea = torch.exp(a) | |
| z = ea.sum(dim=-1, keepdim=True) | |
| p = ea / z | |
| entropy = (p * (torch.log(z) - a)).sum(dim=-1) | |
| assert (reduction in [None, 'mean']) | |
| if reduction is None: | |
| return entropy | |
| elif reduction == 'mean': | |
| return entropy.mean() | |
| def noise_mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: | |
| """ | |
| Overview: | |
| add noise to logits | |
| Arguments: | |
| - viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \ | |
| Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log) | |
| Returns: | |
| - result (:obj:`torch.Tensor`): noised logits | |
| - viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization. | |
| """ | |
| u = torch.rand_like(self.logits) | |
| u = -torch.log(-torch.log(u)) | |
| noise_logits = self.logits + u | |
| result = noise_logits.argmax(dim=-1) | |
| if viz: | |
| viz_feature = {} | |
| viz_feature['logits'] = self.logits.data.cpu().numpy() | |
| viz_feature['noise'] = u.data.cpu().numpy() | |
| viz_feature['noise_logits'] = noise_logits.data.cpu().numpy() | |
| return result, viz_feature | |
| else: | |
| return result | |
| def mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: | |
| """ | |
| Overview: | |
| return logits argmax result | |
| Arguments: | |
| - viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; | |
| Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log) | |
| Returns: | |
| - result (:obj:`torch.Tensor`): the logits argmax result | |
| - viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization. | |
| """ | |
| result = self.logits.argmax(dim=-1) | |
| if viz: | |
| viz_feature = {} | |
| viz_feature['logits'] = self.logits.data.cpu().numpy() | |
| return result, viz_feature | |
| else: | |
| return result | |
| def sample(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]: | |
| """ | |
| Overview: | |
| Sample from logits's distribution by using softmax | |
| Arguments: | |
| - viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \ | |
| Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log) | |
| Returns: | |
| - result (:obj:`torch.Tensor`): the logits sampled result | |
| - viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization. | |
| """ | |
| p = torch.softmax(self.logits, dim=1) | |
| result = torch.multinomial(p, 1).squeeze(1) | |
| if viz: | |
| viz_feature = {} | |
| viz_feature['logits'] = self.logits.data.cpu().numpy() | |
| return result, viz_feature | |
| else: | |
| return result | |
| class CategoricalPdPytorch(torch.distributions.Categorical): | |
| """ | |
| Overview: | |
| Wrapped ``torch.distributions.Categorical`` | |
| Interfaces: | |
| ``__init__``, ``update_logits``, ``update_probs``, ``sample``, ``neglogp``, ``mode``, ``entropy`` | |
| """ | |
| def __init__(self, probs: torch.Tensor = None) -> None: | |
| """ | |
| Overview: | |
| Initialize the CategoricalPdPytorch object. | |
| Arguments: | |
| - probs (:obj:`torch.Tensor`): The tensor of probabilities. | |
| """ | |
| if probs is not None: | |
| self.update_probs(probs) | |
| def update_logits(self, logits: torch.Tensor) -> None: | |
| """ | |
| Overview: | |
| Updata logits | |
| Arguments: | |
| - logits (:obj:`torch.Tensor`): logits to update | |
| """ | |
| super().__init__(logits=logits) | |
| def update_probs(self, probs: torch.Tensor) -> None: | |
| """ | |
| Overview: | |
| Updata probs | |
| Arguments: | |
| - probs (:obj:`torch.Tensor`): probs to update | |
| """ | |
| super().__init__(probs=probs) | |
| def sample(self) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Sample from logits's distribution by using softmax | |
| Return: | |
| - result (:obj:`torch.Tensor`): the logits sampled result | |
| """ | |
| return super().sample() | |
| def neglogp(self, actions: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: | |
| """ | |
| Overview: | |
| Calculate cross_entropy between input x and logits | |
| Arguments: | |
| - actions (:obj:`torch.Tensor`): the input action tensor | |
| - reduction (:obj:`str`): support [None, 'mean'], default set to mean | |
| Return: | |
| - cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss | |
| """ | |
| neglogp = super().log_prob(actions) | |
| assert (reduction in ['none', 'mean']) | |
| if reduction == 'none': | |
| return neglogp | |
| elif reduction == 'mean': | |
| return neglogp.mean(dim=0) | |
| def mode(self) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Return logits argmax result | |
| Return: | |
| - result(:obj:`torch.Tensor`): the logits argmax result | |
| """ | |
| return self.probs.argmax(dim=-1) | |
| def entropy(self, reduction: str = None) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Calculate the softmax entropy of logits | |
| Arguments: | |
| - reduction (:obj:`str`): support [None, 'mean'], default set to mean | |
| Returns: | |
| - entropy (:obj:`torch.Tensor`): the calculated entropy | |
| """ | |
| entropy = super().entropy() | |
| assert (reduction in [None, 'mean']) | |
| if reduction is None: | |
| return entropy | |
| elif reduction == 'mean': | |
| return entropy.mean() | |