Spaces:
Sleeping
Sleeping
| from typing import Any, Optional, Union, Tuple, Dict | |
| from multiprocessing import Array | |
| import ctypes | |
| import numpy as np | |
| import torch | |
| _NTYPE_TO_CTYPE = { | |
| np.bool_: ctypes.c_bool, | |
| np.uint8: ctypes.c_uint8, | |
| np.uint16: ctypes.c_uint16, | |
| np.uint32: ctypes.c_uint32, | |
| np.uint64: ctypes.c_uint64, | |
| np.int8: ctypes.c_int8, | |
| np.int16: ctypes.c_int16, | |
| np.int32: ctypes.c_int32, | |
| np.int64: ctypes.c_int64, | |
| np.float32: ctypes.c_float, | |
| np.float64: ctypes.c_double, | |
| } | |
| class ShmBuffer(): | |
| """ | |
| Overview: | |
| Shared memory buffer to store numpy array. | |
| """ | |
| def __init__( | |
| self, | |
| dtype: Union[type, np.dtype], | |
| shape: Tuple[int], | |
| copy_on_get: bool = True, | |
| ctype: Optional[type] = None | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize the buffer. | |
| Arguments: | |
| - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. | |
| - shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer. | |
| - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. | |
| - ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor. | |
| """ | |
| if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype | |
| dtype = dtype.type | |
| self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape))) | |
| self.dtype = dtype | |
| self.shape = shape | |
| self.copy_on_get = copy_on_get | |
| self.ctype = ctype | |
| def fill(self, src_arr: np.ndarray) -> None: | |
| """ | |
| Overview: | |
| Fill the shared memory buffer with a numpy array. (Replace the original one.) | |
| Arguments: | |
| - src_arr (:obj:`np.ndarray`): array to fill the buffer. | |
| """ | |
| assert isinstance(src_arr, np.ndarray), type(src_arr) | |
| # for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten | |
| # for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten | |
| # so we reshape dst_arr rather than flatten src_arr | |
| dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) | |
| np.copyto(dst_arr, src_arr) | |
| def get(self) -> np.ndarray: | |
| """ | |
| Overview: | |
| Get the array stored in the buffer. | |
| Return: | |
| - data (:obj:`np.ndarray`): A copy of the data stored in the buffer. | |
| """ | |
| data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) | |
| if self.copy_on_get: | |
| data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory | |
| if self.ctype is torch.Tensor: | |
| data = torch.from_numpy(data) | |
| return data | |
| class ShmBufferContainer(object): | |
| """ | |
| Overview: | |
| Support multiple shared memory buffers. Each key-value is name-buffer. | |
| """ | |
| def __init__( | |
| self, | |
| dtype: Union[Dict[Any, type], type, np.dtype], | |
| shape: Union[Dict[Any, tuple], tuple], | |
| copy_on_get: bool = True | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize the buffer container. | |
| Arguments: | |
| - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. | |
| - shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \ | |
| multiple buffers; If `tuple`, use single buffer. | |
| - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. | |
| """ | |
| if isinstance(shape, dict): | |
| self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()} | |
| elif isinstance(shape, (tuple, list)): | |
| self._data = ShmBuffer(dtype, shape, copy_on_get) | |
| else: | |
| raise RuntimeError("not support shape: {}".format(shape)) | |
| self._shape = shape | |
| def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None: | |
| """ | |
| Overview: | |
| Fill the one or many shared memory buffer. | |
| Arguments: | |
| - src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer. | |
| """ | |
| if isinstance(self._shape, dict): | |
| for k in self._shape.keys(): | |
| self._data[k].fill(src_arr[k]) | |
| elif isinstance(self._shape, (tuple, list)): | |
| self._data.fill(src_arr) | |
| def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]: | |
| """ | |
| Overview: | |
| Get the one or many arrays stored in the buffer. | |
| Return: | |
| - data (:obj:`np.ndarray`): The array(s) stored in the buffer. | |
| """ | |
| if isinstance(self._shape, dict): | |
| return {k: self._data[k].get() for k in self._shape.keys()} | |
| elif isinstance(self._shape, (tuple, list)): | |
| return self._data.get() | |