Spaces:
Sleeping
Sleeping
| from types import MethodType | |
| from typing import Union, Any, List, Callable, Dict, Optional, Tuple | |
| from functools import partial, wraps | |
| from easydict import EasyDict | |
| from ditk import logging | |
| import copy | |
| import platform | |
| import numbers | |
| import enum | |
| import time | |
| import treetensor.numpy as tnp | |
| from ding.utils import ENV_MANAGER_REGISTRY, import_module, one_time_warning, make_key_as_identifier, WatchDog, \ | |
| remove_illegal_item | |
| from ding.envs import BaseEnv, BaseEnvTimestep | |
| global space_log_flag | |
| space_log_flag = True | |
| class EnvState(enum.IntEnum): | |
| VOID = 0 | |
| INIT = 1 | |
| RUN = 2 | |
| RESET = 3 | |
| DONE = 4 | |
| ERROR = 5 | |
| NEED_RESET = 6 | |
| def timeout_wrapper(func: Callable = None, timeout: Optional[int] = None) -> Callable: | |
| """ | |
| Overview: | |
| Watch the function that must be finihsed within a period of time. If timeout, raise the captured error. | |
| """ | |
| if func is None: | |
| return partial(timeout_wrapper, timeout=timeout) | |
| if timeout is None: | |
| return func | |
| windows_flag = platform.system().lower() == 'windows' | |
| if windows_flag: | |
| one_time_warning("Timeout wrapper is not implemented in windows platform, so ignore it default") | |
| return func | |
| def wrapper(*args, **kwargs): | |
| watchdog = WatchDog(timeout) | |
| try: | |
| watchdog.start() | |
| except ValueError as e: | |
| # watchdog invalid case | |
| return func(*args, **kwargs) | |
| try: | |
| return func(*args, **kwargs) | |
| except BaseException as e: | |
| raise e | |
| finally: | |
| watchdog.stop() | |
| return wrapper | |
| class BaseEnvManager(object): | |
| """ | |
| Overview: | |
| The basic class of env manager to manage multiple vectorized environments. BaseEnvManager define all the | |
| necessary interfaces and derived class must extend this basic class. | |
| The class is implemented by the pseudo-parallelism (i.e. serial) mechanism, therefore, this class is only | |
| used in some tiny environments and for debug purpose. | |
| Interfaces: | |
| reset, step, seed, close, enable_save_replay, launch, default_config, reward_shaping, enable_save_figure | |
| Properties: | |
| env_num, env_ref, ready_obs, ready_obs_id, ready_imgs, done, closed, method_name_list, observation_space, \ | |
| action_space, reward_space | |
| """ | |
| def default_config(cls: type) -> EasyDict: | |
| """ | |
| Overview: | |
| Return the deepcopyed default config of env manager. | |
| Returns: | |
| - cfg (:obj:`EasyDict`): The default config of env manager. | |
| """ | |
| cfg = EasyDict(copy.deepcopy(cls.config)) | |
| cfg.cfg_type = cls.__name__ + 'Dict' | |
| return cfg | |
| config = dict( | |
| # (int) The total episode number to be executed, defaults to inf, which means no episode limits. | |
| episode_num=float("inf"), | |
| # (int) The maximum retry times when the env is in error state, defaults to 1, i.e. no retry. | |
| max_retry=1, | |
| # (str) The retry type when the env is in error state, including ['reset', 'renew'], defaults to 'reset'. | |
| # The former is to reset the env to the last reset state, while the latter is to create a new env. | |
| retry_type='reset', | |
| # (bool) Whether to automatically reset sub-environments when they are done, defaults to True. | |
| auto_reset=True, | |
| # (float) WatchDog timeout (second) for ``step`` method, defaults to None, which means no timeout. | |
| step_timeout=None, | |
| # (float) WatchDog timeout (second) for ``reset`` method, defaults to None, which means no timeout. | |
| reset_timeout=None, | |
| # (float) The interval waiting time for automatically retry mechanism, defaults to 0.1. | |
| retry_waiting_time=0.1, | |
| ) | |
| def __init__( | |
| self, | |
| env_fn: List[Callable], | |
| cfg: EasyDict = EasyDict({}), | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize the base env manager with callable the env function and the EasyDict-type config. Here we use | |
| ``env_fn`` to ensure the lazy initialization of sub-environments, which is benetificial to resource | |
| allocation and parallelism. ``cfg`` is the merged result between the default config of this class | |
| and user's config. | |
| This construction function is in lazy-initialization mode, the actual initialization is in ``launch``. | |
| Arguments: | |
| - env_fn (:obj:`List[Callable]`): A list of functions to create ``env_num`` sub-environments. | |
| - cfg (:obj:`EasyDict`): Final merged config. | |
| .. note:: | |
| For more details about how to merge config, please refer to the system document of DI-engine \ | |
| (`en link <../03_system/config.html>`_). | |
| """ | |
| self._cfg = cfg | |
| self._env_fn = env_fn | |
| self._env_num = len(self._env_fn) | |
| self._closed = True | |
| self._env_replay_path = None | |
| # env_ref is used to acquire some common attributes of env, like obs_shape and act_shape | |
| self._env_ref = self._env_fn[0]() | |
| try: | |
| self._observation_space = self._env_ref.observation_space | |
| self._action_space = self._env_ref.action_space | |
| self._reward_space = self._env_ref.reward_space | |
| except: | |
| # For some environment, | |
| # we have to reset before getting observation description. | |
| # However, for dmc-mujoco, we should not reset the env at the main thread, | |
| # when using in a subprocess mode, which would cause opengl rendering bugs, | |
| # leading to no response subprocesses. | |
| self._env_ref.reset() | |
| self._observation_space = self._env_ref.observation_space | |
| self._action_space = self._env_ref.action_space | |
| self._reward_space = self._env_ref.reward_space | |
| self._env_ref.close() | |
| self._env_states = {i: EnvState.VOID for i in range(self._env_num)} | |
| self._env_seed = {i: None for i in range(self._env_num)} | |
| self._episode_num = self._cfg.episode_num | |
| self._max_retry = max(self._cfg.max_retry, 1) | |
| self._auto_reset = self._cfg.auto_reset | |
| self._retry_type = self._cfg.retry_type | |
| assert self._retry_type in ['reset', 'renew'], self._retry_type | |
| self._step_timeout = self._cfg.step_timeout | |
| self._reset_timeout = self._cfg.reset_timeout | |
| self._retry_waiting_time = self._cfg.retry_waiting_time | |
| def env_num(self) -> int: | |
| """ | |
| Overview: | |
| ``env_num`` is the number of sub-environments in env manager. | |
| Returns: | |
| - env_num (:obj:`int`): The number of sub-environments. | |
| """ | |
| return self._env_num | |
| def env_ref(self) -> 'BaseEnv': | |
| """ | |
| Overview: | |
| ``env_ref`` is used to acquire some common attributes of env, like obs_shape and act_shape. | |
| Returns: | |
| - env_ref (:obj:`BaseEnv`): The reference of sub-environment. | |
| """ | |
| return self._env_ref | |
| def observation_space(self) -> 'gym.spaces.Space': # noqa | |
| """ | |
| Overview: | |
| ``observation_space`` is the observation space of sub-environment, following the format of gym.spaces. | |
| Returns: | |
| - observation_space (:obj:`gym.spaces.Space`): The observation space of sub-environment. | |
| """ | |
| return self._observation_space | |
| def action_space(self) -> 'gym.spaces.Space': # noqa | |
| """ | |
| Overview: | |
| ``action_space`` is the action space of sub-environment, following the format of gym.spaces. | |
| Returns: | |
| - action_space (:obj:`gym.spaces.Space`): The action space of sub-environment. | |
| """ | |
| return self._action_space | |
| def reward_space(self) -> 'gym.spaces.Space': # noqa | |
| """ | |
| Overview: | |
| ``reward_space`` is the reward space of sub-environment, following the format of gym.spaces. | |
| Returns: | |
| - reward_space (:obj:`gym.spaces.Space`): The reward space of sub-environment. | |
| """ | |
| return self._reward_space | |
| def ready_obs(self) -> Dict[int, Any]: | |
| """ | |
| Overview: | |
| Get the ready (next) observation, which is a special design to unify both aysnc/sync env manager. | |
| For each interaction between policy and env, the policy will input the ready_obs and output the action. | |
| Then the env_manager will ``step`` with the action and prepare the next ready_obs. | |
| Returns: | |
| - ready_obs (:obj:`Dict[int, Any]`): A dict with env_id keys and observation values. | |
| Example: | |
| >>> obs = env_manager.ready_obs | |
| >>> stacked_obs = np.concatenate(list(obs.values())) | |
| >>> action = policy(obs) # here policy inputs np obs and outputs np action | |
| >>> action = {env_id: a for env_id, a in zip(obs.keys(), action)} | |
| >>> timesteps = env_manager.step(action) | |
| """ | |
| active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN] | |
| return {i: self._ready_obs[i] for i in active_env} | |
| def ready_obs_id(self) -> List[int]: | |
| """ | |
| Overview: | |
| Get the ready (next) observation id, which is a special design to unify both aysnc/sync env manager. | |
| Returns: | |
| - ready_obs_id (:obj:`List[int]`): A list of env_ids for ready observations. | |
| """ | |
| # In BaseEnvManager, if env_episode_count equals episode_num, this env is done. | |
| return [i for i, s in self._env_states.items() if s == EnvState.RUN] | |
| def ready_imgs(self, render_mode: Optional[str] = 'rgb_array') -> Dict[int, Any]: | |
| """ | |
| Overview: | |
| Sometimes, we need to render the envs, this function is used to get the next ready renderd frame and \ | |
| corresponding env id. | |
| Arguments: | |
| - render_mode (:obj:`Optional[str]`): The render mode, can be 'rgb_array' or 'depth_array', which follows \ | |
| the definition in the ``render`` function of ``ding.utils`` . | |
| Returns: | |
| - ready_imgs (:obj:`Dict[int, np.ndarray]`): A dict with env_id keys and rendered frames. | |
| """ | |
| from ding.utils import render | |
| assert render_mode in ['rgb_array', 'depth_array'], render_mode | |
| return {i: render(self._envs[i], render_mode) for i in self.ready_obs_id} | |
| def done(self) -> bool: | |
| """ | |
| Overview: | |
| ``done`` is a flag to indicate whether env manager is done, i.e., whether all sub-environments have \ | |
| executed enough episodes. | |
| Returns: | |
| - done (:obj:`bool`): Whether env manager is done. | |
| """ | |
| return all([s == EnvState.DONE for s in self._env_states.values()]) | |
| def method_name_list(self) -> list: | |
| """ | |
| Overview: | |
| The public methods list of sub-environments that can be directly called from the env manager level. Other \ | |
| methods and attributes will be accessed with the ``__getattr__`` method. | |
| Methods defined in this list can be regarded as the vectorized extension of methods in sub-environments. | |
| Sub-class of ``BaseEnvManager`` can override this method to add more methods. | |
| Returns: | |
| - method_name_list (:obj:`list`): The public methods list of sub-environments. | |
| """ | |
| return [ | |
| 'reset', 'step', 'seed', 'close', 'enable_save_replay', 'render', 'reward_shaping', 'enable_save_figure' | |
| ] | |
| def env_state_done(self, env_id: int) -> bool: | |
| return self._env_states[env_id] == EnvState.DONE | |
| def __getattr__(self, key: str) -> Any: | |
| """ | |
| Note: | |
| If a python object doesn't have the attribute whose name is `key`, it will call this method. | |
| We suppose that all envs have the same attributes. | |
| If you need different envs, please implement other env managers. | |
| """ | |
| if not hasattr(self._env_ref, key): | |
| raise AttributeError("env `{}` doesn't have the attribute `{}`".format(type(self._env_ref), key)) | |
| if isinstance(getattr(self._env_ref, key), MethodType) and key not in self.method_name_list: | |
| raise RuntimeError("env getattr doesn't support method({}), please override method_name_list".format(key)) | |
| self._check_closed() | |
| return [getattr(env, key) if hasattr(env, key) else None for env in self._envs] | |
| def _check_closed(self): | |
| """ | |
| Overview: | |
| Check whether the env manager is closed. Will be called in ``__getattr__`` and ``step``. | |
| """ | |
| assert not self._closed, "env manager is closed, please use the alive env manager" | |
| def launch(self, reset_param: Optional[Dict] = None) -> None: | |
| """ | |
| Overview: | |
| Launch the env manager, instantiate the sub-environments and set up the environments and their parameters. | |
| Arguments: | |
| - reset_param (:obj:`Optional[Dict]`): A dict of reset parameters for each environment, key is the env_id, \ | |
| value is the corresponding reset parameter, defaults to None. | |
| """ | |
| assert self._closed, "Please first close the env manager" | |
| try: | |
| global space_log_flag | |
| if space_log_flag: | |
| logging.info("Env Space Information:") | |
| logging.info("\tObservation Space: {}".format(self._observation_space)) | |
| logging.info("\tAction Space: {}".format(self._action_space)) | |
| logging.info("\tReward Space: {}".format(self._reward_space)) | |
| space_log_flag = False | |
| except: | |
| pass | |
| if reset_param is not None: | |
| assert len(reset_param) == len(self._env_fn) | |
| self._create_state() | |
| self.reset(reset_param) | |
| def _create_state(self) -> None: | |
| self._env_episode_count = {i: 0 for i in range(self.env_num)} | |
| self._ready_obs = {i: None for i in range(self.env_num)} | |
| self._envs = [e() for e in self._env_fn] | |
| assert len(self._envs) == self._env_num | |
| self._reset_param = {i: {} for i in range(self.env_num)} | |
| self._env_states = {i: EnvState.INIT for i in range(self.env_num)} | |
| if self._env_replay_path is not None: | |
| for e, s in zip(self._envs, self._env_replay_path): | |
| e.enable_save_replay(s) | |
| self._closed = False | |
| def reset(self, reset_param: Optional[Dict] = None) -> None: | |
| """ | |
| Overview: | |
| Forcely reset the sub-environments their corresponding parameters. Because in env manager all the \ | |
| sub-environments usually are reset automatically as soon as they are done, this method is only called when \ | |
| the caller must forcely reset all the sub-environments, such as in evaluation. | |
| Arguments: | |
| - reset_param (:obj:`List`): Dict of reset parameters for each environment, key is the env_id, \ | |
| value is the corresponding reset parameters. | |
| """ | |
| self._check_closed() | |
| # set seed if necessary | |
| env_ids = list(range(self._env_num)) if reset_param is None else list(reset_param.keys()) | |
| for i, env_id in enumerate(env_ids): # loop-type is necessary | |
| if self._env_seed[env_id] is not None: | |
| if self._env_dynamic_seed is not None: | |
| self._envs[env_id].seed(self._env_seed[env_id], self._env_dynamic_seed) | |
| else: | |
| self._envs[env_id].seed(self._env_seed[env_id]) | |
| self._env_seed[env_id] = None # seed only use once | |
| # reset env | |
| if reset_param is None: | |
| env_range = range(self.env_num) | |
| else: | |
| for env_id in reset_param: | |
| self._reset_param[env_id] = reset_param[env_id] | |
| env_range = reset_param.keys() | |
| for env_id in env_range: | |
| if self._env_replay_path is not None and self._env_states[env_id] == EnvState.RUN: | |
| logging.warning("please don't reset a unfinished env when you enable save replay, we just skip it") | |
| continue | |
| self._reset(env_id) | |
| def _reset(self, env_id: int) -> None: | |
| def reset_fn(): | |
| # if self._reset_param[env_id] is None, just reset specific env, not pass reset param | |
| if self._reset_param[env_id] is not None: | |
| assert isinstance(self._reset_param[env_id], dict), type(self._reset_param[env_id]) | |
| return self._envs[env_id].reset(**self._reset_param[env_id]) | |
| else: | |
| return self._envs[env_id].reset() | |
| exceptions = [] | |
| for _ in range(self._max_retry): | |
| try: | |
| self._env_states[env_id] = EnvState.RESET | |
| obs = reset_fn() | |
| self._ready_obs[env_id] = obs | |
| self._env_states[env_id] = EnvState.RUN | |
| return | |
| except BaseException as e: | |
| if self._retry_type == 'renew': | |
| err_env = self._envs[env_id] | |
| err_env.close() | |
| self._envs[env_id] = self._env_fn[env_id]() | |
| exceptions.append(e) | |
| time.sleep(self._retry_waiting_time) | |
| continue | |
| self._env_states[env_id] = EnvState.ERROR | |
| self.close() | |
| logging.error("Env {} reset has exceeded max retries({})".format(env_id, self._max_retry)) | |
| runtime_error = RuntimeError( | |
| "Env {} reset has exceeded max retries({}), and the latest exception is: {}".format( | |
| env_id, self._max_retry, str(exceptions[-1]) | |
| ) | |
| ) | |
| runtime_error.__traceback__ = exceptions[-1].__traceback__ | |
| raise runtime_error | |
| def step(self, actions: Dict[int, Any]) -> Dict[int, BaseEnvTimestep]: | |
| """ | |
| Overview: | |
| Execute env step according to input actions. If some sub-environments are done after this execution, \ | |
| they will be reset automatically when ``self._auto_reset`` is True, otherwise they need to be reset when \ | |
| the caller use the ``reset`` method of env manager. | |
| Arguments: | |
| - actions (:obj:`Dict[int, Any]`): A dict of actions, key is the env_id, value is corresponding action. \ | |
| action can be any type, it depends on the env, and the env will handle it. Ususlly, the action is \ | |
| a dict of numpy array, and the value is generated by the outer caller like ``policy``. | |
| Returns: | |
| - timesteps (:obj:`Dict[int, BaseEnvTimestep]`): Each timestep is a ``BaseEnvTimestep`` object, \ | |
| usually including observation, reward, done, info. Some special customized environments will have \ | |
| the special timestep definition. The length of timesteps is the same as the length of actions in \ | |
| synchronous env manager. | |
| Example: | |
| >>> timesteps = env_manager.step(action) | |
| >>> for env_id, timestep in enumerate(timesteps): | |
| >>> if timestep.done: | |
| >>> print('Env {} is done'.format(env_id)) | |
| """ | |
| self._check_closed() | |
| timesteps = {} | |
| for env_id, act in actions.items(): | |
| timesteps[env_id] = self._step(env_id, act) | |
| if timesteps[env_id].done: | |
| self._env_episode_count[env_id] += 1 | |
| if self._env_episode_count[env_id] < self._episode_num: | |
| if self._auto_reset: | |
| self._reset(env_id) | |
| else: | |
| self._env_states[env_id] = EnvState.NEED_RESET | |
| else: | |
| self._env_states[env_id] = EnvState.DONE | |
| else: | |
| self._ready_obs[env_id] = timesteps[env_id].obs | |
| return timesteps | |
| def _step(self, env_id: int, act: Any) -> BaseEnvTimestep: | |
| def step_fn(): | |
| return self._envs[env_id].step(act) | |
| exceptions = [] | |
| for _ in range(self._max_retry): | |
| try: | |
| return step_fn() | |
| except BaseException as e: | |
| exceptions.append(e) | |
| self._env_states[env_id] = EnvState.ERROR | |
| logging.error("Env {} step has exceeded max retries({})".format(env_id, self._max_retry)) | |
| runtime_error = RuntimeError( | |
| "Env {} step has exceeded max retries({}), and the latest exception is: {}".format( | |
| env_id, self._max_retry, str(exceptions[-1]) | |
| ) | |
| ) | |
| runtime_error.__traceback__ = exceptions[-1].__traceback__ | |
| raise runtime_error | |
| def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: bool = None) -> None: | |
| """ | |
| Overview: | |
| Set the random seed for each environment. | |
| Arguments: | |
| - seed (:obj:`Union[Dict[int, int], List[int], int]`): Dict or List of seeds for each environment; \ | |
| If only one seed is provided, it will be used in the same way for all environments. | |
| - dynamic_seed (:obj:`bool`): Whether to use dynamic seed. | |
| .. note:: | |
| For more details about ``dynamic_seed``, please refer to the best practice document of DI-engine \ | |
| (`en link <../04_best_practice/random_seed.html>`_). | |
| """ | |
| if isinstance(seed, numbers.Integral): | |
| seed = [seed + i for i in range(self.env_num)] | |
| self._env_seed = seed | |
| elif isinstance(seed, list): | |
| assert len(seed) == self._env_num, "len(seed) {:d} != env_num {:d}".format(len(seed), self._env_num) | |
| self._env_seed = seed | |
| elif isinstance(seed, dict): | |
| if not hasattr(self, '_env_seed'): | |
| raise RuntimeError("please indicate all the seed of each env in the beginning") | |
| for env_id, s in seed.items(): | |
| self._env_seed[env_id] = s | |
| else: | |
| raise TypeError("invalid seed arguments type: {}".format(type(seed))) | |
| self._env_dynamic_seed = dynamic_seed | |
| try: | |
| self._action_space.seed(seed[0]) | |
| except Exception: # TODO(nyz) deal with nested action_space like SMAC | |
| pass | |
| def enable_save_replay(self, replay_path: Union[List[str], str]) -> None: | |
| """ | |
| Overview: | |
| Enable all environments to save replay video after each episode terminates. | |
| Arguments: | |
| - replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \ | |
| Or one path for all environments. | |
| """ | |
| if isinstance(replay_path, str): | |
| replay_path = [replay_path] * self.env_num | |
| self._env_replay_path = replay_path | |
| def enable_save_figure(self, env_id: int, figure_path: str) -> None: | |
| """ | |
| Overview: | |
| Enable a specific env to save figure (e.g. environment statistics or episode return curve). | |
| Arguments: | |
| - figure_path (:obj:`str`): The file directory path for all environments to save figures. | |
| """ | |
| assert figure_path is not None | |
| self._envs[env_id].enable_save_figure(figure_path) | |
| def close(self) -> None: | |
| """ | |
| Overview: | |
| Close the env manager and release all the environment resources. | |
| """ | |
| if self._closed: | |
| return | |
| for env in self._envs: | |
| env.close() | |
| for i in range(self._env_num): | |
| self._env_states[i] = EnvState.VOID | |
| self._closed = True | |
| def reward_shaping(self, env_id: int, transitions: List[dict]) -> List[dict]: | |
| """ | |
| Overview: | |
| Execute reward shaping for a specific environment, which is often called when a episode terminates. | |
| Arguments: | |
| - env_id (:obj:`int`): The id of the environment to be shaped. | |
| - transitions (:obj:`List[dict]`): The transition data list of the environment to be shaped. | |
| Returns: | |
| - transitions (:obj:`List[dict]`): The shaped transition data list. | |
| """ | |
| return self._envs[env_id].reward_shaping(transitions) | |
| def closed(self) -> bool: | |
| """ | |
| Overview: | |
| ``closed`` is a property that returns whether the env manager is closed. | |
| Returns: | |
| - closed (:obj:`bool`): Whether the env manager is closed. | |
| """ | |
| return self._closed | |
| def random_action(self) -> Dict: | |
| return {env_id: self._env_ref.action_space.sample() for env_id in self.ready_obs_id} | |
| class BaseEnvManagerV2(BaseEnvManager): | |
| """ | |
| Overview: | |
| The basic class of env manager to manage multiple vectorized environments. BaseEnvManager define all the | |
| necessary interfaces and derived class must extend this basic class. | |
| The class is implemented by the pseudo-parallelism (i.e. serial) mechanism, therefore, this class is only | |
| used in some tiny environments and for debug purpose. | |
| ``V2`` means this env manager is designed for new task pipeline and interfaces coupled with treetensor.` | |
| .. note:: | |
| For more details about new task pipeline, please refer to the system document of DI-engine \ | |
| (`system en link <../03_system/index.html>`_). | |
| Interfaces: | |
| reset, step, seed, close, enable_save_replay, launch, default_config, reward_shaping, enable_save_figure | |
| Properties: | |
| env_num, env_ref, ready_obs, ready_obs_id, ready_imgs, done, closed, method_name_list, observation_space, \ | |
| action_space, reward_space | |
| """ | |
| def ready_obs(self) -> tnp.array: | |
| """ | |
| Overview: | |
| Get the ready (next) observation, which is a special design to unify both aysnc/sync env manager. | |
| For each interaction between policy and env, the policy will input the ready_obs and output the action. | |
| Then the env_manager will ``step`` with the action and prepare the next ready_obs. | |
| For ``V2`` version, the observation is transformed and packed up into ``tnp.array`` type, which allows | |
| more convenient operations. | |
| Return: | |
| - ready_obs (:obj:`tnp.array`): A stacked treenumpy-type observation data. | |
| Example: | |
| >>> obs = env_manager.ready_obs | |
| >>> action = policy(obs) # here policy inputs treenp obs and output np action | |
| >>> timesteps = env_manager.step(action) | |
| """ | |
| active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN] | |
| obs = [self._ready_obs[i] for i in active_env] | |
| if isinstance(obs[0], dict): # transform each element to treenumpy array | |
| obs = [tnp.array(o) for o in obs] | |
| return tnp.stack(obs) | |
| def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]: | |
| """ | |
| Overview: | |
| Execute env step according to input actions. If some sub-environments are done after this execution, \ | |
| they will be reset automatically by default. | |
| Arguments: | |
| - actions (:obj:`List[tnp.ndarray]`): A list of treenumpy-type actions, the value is generated by the \ | |
| outer caller like ``policy``. | |
| Returns: | |
| - timesteps (:obj:`List[tnp.ndarray]`): A list of timestep, Each timestep is a ``tnp.ndarray`` object, \ | |
| usually including observation, reward, done, info, env_id. Some special environments will have \ | |
| the special timestep definition. The length of timesteps is the same as the length of actions in \ | |
| synchronous env manager. For the compatibility of treenumpy, here we use ``make_key_as_identifier`` \ | |
| and ``remove_illegal_item`` functions to modify the original timestep. | |
| Example: | |
| >>> timesteps = env_manager.step(action) | |
| >>> for timestep in timesteps: | |
| >>> if timestep.done: | |
| >>> print('Env {} is done'.format(timestep.env_id)) | |
| """ | |
| actions = {env_id: a for env_id, a in zip(self.ready_obs_id, actions)} | |
| timesteps = super().step(actions) | |
| new_data = [] | |
| for env_id, timestep in timesteps.items(): | |
| obs, reward, done, info = timestep | |
| # make the type and content of key as similar as identifier, | |
| # in order to call them as attribute (e.g. timestep.xxx), such as ``TimeLimit.truncated`` in cartpole info | |
| info = make_key_as_identifier(info) | |
| info = remove_illegal_item(info) | |
| new_data.append(tnp.array({'obs': obs, 'reward': reward, 'done': done, 'info': info, 'env_id': env_id})) | |
| return new_data | |
| def create_env_manager(manager_cfg: EasyDict, env_fn: List[Callable]) -> BaseEnvManager: | |
| """ | |
| Overview: | |
| Create an env manager according to ``manager_cfg`` and env functions. | |
| Arguments: | |
| - manager_cfg (:obj:`EasyDict`): Final merged env manager config. | |
| - env_fn (:obj:`List[Callable]`): A list of functions to create ``env_num`` sub-environments. | |
| ArgumentsKeys: | |
| - type (:obj:`str`): Env manager type set in ``ENV_MANAGER_REGISTRY.register`` , such as ``base`` . | |
| - import_names (:obj:`List[str]`): A list of module names (paths) to import before creating env manager, such \ | |
| as ``ding.envs.env_manager.base_env_manager`` . | |
| Returns: | |
| - env_manager (:obj:`BaseEnvManager`): The created env manager. | |
| .. tip:: | |
| This method will not modify the ``manager_cfg`` , it will deepcopy the ``manager_cfg`` and then modify it. | |
| """ | |
| manager_cfg = copy.deepcopy(manager_cfg) | |
| if 'import_names' in manager_cfg: | |
| import_module(manager_cfg.pop('import_names')) | |
| manager_type = manager_cfg.pop('type') | |
| return ENV_MANAGER_REGISTRY.build(manager_type, env_fn=env_fn, cfg=manager_cfg) | |
| def get_env_manager_cls(cfg: EasyDict) -> type: | |
| """ | |
| Overview: | |
| Get the env manager class according to config, which is used to access related class variables/methods. | |
| Arguments: | |
| - manager_cfg (:obj:`EasyDict`): Final merged env manager config. | |
| ArgumentsKeys: | |
| - type (:obj:`str`): Env manager type set in ``ENV_MANAGER_REGISTRY.register`` , such as ``base`` . | |
| - import_names (:obj:`List[str]`): A list of module names (paths) to import before creating env manager, such \ | |
| as ``ding.envs.env_manager.base_env_manager`` . | |
| Returns: | |
| - env_manager_cls (:obj:`type`): The corresponding env manager class. | |
| """ | |
| import_module(cfg.get('import_names', [])) | |
| return ENV_MANAGER_REGISTRY.get(cfg.type) | |