Spaces:
Sleeping
Sleeping
| from typing import Callable, Any, List, Dict, Optional, Union, TYPE_CHECKING | |
| import copy | |
| import numpy as np | |
| import torch | |
| from ding.utils import SumSegmentTree, MinSegmentTree | |
| from ding.data.buffer.buffer import BufferedData | |
| if TYPE_CHECKING: | |
| from ding.data.buffer.buffer import Buffer | |
| class PriorityExperienceReplay: | |
| """ | |
| Overview: | |
| The middleware that implements priority experience replay (PER). | |
| """ | |
| def __init__( | |
| self, | |
| buffer: 'Buffer', | |
| IS_weight: bool = True, | |
| priority_power_factor: float = 0.6, | |
| IS_weight_power_factor: float = 0.4, | |
| IS_weight_anneal_train_iter: int = int(1e5), | |
| ) -> None: | |
| """ | |
| Arguments: | |
| - buffer (:obj:`Buffer`): The buffer to use PER. | |
| - IS_weight (:obj:`bool`): Whether use importance sampling or not. | |
| - priority_power_factor (:obj:`float`): The factor that adjust the sensitivity between\ | |
| the sampling probability and the priority level. | |
| - IS_weight_power_factor (:obj:`float`): The factor that adjust the sensitivity between\ | |
| the sample rarity and sampling probability in importance sampling. | |
| - IS_weight_anneal_train_iter (:obj:`float`): The factor that controls the increasing of\ | |
| ``IS_weight_power_factor`` during training. | |
| """ | |
| self.buffer = buffer | |
| self.buffer_idx = {} | |
| self.buffer_size = buffer.size | |
| self.IS_weight = IS_weight | |
| self.priority_power_factor = priority_power_factor | |
| self.IS_weight_power_factor = IS_weight_power_factor | |
| self.IS_weight_anneal_train_iter = IS_weight_anneal_train_iter | |
| # Max priority till now, it's used to initizalize data's priority if "priority" is not passed in with the data. | |
| self.max_priority = 1.0 | |
| # Capacity needs to be the power of 2. | |
| capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size)))) | |
| self.sum_tree = SumSegmentTree(capacity) | |
| if self.IS_weight: | |
| self.min_tree = MinSegmentTree(capacity) | |
| self.delta_anneal = (1 - self.IS_weight_power_factor) / self.IS_weight_anneal_train_iter | |
| self.pivot = 0 | |
| def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> BufferedData: | |
| if meta is None: | |
| if 'priority' in data: | |
| meta = {'priority': data.pop('priority')} | |
| else: | |
| meta = {'priority': self.max_priority} | |
| else: | |
| if 'priority' not in meta: | |
| meta['priority'] = self.max_priority | |
| meta['priority_idx'] = self.pivot | |
| self._update_tree(meta['priority'], self.pivot) | |
| buffered = chain(data, meta=meta, *args, **kwargs) | |
| index = buffered.index | |
| self.buffer_idx[self.pivot] = index | |
| self.pivot = (self.pivot + 1) % self.buffer_size | |
| return buffered | |
| def sample(self, chain: Callable, size: int, *args, | |
| **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: | |
| # Divide [0, 1) into size intervals on average | |
| intervals = np.array([i * 1.0 / size for i in range(size)]) | |
| # Uniformly sample within each interval | |
| mass = intervals + np.random.uniform(size=(size, )) * 1. / size | |
| # Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree) | |
| mass *= self.sum_tree.reduce() | |
| indices = [self.sum_tree.find_prefixsum_idx(m) for m in mass] | |
| indices = [self.buffer_idx[i] for i in indices] | |
| # Sample with indices | |
| data = chain(indices=indices, *args, **kwargs) | |
| if self.IS_weight: | |
| # Calculate max weight for normalizing IS | |
| sum_tree_root = self.sum_tree.reduce() | |
| p_min = self.min_tree.reduce() / sum_tree_root | |
| buffer_count = self.buffer.count() | |
| max_weight = (buffer_count * p_min) ** (-self.IS_weight_power_factor) | |
| for i in range(len(data)): | |
| meta = data[i].meta | |
| priority_idx = meta['priority_idx'] | |
| p_sample = self.sum_tree[priority_idx] / sum_tree_root | |
| weight = (buffer_count * p_sample) ** (-self.IS_weight_power_factor) | |
| meta['priority_IS'] = weight / max_weight | |
| data[i].data['priority_IS'] = torch.as_tensor([meta['priority_IS']]).float() # for compability | |
| self.IS_weight_power_factor = min(1.0, self.IS_weight_power_factor + self.delta_anneal) | |
| return data | |
| def update(self, chain: Callable, index: str, data: Any, meta: Any, *args, **kwargs) -> None: | |
| update_flag = chain(index, data, meta, *args, **kwargs) | |
| if update_flag: # when update succeed | |
| assert meta is not None, "Please indicate dict-type meta in priority update" | |
| new_priority, idx = meta['priority'], meta['priority_idx'] | |
| assert new_priority >= 0, "new_priority should greater than 0, but found {}".format(new_priority) | |
| new_priority += 1e-5 # Add epsilon to avoid priority == 0 | |
| self._update_tree(new_priority, idx) | |
| self.max_priority = max(self.max_priority, new_priority) | |
| def delete(self, chain: Callable, index: str, *args, **kwargs) -> None: | |
| for item in self.buffer.storage: | |
| meta = item.meta | |
| priority_idx = meta['priority_idx'] | |
| self.sum_tree[priority_idx] = self.sum_tree.neutral_element | |
| self.min_tree[priority_idx] = self.min_tree.neutral_element | |
| self.buffer_idx.pop(priority_idx) | |
| return chain(index, *args, **kwargs) | |
| def clear(self, chain: Callable) -> None: | |
| self.max_priority = 1.0 | |
| capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size)))) | |
| self.sum_tree = SumSegmentTree(capacity) | |
| if self.IS_weight: | |
| self.min_tree = MinSegmentTree(capacity) | |
| self.buffer_idx = {} | |
| self.pivot = 0 | |
| chain() | |
| def _update_tree(self, priority: float, idx: int) -> None: | |
| weight = priority ** self.priority_power_factor | |
| self.sum_tree[idx] = weight | |
| if self.IS_weight: | |
| self.min_tree[idx] = weight | |
| def state_dict(self) -> Dict: | |
| return { | |
| 'max_priority': self.max_priority, | |
| 'IS_weight_power_factor': self.IS_weight_power_factor, | |
| 'sumtree': self.sumtree, | |
| 'mintree': self.mintree, | |
| 'buffer_idx': self.buffer_idx, | |
| } | |
| def load_state_dict(self, _state_dict: Dict, deepcopy: bool = False) -> None: | |
| for k, v in _state_dict.items(): | |
| if deepcopy: | |
| setattr(self, '{}'.format(k), copy.deepcopy(v)) | |
| else: | |
| setattr(self, '{}'.format(k), v) | |
| def __call__(self, action: str, chain: Callable, *args, **kwargs) -> Any: | |
| if action in ["push", "sample", "update", "delete", "clear"]: | |
| return getattr(self, action)(chain, *args, **kwargs) | |
| return chain(*args, **kwargs) | |