Spaces:
Sleeping
Sleeping
| from collections import defaultdict | |
| import math | |
| import queue | |
| from time import sleep, time | |
| import gym | |
| from ding.framework import Supervisor | |
| from typing import TYPE_CHECKING, Any, List, Union, Dict, Optional, Callable | |
| from ding.framework.supervisor import ChildType, RecvPayload, SendPayload | |
| from ding.utils import make_key_as_identifier | |
| from ditk import logging | |
| from ding.data import ShmBufferContainer | |
| import enum | |
| import treetensor.numpy as tnp | |
| import numbers | |
| if TYPE_CHECKING: | |
| from gym.spaces import Space | |
| class EnvState(enum.IntEnum): | |
| """ | |
| VOID -> RUN -> DONE | |
| """ | |
| VOID = 0 | |
| INIT = 1 | |
| RUN = 2 | |
| RESET = 3 | |
| DONE = 4 | |
| ERROR = 5 | |
| NEED_RESET = 6 | |
| class EnvRetryType(str, enum.Enum): | |
| RESET = "reset" | |
| RENEW = "renew" | |
| class EnvSupervisor(Supervisor): | |
| """ | |
| Manage multiple envs with supervisor. | |
| New features (compared to env manager): | |
| - Consistent interface in multi-process and multi-threaded mode. | |
| - Add asynchronous features and recommend using asynchronous methods. | |
| - Reset is performed after an error is encountered in the step method. | |
| Breaking changes (compared to env manager): | |
| - Without some states. | |
| """ | |
| def __init__( | |
| self, | |
| type_: ChildType = ChildType.PROCESS, | |
| env_fn: List[Callable] = None, | |
| retry_type: EnvRetryType = EnvRetryType.RESET, | |
| max_try: Optional[int] = None, | |
| max_retry: Optional[int] = None, | |
| auto_reset: bool = True, | |
| reset_timeout: Optional[int] = None, | |
| step_timeout: Optional[int] = None, | |
| retry_waiting_time: Optional[int] = None, | |
| episode_num: int = float("inf"), | |
| shared_memory: bool = True, | |
| copy_on_get: bool = True, | |
| **kwargs | |
| ) -> None: | |
| """ | |
| Overview: | |
| Supervisor that manage a group of envs. | |
| Arguments: | |
| - type_ (:obj:`ChildType`): Type of child process. | |
| - env_fn (:obj:`List[Callable]`): The function to create environment | |
| - retry_type (:obj:`EnvRetryType`): Retry reset or renew env. | |
| - max_try (:obj:`EasyDict`): Max try times for reset or step action. | |
| - max_retry (:obj:`Optional[int]`): Alias of max_try. | |
| - auto_reset (:obj:`bool`): Auto reset env if reach done. | |
| - reset_timeout (:obj:`Optional[int]`): Timeout in seconds for reset. | |
| - step_timeout (:obj:`Optional[int]`): Timeout in seconds for step. | |
| - retry_waiting_time (:obj:`Optional[float]`): Wait time on each retry. | |
| - shared_memory (:obj:`bool`): Use shared memory in multiprocessing. | |
| - copy_on_get (:obj:`bool`): Use copy on get in multiprocessing. | |
| """ | |
| if kwargs: | |
| logging.warning("Unknown parameters on env supervisor: {}".format(kwargs)) | |
| super().__init__(type_=type_) | |
| if type_ is not ChildType.PROCESS and (shared_memory or copy_on_get): | |
| logging.warning("shared_memory and copy_on_get only works in process mode.") | |
| self._shared_memory = type_ is ChildType.PROCESS and shared_memory | |
| self._copy_on_get = type_ is ChildType.PROCESS and copy_on_get | |
| self._env_fn = env_fn | |
| self._create_env_ref() | |
| self._obs_buffers = None | |
| if env_fn: | |
| if self._shared_memory: | |
| obs_space = self._observation_space | |
| if isinstance(obs_space, gym.spaces.Dict): | |
| # For multi_agent case, such as multiagent_mujoco and petting_zoo mpe. | |
| # Now only for the case that each agent in the team have the same obs structure | |
| # and corresponding shape. | |
| shape = {k: v.shape for k, v in obs_space.spaces.items()} | |
| dtype = {k: v.dtype for k, v in obs_space.spaces.items()} | |
| else: | |
| shape = obs_space.shape | |
| dtype = obs_space.dtype | |
| self._obs_buffers = { | |
| env_id: ShmBufferContainer(dtype, shape, copy_on_get=self._copy_on_get) | |
| for env_id in range(len(self._env_fn)) | |
| } | |
| for env_init in env_fn: | |
| self.register(env_init, shm_buffer=self._obs_buffers, shm_callback=self._shm_callback) | |
| else: | |
| for env_init in env_fn: | |
| self.register(env_init) | |
| self._retry_type = retry_type | |
| self._auto_reset = auto_reset | |
| if max_retry: | |
| logging.warning("The `max_retry` is going to be deprecated, use `max_try` instead!") | |
| self._max_try = max_try or max_retry or 1 | |
| self._reset_timeout = reset_timeout | |
| self._step_timeout = step_timeout | |
| self._retry_waiting_time = retry_waiting_time | |
| self._env_replay_path = None | |
| self._episode_num = episode_num | |
| self._init_states() | |
| def _init_states(self): | |
| self._env_seed = {} | |
| self._env_dynamic_seed = None | |
| self._env_replay_path = None | |
| self._env_states = {} | |
| self._reset_param = {} | |
| self._ready_obs = {} | |
| self._env_episode_count = {i: 0 for i in range(self.env_num)} | |
| self._retry_times = defaultdict(lambda: 0) | |
| self._last_called = defaultdict(lambda: {"step": math.inf, "reset": math.inf}) | |
| def _shm_callback(self, payload: RecvPayload, obs_buffers: Any): | |
| """ | |
| Overview: | |
| This method will be called in child worker, so we can put large data into shared memory | |
| and replace the original payload data to none, then reduce the serialization/deserialization cost. | |
| """ | |
| if payload.method == "reset" and payload.data is not None: | |
| obs_buffers[payload.proc_id].fill(payload.data) | |
| payload.data = None | |
| elif payload.method == "step" and payload.data is not None: | |
| obs_buffers[payload.proc_id].fill(payload.data.obs) | |
| payload.data._replace(obs=None) | |
| def _create_env_ref(self): | |
| # env_ref is used to acquire some common attributes of env, like obs_shape and act_shape | |
| self._env_ref = self._env_fn[0]() | |
| self._env_ref.reset() | |
| self._observation_space = self._env_ref.observation_space | |
| self._action_space = self._env_ref.action_space | |
| self._reward_space = self._env_ref.reward_space | |
| self._env_ref.close() | |
| def step(self, actions: Union[Dict[int, List[Any]], List[Any]], block: bool = True) -> Optional[List[tnp.ndarray]]: | |
| """ | |
| Overview: | |
| Execute env step according to input actions. And reset an env if done. | |
| Arguments: | |
| - actions (:obj:`List[tnp.ndarray]`): Actions came from outer caller like policy, \ | |
| in structure of {env_id: actions}. | |
| - block (:obj:`bool`): If block, return timesteps, else return none. | |
| Returns: | |
| - timesteps (:obj:`List[tnp.ndarray]`): Each timestep is a tnp.array with observation, reward, done, \ | |
| info, env_id. | |
| """ | |
| assert not self.closed, "Env supervisor has closed." | |
| if isinstance(actions, List): | |
| actions = {i: p for i, p in enumerate(actions)} | |
| assert actions, "Action is empty!" | |
| send_payloads = [] | |
| for env_id, act in actions.items(): | |
| payload = SendPayload(proc_id=env_id, method="step", args=[act]) | |
| send_payloads.append(payload) | |
| self.send(payload) | |
| if not block: | |
| # Retrieve the data for these steps from the recv method | |
| return | |
| # Wait for all steps returns | |
| recv_payloads = self.recv_all( | |
| send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._step_timeout | |
| ) | |
| return [payload.data for payload in recv_payloads] | |
| def recv(self, ignore_err: bool = False) -> RecvPayload: | |
| """ | |
| Overview: | |
| Wait for recv payload, this function will block the thread. | |
| Arguments: | |
| - ignore_err (:obj:`bool`): If ignore_err is true, payload with error object will be discarded.\ | |
| This option will not catch the exception. | |
| Returns: | |
| - recv_payload (:obj:`RecvPayload`): Recv payload. | |
| """ | |
| self._detect_timeout() | |
| try: | |
| payload = super().recv(ignore_err=True, timeout=0.1) | |
| payload = self._recv_callback(payload=payload) | |
| if payload.err: | |
| return self.recv(ignore_err=ignore_err) | |
| else: | |
| return payload | |
| except queue.Empty: | |
| return self.recv(ignore_err=ignore_err) | |
| def _detect_timeout(self): | |
| """ | |
| Overview: | |
| Try to restart all timeout environments if detected timeout. | |
| """ | |
| for env_id in self._last_called: | |
| if self._step_timeout and time() - self._last_called[env_id]["step"] > self._step_timeout: | |
| payload = RecvPayload( | |
| proc_id=env_id, method="step", err=TimeoutError("Step timeout on env {}".format(env_id)) | |
| ) | |
| self._recv_queue.put(payload) | |
| continue | |
| if self._reset_timeout and time() - self._last_called[env_id]["reset"] > self._reset_timeout: | |
| payload = RecvPayload( | |
| proc_id=env_id, method="reset", err=TimeoutError("Step timeout on env {}".format(env_id)) | |
| ) | |
| self._recv_queue.put(payload) | |
| continue | |
| def env_num(self) -> int: | |
| return len(self._children) | |
| def observation_space(self) -> 'Space': | |
| return self._observation_space | |
| def action_space(self) -> 'Space': | |
| return self._action_space | |
| def reward_space(self) -> 'Space': | |
| return self._reward_space | |
| def ready_obs(self) -> tnp.array: | |
| """ | |
| Overview: | |
| Get the ready (next) observation in ``tnp.array`` type, which is uniform for both async/sync scenarios. | |
| Return: | |
| - ready_obs (:obj:`tnp.array`): A stacked treenumpy-type observation data. | |
| Example: | |
| >>> obs = env_manager.ready_obs | |
| >>> action = model(obs) # model input np obs and output np action | |
| >>> timesteps = env_manager.step(action) | |
| """ | |
| active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN] | |
| active_env.sort() | |
| obs = [self._ready_obs.get(i) for i in active_env] | |
| if len(obs) == 0: | |
| return tnp.array([]) | |
| return tnp.stack(obs) | |
| def ready_obs_id(self) -> List[int]: | |
| return [i for i, s in self.env_states.items() if s == EnvState.RUN] | |
| def done(self) -> bool: | |
| return all([s == EnvState.DONE for s in self.env_states.values()]) | |
| def method_name_list(self) -> List[str]: | |
| return ['reset', 'step', 'seed', 'close', 'enable_save_replay'] | |
| def env_states(self) -> Dict[int, EnvState]: | |
| return {env_id: self._env_states.get(env_id) or EnvState.VOID for env_id in range(self.env_num)} | |
| def env_state_done(self, env_id: int) -> bool: | |
| return self.env_states[env_id] == EnvState.DONE | |
| def launch(self, reset_param: Optional[Dict] = None, block: bool = True) -> None: | |
| """ | |
| Overview: | |
| Set up the environments and their parameters. | |
| Arguments: | |
| - reset_param (:obj:`Optional[Dict]`): Dict of reset parameters for each environment, key is the env_id, \ | |
| value is the cooresponding reset parameters. | |
| - block (:obj:`block`): Whether will block the process and wait for reset states. | |
| """ | |
| assert self.closed, "Please first close the env supervisor before launch it" | |
| if reset_param is not None: | |
| assert len(reset_param) == self.env_num | |
| self.start_link() | |
| self._send_seed(self._env_seed, self._env_dynamic_seed, block=block) | |
| self.reset(reset_param, block=block) | |
| self._enable_env_replay() | |
| def reset(self, reset_param: Optional[Dict[int, List[Any]]] = None, block: bool = True) -> None: | |
| """ | |
| Overview: | |
| Reset an environment. | |
| Arguments: | |
| - reset_param (:obj:`Optional[Dict[int, List[Any]]]`): Dict of reset parameters for each environment, \ | |
| key is the env_id, value is the cooresponding reset parameters. | |
| - block (:obj:`block`): Whether will block the process and wait for reset states. | |
| """ | |
| if not reset_param: | |
| reset_param = {i: {} for i in range(self.env_num)} | |
| elif isinstance(reset_param, List): | |
| reset_param = {i: p for i, p in enumerate(reset_param)} | |
| send_payloads = [] | |
| for env_id, kw_param in reset_param.items(): | |
| self._reset_param[env_id] = kw_param # For auto reset | |
| send_payloads += self._reset(env_id, kw_param=kw_param) | |
| if not block: | |
| return | |
| self.recv_all(send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._reset_timeout) | |
| def _recv_callback( | |
| self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None | |
| ) -> RecvPayload: | |
| """ | |
| Overview: | |
| The callback function for each received payload, within this method will modify the state of \ | |
| each environment, replace objects in shared memory, and determine if a retry is needed due to an error. | |
| Arguments: | |
| - payload (:obj:`RecvPayload`): The received payload. | |
| - remain_payloads (:obj:`Optional[Dict[str, SendPayload]]`): The callback may be called many times \ | |
| until remain_payloads be cleared, you can append new payload into remain_payloads to call this \ | |
| callback recursively. | |
| """ | |
| self._set_shared_obs(payload=payload) | |
| self.change_state(payload=payload) | |
| if payload.method == "reset": | |
| return self._recv_reset_callback(payload=payload, remain_payloads=remain_payloads) | |
| elif payload.method == "step": | |
| return self._recv_step_callback(payload=payload, remain_payloads=remain_payloads) | |
| return payload | |
| def _set_shared_obs(self, payload: RecvPayload): | |
| if self._obs_buffers is None: | |
| return | |
| if payload.method == "reset" and payload.err is None: | |
| payload.data = self._obs_buffers[payload.proc_id].get() | |
| elif payload.method == "step" and payload.err is None: | |
| payload.data._replace(obs=self._obs_buffers[payload.proc_id].get()) | |
| def _recv_reset_callback( | |
| self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None | |
| ) -> RecvPayload: | |
| assert payload.method == "reset", "Recv error callback({}) in reset callback!".format(payload.method) | |
| if remain_payloads is None: | |
| remain_payloads = {} | |
| env_id = payload.proc_id | |
| if payload.err: | |
| self._retry_times[env_id] += 1 | |
| if self._retry_times[env_id] > self._max_try - 1: | |
| self.shutdown(5) | |
| raise RuntimeError( | |
| "Env {} reset has exceeded max_try({}), and the latest exception is: {}".format( | |
| env_id, self._max_try, payload.err | |
| ) | |
| ) | |
| if self._retry_waiting_time: | |
| sleep(self._retry_waiting_time) | |
| if self._retry_type == EnvRetryType.RENEW: | |
| self._children[env_id].restart() | |
| send_payloads = self._reset(env_id) | |
| for p in send_payloads: | |
| remain_payloads[p.req_id] = p | |
| else: | |
| self._retry_times[env_id] = 0 | |
| self._ready_obs[env_id] = payload.data | |
| return payload | |
| def _recv_step_callback( | |
| self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None | |
| ) -> RecvPayload: | |
| assert payload.method == "step", "Recv error callback({}) in step callback!".format(payload.method) | |
| if remain_payloads is None: | |
| remain_payloads = {} | |
| if payload.err: | |
| send_payloads = self._reset(payload.proc_id) | |
| for p in send_payloads: | |
| remain_payloads[p.req_id] = p | |
| info = {"abnormal": True, "err": payload.err} | |
| payload.data = tnp.array( | |
| { | |
| 'obs': None, | |
| 'reward': None, | |
| 'done': None, | |
| 'info': info, | |
| 'env_id': payload.proc_id | |
| } | |
| ) | |
| else: | |
| obs, reward, done, info, *_ = payload.data | |
| if done: | |
| self._env_episode_count[payload.proc_id] += 1 | |
| if self._env_episode_count[payload.proc_id] < self._episode_num and self._auto_reset: | |
| send_payloads = self._reset(payload.proc_id) | |
| for p in send_payloads: | |
| remain_payloads[p.req_id] = p | |
| # make the type and content of key as similar as identifier, | |
| # in order to call them as attribute (e.g. timestep.xxx), such as ``TimeLimit.truncated`` in cartpole info | |
| info = make_key_as_identifier(info) | |
| payload.data = tnp.array( | |
| { | |
| 'obs': obs, | |
| 'reward': reward, | |
| 'done': done, | |
| 'info': info, | |
| 'env_id': payload.proc_id | |
| } | |
| ) | |
| self._ready_obs[payload.proc_id] = obs | |
| return payload | |
| def _reset(self, env_id: int, kw_param: Optional[Dict[str, Any]] = None) -> List[SendPayload]: | |
| """ | |
| Overview: | |
| Reset an environment. This method does not wait for the result to be returned. | |
| Arguments: | |
| - env_id (:obj:`int`): Environment id. | |
| - kw_param (:obj:`Optional[Dict[str, Any]]`): Reset parameters for the environment. | |
| Returns: | |
| - send_payloads (:obj:`List[SendPayload]`): The request payloads for seed and reset actions. | |
| """ | |
| assert not self.closed, "Env supervisor has closed." | |
| send_payloads = [] | |
| kw_param = kw_param or self._reset_param[env_id] | |
| if self._env_replay_path is not None and self.env_states[env_id] == EnvState.RUN: | |
| logging.warning("Please don't reset an unfinished env when you enable save replay, we just skip it") | |
| return send_payloads | |
| # Reset env | |
| payload = SendPayload(proc_id=env_id, method="reset", kwargs=kw_param) | |
| send_payloads.append(payload) | |
| self.send(payload) | |
| return send_payloads | |
| def _send_seed(self, env_seed: Dict[int, int], env_dynamic_seed: Optional[bool] = None, block: bool = True) -> None: | |
| send_payloads = [] | |
| for env_id, seed in env_seed.items(): | |
| if seed is None: | |
| continue | |
| args = [seed] | |
| if env_dynamic_seed is not None: | |
| args.append(env_dynamic_seed) | |
| payload = SendPayload(proc_id=env_id, method="seed", args=args) | |
| send_payloads.append(payload) | |
| self.send(payload) | |
| if not block or not send_payloads: | |
| return | |
| self.recv_all(send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._reset_timeout) | |
| def change_state(self, payload: RecvPayload): | |
| self._last_called[payload.proc_id][payload.method] = math.inf # Have recevied | |
| if payload.err: | |
| self._env_states[payload.proc_id] = EnvState.ERROR | |
| elif payload.method == "reset": | |
| self._env_states[payload.proc_id] = EnvState.RUN | |
| elif payload.method == "step": | |
| if payload.data[2]: | |
| self._env_states[payload.proc_id] = EnvState.DONE | |
| def send(self, payload: SendPayload) -> None: | |
| self._last_called[payload.proc_id][payload.method] = time() | |
| return super().send(payload) | |
| def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: Optional[bool] = None) -> None: | |
| """ | |
| Overview: | |
| Set the seed for each environment. The seed function will not be called until supervisor.launch \ | |
| was called. | |
| Arguments: | |
| - seed (:obj:`Union[Dict[int, int], List[int], int]`): List of seeds for each environment; \ | |
| Or one seed for the first environment and other seeds are generated automatically. \ | |
| Note that in threading mode, no matter how many seeds are given, only the last one will take effect. \ | |
| Because the execution in the thread is asynchronous, the results of each experiment \ | |
| are different even if a fixed seed is used. | |
| - dynamic_seed (:obj:`Optional[bool]`): Dynamic seed is used in the training environment, \ | |
| trying to make the random seed of each episode different, they are all generated in the reset \ | |
| method by a random generator 100 * np.random.randint(1 , 1000) (but the seed of this random \ | |
| number generator is fixed by the environmental seed method, guranteeing the reproducibility \ | |
| of the experiment). You need not pass the dynamic_seed parameter in the seed method, or pass \ | |
| the parameter as True. | |
| """ | |
| self._env_seed = {} | |
| if isinstance(seed, numbers.Integral): | |
| self._env_seed = {i: seed + i for i in range(self.env_num)} | |
| elif isinstance(seed, list): | |
| assert len(seed) == self.env_num, "len(seed) {:d} != env_num {:d}".format(len(seed), self.env_num) | |
| self._env_seed = {i: _seed for i, _seed in enumerate(seed)} | |
| elif isinstance(seed, dict): | |
| self._env_seed = {env_id: s for env_id, s in seed.items()} | |
| else: | |
| raise TypeError("Invalid seed arguments type: {}".format(type(seed))) | |
| self._env_dynamic_seed = dynamic_seed | |
| def enable_save_replay(self, replay_path: Union[List[str], str]) -> None: | |
| """ | |
| Overview: | |
| Set each env's replay save path. | |
| Arguments: | |
| - replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \ | |
| Or one path for all environments. | |
| """ | |
| if isinstance(replay_path, str): | |
| replay_path = [replay_path] * self.env_num | |
| self._env_replay_path = replay_path | |
| def _enable_env_replay(self): | |
| if self._env_replay_path is None: | |
| return | |
| send_payloads = [] | |
| for env_id, s in enumerate(self._env_replay_path): | |
| payload = SendPayload(proc_id=env_id, method="enable_save_replay", args=[s]) | |
| send_payloads.append(payload) | |
| self.send(payload) | |
| self.recv_all(send_payloads=send_payloads) | |
| def __getattr__(self, key: str) -> List[Any]: | |
| if not hasattr(self._env_ref, key): | |
| raise AttributeError("env `{}` doesn't have the attribute `{}`".format(type(self._env_ref), key)) | |
| return super().__getattr__(key) | |
| def close(self, timeout: Optional[float] = None) -> None: | |
| """ | |
| In order to be compatible with BaseEnvManager, the new version can use `shutdown` directly. | |
| """ | |
| self.shutdown(timeout=timeout) | |
| def shutdown(self, timeout: Optional[float] = None) -> None: | |
| if self._running: | |
| send_payloads = [] | |
| for env_id in range(self.env_num): | |
| payload = SendPayload(proc_id=env_id, method="close") | |
| send_payloads.append(payload) | |
| self.send(payload) | |
| self.recv_all(send_payloads=send_payloads, ignore_err=True, timeout=timeout) | |
| super().shutdown(timeout=timeout) | |
| self._init_states() | |
| def closed(self) -> bool: | |
| return not self._running | |