Spaces:
Sleeping
Sleeping
| from abc import abstractmethod, ABC | |
| from typing import Any, List, Optional, Union, Callable | |
| import copy | |
| from dataclasses import dataclass | |
| from functools import wraps | |
| from ding.utils import fastcopy | |
| def apply_middleware(func_name: str): | |
| def wrap_func(base_func: Callable): | |
| def handler(buffer, *args, **kwargs): | |
| """ | |
| Overview: | |
| The real processing starts here, we apply the middleware one by one, | |
| each middleware will receive next `chained` function, which is an executor of next | |
| middleware. You can change the input arguments to the next `chained` middleware, and you | |
| also can get the return value from the next middleware, so you have the | |
| maximum freedom to choose at what stage to implement your method. | |
| """ | |
| def wrap_handler(middleware, *args, **kwargs): | |
| if len(middleware) == 0: | |
| return base_func(buffer, *args, **kwargs) | |
| def chain(*args, **kwargs): | |
| return wrap_handler(middleware[1:], *args, **kwargs) | |
| func = middleware[0] | |
| return func(func_name, chain, *args, **kwargs) | |
| return wrap_handler(buffer._middleware, *args, **kwargs) | |
| return handler | |
| return wrap_func | |
| class BufferedData: | |
| data: Any | |
| index: str | |
| meta: dict | |
| # Register new dispatcher on fastcopy to avoid circular references | |
| def _copy_buffereddata(d: BufferedData) -> BufferedData: | |
| return BufferedData(data=fastcopy.copy(d.data), index=d.index, meta=fastcopy.copy(d.meta)) | |
| fastcopy.dispatch[BufferedData] = _copy_buffereddata | |
| class Buffer(ABC): | |
| """ | |
| Buffer is an abstraction of device storage, third-party services or data structures, | |
| For example, memory queue, sum-tree, redis, or di-store. | |
| """ | |
| def __init__(self, size: int) -> None: | |
| self._middleware = [] | |
| self.size = size | |
| def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData: | |
| """ | |
| Overview: | |
| Push data and it's meta information in buffer. | |
| Arguments: | |
| - data (:obj:`Any`): The data which will be pushed into buffer. | |
| - meta (:obj:`dict`): Meta information, e.g. priority, count, staleness. | |
| Returns: | |
| - buffered_data (:obj:`BufferedData`): The pushed data. | |
| """ | |
| raise NotImplementedError | |
| def sample( | |
| self, | |
| size: Optional[int] = None, | |
| indices: Optional[List[str]] = None, | |
| replace: bool = False, | |
| sample_range: Optional[slice] = None, | |
| ignore_insufficient: bool = False, | |
| groupby: Optional[str] = None, | |
| unroll_len: Optional[int] = None | |
| ) -> Union[List[BufferedData], List[List[BufferedData]]]: | |
| """ | |
| Overview: | |
| Sample data with length ``size``. | |
| Arguments: | |
| - size (:obj:`Optional[int]`): The number of the data that will be sampled. | |
| - indices (:obj:`Optional[List[str]]`): Sample with multiple indices. | |
| - replace (:obj:`bool`): If use replace is true, you may receive duplicated data from the buffer. | |
| - sample_range (:obj:`slice`): Sample range slice. | |
| - ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size | |
| with no repetition will not cause an exception. | |
| - groupby (:obj:`Optional[str]`): Groupby key in meta, i.e. groupby="episode" | |
| - unroll_len (:obj:`Optional[int]`): Number of consecutive frames within a group. | |
| Returns: | |
| - sample_data (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`): | |
| A list of data with length ``size``, may be nested if groupby is set. | |
| """ | |
| raise NotImplementedError | |
| def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool: | |
| """ | |
| Overview: | |
| Update data and meta by index | |
| Arguments: | |
| - index (:obj:`str`): Index of data. | |
| - data (:obj:`any`): Pure data. | |
| - meta (:obj:`dict`): Meta information. | |
| Returns: | |
| - success (:obj:`bool`): Success or not, if data with the index not exist in buffer, return false. | |
| """ | |
| raise NotImplementedError | |
| def delete(self, index: str): | |
| """ | |
| Overview: | |
| Delete one data sample by index | |
| Arguments: | |
| - index (:obj:`str`): Index | |
| """ | |
| raise NotImplementedError | |
| def save_data(self, file_name: str): | |
| """ | |
| Overview: | |
| Save buffer data into a file. | |
| Arguments: | |
| - file_name (:obj:`str`): file name of buffer data | |
| """ | |
| raise NotImplementedError | |
| def load_data(self, file_name: str): | |
| """ | |
| Overview: | |
| Load buffer data from a file. | |
| Arguments: | |
| - file_name (:obj:`str`): file name of buffer data | |
| """ | |
| raise NotImplementedError | |
| def count(self) -> int: | |
| raise NotImplementedError | |
| def clear(self) -> None: | |
| raise NotImplementedError | |
| def get(self, idx: int) -> BufferedData: | |
| """ | |
| Overview: | |
| Get item by subscript index | |
| Arguments: | |
| - idx (:obj:`int`): Subscript index | |
| Returns: | |
| - buffered_data (:obj:`BufferedData`): Item from buffer | |
| """ | |
| raise NotImplementedError | |
| def use(self, func: Callable) -> "Buffer": | |
| """ | |
| Overview: | |
| Use algorithm middleware to modify the behavior of the buffer. | |
| Every middleware should be a callable function, it will receive three argument parts, including: | |
| 1. The buffer instance, you can use this instance to visit every thing of the buffer, including the storage. | |
| 2. The functions called by the user, there are three methods named `push` , `sample` and `clear` , \ | |
| so you can use these function name to decide which action to choose. | |
| 3. The remaining arguments passed by the user to the original function, will be passed in `*args` . | |
| Each middleware handler should return two parts of the value, including: | |
| 1. The first value is `done` (True or False), if done==True, the middleware chain will stop immediately, \ | |
| no more middleware will be executed during this execution | |
| 2. The remaining values, will be passed to the next middleware or the default function in the buffer. | |
| Arguments: | |
| - func (:obj:`Callable`): The middleware handler | |
| Returns: | |
| - buffer (:obj:`Buffer`): The instance self | |
| """ | |
| self._middleware.append(func) | |
| return self | |
| def view(self) -> "Buffer": | |
| r""" | |
| Overview: | |
| A view is a new instance of buffer, with a deepcopy of every property except the storage. | |
| The storage is shared among all the buffer instances. | |
| Returns: | |
| - buffer (:obj:`Buffer`): The instance self | |
| """ | |
| return copy.copy(self) | |
| def __copy__(self) -> "Buffer": | |
| raise NotImplementedError | |
| def __len__(self) -> int: | |
| return self.count() | |
| def __getitem__(self, idx: int) -> BufferedData: | |
| return self.get(idx) | |