Spaces:
Sleeping
Sleeping
| from typing import List, Dict | |
| from ditk import logging | |
| import numpy as np | |
| import torch | |
| import pickle | |
| try: | |
| from sklearn.svm import SVC | |
| except ImportError: | |
| SVC = None | |
| from ding.torch_utils import cov | |
| from ding.utils import REWARD_MODEL_REGISTRY, one_time_warning | |
| from .base_reward_model import BaseRewardModel | |
| class PdeilRewardModel(BaseRewardModel): | |
| """ | |
| Overview: | |
| The Pdeil reward model class (https://arxiv.org/abs/2112.06746) | |
| Interface: | |
| ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \ | |
| ``__init__``, ``_train``, ``_batch_mn_pdf`` | |
| Config: | |
| == ==================== ===== ============= ======================================= ======================= | |
| ID Symbol Type Default Value Description Other(Shape) | |
| == ==================== ===== ============= ======================================= ======================= | |
| 1 ``type`` str pdeil | Reward model register name, refer | | |
| | to registry ``REWARD_MODEL_REGISTRY`` | | |
| 2 | ``expert_data_`` str expert_data. | Path to the expert dataset | Should be a '.pkl' | |
| | ``path`` .pkl | | file | |
| 3 | ``discrete_`` bool False | Whether the action is discrete | | |
| | ``action`` | | | |
| 4 | ``alpha`` float 0.5 | coefficient for Probability | | |
| | | Density Estimator | | |
| 5 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay | |
| ``_per_iters`` | buffer's data count | |
| | isn't too few. | |
| | (code work in entry) | |
| == ==================== ===== ============= ======================================= ======================= | |
| """ | |
| config = dict( | |
| # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``. | |
| type='pdeil', | |
| # (str) Path to the expert dataset. | |
| # expert_data_path='expert_data.pkl', | |
| # (bool) Whether the action is discrete. | |
| discrete_action=False, | |
| # (float) Coefficient for Probability Density Estimator. | |
| # alpha + beta = 1, alpha is in [0,1] | |
| # when alpha is close to 0, the estimator has high variance and low bias; | |
| # when alpha is close to 1, the estimator has high bias and low variance. | |
| alpha=0.5, | |
| # (int) Clear buffer per fixed iters. | |
| clear_buffer_per_iters=1, | |
| ) | |
| def __init__(self, cfg: dict, device, tb_logger: 'SummaryWriter') -> None: # noqa | |
| """ | |
| Overview: | |
| Initialize ``self.`` See ``help(type(self))`` for accurate signature. | |
| Some rules in naming the attributes of ``self.``: | |
| - ``e_`` : expert values | |
| - ``_sigma_`` : standard division values | |
| - ``p_`` : current policy values | |
| - ``_s_`` : states | |
| - ``_a_`` : actions | |
| Arguments: | |
| - cfg (:obj:`Dict`): Training config | |
| - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda" | |
| - tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary | |
| """ | |
| super(PdeilRewardModel, self).__init__() | |
| try: | |
| import scipy.stats as stats | |
| self.stats = stats | |
| except ImportError: | |
| import sys | |
| logging.warning("Please install scipy first, such as `pip3 install scipy`.") | |
| sys.exit(1) | |
| self.cfg: dict = cfg | |
| self.e_u_s = None | |
| self.e_sigma_s = None | |
| if cfg.discrete_action: | |
| self.svm = None | |
| else: | |
| self.e_u_s_a = None | |
| self.e_sigma_s_a = None | |
| self.p_u_s = None | |
| self.p_sigma_s = None | |
| self.expert_data = None | |
| self.train_data: list = [] | |
| assert device in ["cpu", "cuda"] or "cuda" in device | |
| # pedil default use cpu device | |
| self.device = 'cpu' | |
| self.load_expert_data() | |
| states: list = [] | |
| actions: list = [] | |
| for item in self.expert_data: | |
| states.append(item['obs']) | |
| actions.append(item['action']) | |
| states: torch.Tensor = torch.stack(states, dim=0) | |
| actions: torch.Tensor = torch.stack(actions, dim=0) | |
| self.e_u_s: torch.Tensor = torch.mean(states, axis=0) | |
| self.e_sigma_s: torch.Tensor = cov(states, rowvar=False) | |
| if self.cfg.discrete_action and SVC is None: | |
| one_time_warning("You are using discrete action while the SVC is not installed!") | |
| if self.cfg.discrete_action and SVC is not None: | |
| self.svm: SVC = SVC(probability=True) | |
| self.svm.fit(states.cpu().numpy(), actions.cpu().numpy()) | |
| else: | |
| # states action conjuct | |
| state_actions = torch.cat((states, actions.float()), dim=-1) | |
| self.e_u_s_a = torch.mean(state_actions, axis=0) | |
| self.e_sigma_s_a = cov(state_actions, rowvar=False) | |
| def load_expert_data(self) -> None: | |
| """ | |
| Overview: | |
| Getting the expert data from ``config['expert_data_path']`` attribute in self. | |
| Effects: | |
| This is a side effect function which updates the expert data attribute (e.g. ``self.expert_data``) | |
| """ | |
| expert_data_path: str = self.cfg.expert_data_path | |
| with open(expert_data_path, 'rb') as f: | |
| self.expert_data: list = pickle.load(f) | |
| def _train(self, states: torch.Tensor) -> None: | |
| """ | |
| Overview: | |
| Helper function for ``train`` which caclulates loss for train data and expert data. | |
| Arguments: | |
| - states (:obj:`torch.Tensor`): current policy states | |
| Effects: | |
| - Update attributes of ``p_u_s`` and ``p_sigma_s`` | |
| """ | |
| # we only need to collect the current policy state | |
| self.p_u_s = torch.mean(states, axis=0) | |
| self.p_sigma_s = cov(states, rowvar=False) | |
| def train(self): | |
| """ | |
| Overview: | |
| Training the Pdeil reward model. | |
| """ | |
| states = torch.stack([item['obs'] for item in self.train_data], dim=0) | |
| self._train(states) | |
| def _batch_mn_pdf(self, x: np.ndarray, mean: np.ndarray, cov: np.ndarray) -> np.ndarray: | |
| """ | |
| Overview: | |
| Get multivariate normal pdf of given np array. | |
| """ | |
| return np.asarray( | |
| self.stats.multivariate_normal.pdf(x, mean=mean, cov=cov, allow_singular=False), dtype=np.float32 | |
| ) | |
| def estimate(self, data: list) -> List[Dict]: | |
| """ | |
| Overview: | |
| Estimate reward by rewriting the reward keys. | |
| Arguments: | |
| - data (:obj:`list`): the list of data used for estimation,\ | |
| with at least ``obs`` and ``action`` keys. | |
| Effects: | |
| - This is a side effect function which updates the reward values in place. | |
| """ | |
| # NOTE: deepcopy reward part of data is very important, | |
| # otherwise the reward of data in the replay buffer will be incorrectly modified. | |
| train_data_augmented = self.reward_deepcopy(data) | |
| s = torch.stack([item['obs'] for item in train_data_augmented], dim=0) | |
| a = torch.stack([item['action'] for item in train_data_augmented], dim=0) | |
| if self.p_u_s is None: | |
| print("you need to train you reward model first") | |
| for item in train_data_augmented: | |
| item['reward'].zero_() | |
| else: | |
| rho_1 = self._batch_mn_pdf(s.cpu().numpy(), self.e_u_s.cpu().numpy(), self.e_sigma_s.cpu().numpy()) | |
| rho_1 = torch.from_numpy(rho_1) | |
| rho_2 = self._batch_mn_pdf(s.cpu().numpy(), self.p_u_s.cpu().numpy(), self.p_sigma_s.cpu().numpy()) | |
| rho_2 = torch.from_numpy(rho_2) | |
| if self.cfg.discrete_action: | |
| rho_3 = self.svm.predict_proba(s.cpu().numpy())[a.cpu().numpy()] | |
| rho_3 = torch.from_numpy(rho_3) | |
| else: | |
| s_a = torch.cat([s, a.float()], dim=-1) | |
| rho_3 = self._batch_mn_pdf( | |
| s_a.cpu().numpy(), | |
| self.e_u_s_a.cpu().numpy(), | |
| self.e_sigma_s_a.cpu().numpy() | |
| ) | |
| rho_3 = torch.from_numpy(rho_3) | |
| rho_3 = rho_3 / rho_1 | |
| alpha = self.cfg.alpha | |
| beta = 1 - alpha | |
| den = rho_1 * rho_3 | |
| frac = alpha * rho_1 + beta * rho_2 | |
| if frac.abs().max() < 1e-4: | |
| for item in train_data_augmented: | |
| item['reward'].zero_() | |
| else: | |
| reward = den / frac | |
| reward = torch.chunk(reward, reward.shape[0], dim=0) | |
| for item, rew in zip(train_data_augmented, reward): | |
| item['reward'] = rew | |
| return train_data_augmented | |
| def collect_data(self, item: list): | |
| """ | |
| Overview: | |
| Collecting training data by iterating data items in the input list | |
| Arguments: | |
| - data (:obj:`list`): Raw training data (e.g. some form of states, actions, obs, etc) | |
| Effects: | |
| - This is a side effect function which updates the data attribute in ``self`` by \ | |
| iterating data items in the input data items' list | |
| """ | |
| self.train_data.extend(item) | |
| def clear_data(self): | |
| """ | |
| Overview: | |
| Clearing training data. \ | |
| This is a side effect function which clears the data attribute in ``self`` | |
| """ | |
| self.train_data.clear() | |