Spaces:
Sleeping
Sleeping
| from abc import ABC, abstractmethod | |
| from typing import Any, List, Tuple | |
| import gym | |
| import copy | |
| from easydict import EasyDict | |
| from collections import namedtuple | |
| from ding.utils import import_module, ENV_REGISTRY | |
| BaseEnvTimestep = namedtuple('BaseEnvTimestep', ['obs', 'reward', 'done', 'info']) | |
| # for solving multiple inheritance metaclass conflict between gym and ABC | |
| class FinalMeta(type(ABC), type(gym.Env)): | |
| pass | |
| class BaseEnv(gym.Env, ABC, metaclass=FinalMeta): | |
| """ | |
| Overview: | |
| Basic environment class, extended from ``gym.Env`` | |
| Interface: | |
| ``__init__``, ``reset``, ``close``, ``step``, ``random_action``, ``create_collector_env_cfg``, \ | |
| ``create_evaluator_env_cfg``, ``enable_save_replay`` | |
| """ | |
| def __init__(self, cfg: dict) -> None: | |
| """ | |
| Overview: | |
| Lazy init, only related arguments will be initialized in ``__init__`` method, and the concrete \ | |
| env will be initialized the first time ``reset`` method is called. | |
| Arguments: | |
| - cfg (:obj:`dict`): Environment configuration in dict type. | |
| """ | |
| raise NotImplementedError | |
| def reset(self) -> Any: | |
| """ | |
| Overview: | |
| Reset the env to an initial state and returns an initial observation. | |
| Returns: | |
| - obs (:obj:`Any`): Initial observation after reset. | |
| """ | |
| raise NotImplementedError | |
| def close(self) -> None: | |
| """ | |
| Overview: | |
| Close env and all the related resources, it should be called after the usage of env instance. | |
| """ | |
| raise NotImplementedError | |
| def step(self, action: Any) -> 'BaseEnv.timestep': | |
| """ | |
| Overview: | |
| Run one timestep of the environment's dynamics/simulation. | |
| Arguments: | |
| - action (:obj:`Any`): The ``action`` input to step with. | |
| Returns: | |
| - timestep (:obj:`BaseEnv.timestep`): The result timestep of env executing one step. | |
| """ | |
| raise NotImplementedError | |
| def seed(self, seed: int) -> None: | |
| """ | |
| Overview: | |
| Set the seed for this env's random number generator(s). | |
| Arguments: | |
| - seed (:obj:`Any`): Random seed. | |
| """ | |
| raise NotImplementedError | |
| def __repr__(self) -> str: | |
| """ | |
| Overview: | |
| Return the information string of this env instance. | |
| Returns: | |
| - info (:obj:`str`): Information of this env instance, like type and arguments. | |
| """ | |
| raise NotImplementedError | |
| def create_collector_env_cfg(cfg: dict) -> List[dict]: | |
| """ | |
| Overview: | |
| Return a list of all of the environment from input config, used in env manager \ | |
| (a series of vectorized env), and this method is mainly responsible for envs collecting data. | |
| Arguments: | |
| - cfg (:obj:`dict`): Original input env config, which needs to be transformed into the type of creating \ | |
| env instance actually and generated the corresponding number of configurations. | |
| Returns: | |
| - env_cfg_list (:obj:`List[dict]`): List of ``cfg`` including all the config collector envs. | |
| .. note:: | |
| Elements(env config) in collector_env_cfg/evaluator_env_cfg can be different, such as server ip and port. | |
| """ | |
| collector_env_num = cfg.pop('collector_env_num') | |
| return [cfg for _ in range(collector_env_num)] | |
| def create_evaluator_env_cfg(cfg: dict) -> List[dict]: | |
| """ | |
| Overview: | |
| Return a list of all of the environment from input config, used in env manager \ | |
| (a series of vectorized env), and this method is mainly responsible for envs evaluating performance. | |
| Arguments: | |
| - cfg (:obj:`dict`): Original input env config, which needs to be transformed into the type of creating \ | |
| env instance actually and generated the corresponding number of configurations. | |
| Returns: | |
| - env_cfg_list (:obj:`List[dict]`): List of ``cfg`` including all the config evaluator envs. | |
| """ | |
| evaluator_env_num = cfg.pop('evaluator_env_num') | |
| return [cfg for _ in range(evaluator_env_num)] | |
| # optional method | |
| def enable_save_replay(self, replay_path: str) -> None: | |
| """ | |
| Overview: | |
| Save replay file in the given path, and this method need to be self-implemented by each env class. | |
| Arguments: | |
| - replay_path (:obj:`str`): The path to save replay file. | |
| """ | |
| raise NotImplementedError | |
| # optional method | |
| def random_action(self) -> Any: | |
| """ | |
| Overview: | |
| Return random action generated from the original action space, usually it is convenient for test. | |
| Returns: | |
| - random_action (:obj:`Any`): Action generated randomly. | |
| """ | |
| pass | |
| def get_vec_env_setting(cfg: dict, collect: bool = True, eval_: bool = True) -> Tuple[type, List[dict], List[dict]]: | |
| """ | |
| Overview: | |
| Get vectorized env setting (env_fn, collector_env_cfg, evaluator_env_cfg). | |
| Arguments: | |
| - cfg (:obj:`dict`): Original input env config in user config, such as ``cfg.env``. | |
| Returns: | |
| - env_fn (:obj:`type`): Callable object, call it with proper arguments and then get a new env instance. | |
| - collector_env_cfg (:obj:`List[dict]`): A list contains the config of collecting data envs. | |
| - evaluator_env_cfg (:obj:`List[dict]`): A list contains the config of evaluation envs. | |
| .. note:: | |
| Elements (env config) in collector_env_cfg/evaluator_env_cfg can be different, such as server ip and port. | |
| """ | |
| import_module(cfg.get('import_names', [])) | |
| env_fn = ENV_REGISTRY.get(cfg.type) | |
| collector_env_cfg = env_fn.create_collector_env_cfg(cfg) if collect else None | |
| evaluator_env_cfg = env_fn.create_evaluator_env_cfg(cfg) if eval_ else None | |
| return env_fn, collector_env_cfg, evaluator_env_cfg | |
| def get_env_cls(cfg: EasyDict) -> type: | |
| """ | |
| Overview: | |
| Get the env class by correspondng module of ``cfg`` and return the callable class. | |
| Arguments: | |
| - cfg (:obj:`dict`): Original input env config in user config, such as ``cfg.env``. | |
| Returns: | |
| - env_cls_type (:obj:`type`): Env module as the corresponding callable class type. | |
| """ | |
| import_module(cfg.get('import_names', [])) | |
| return ENV_REGISTRY.get(cfg.type) | |
| def create_model_env(cfg: EasyDict) -> Any: | |
| """ | |
| Overview: | |
| Create model env, which is used in model-based RL. | |
| """ | |
| cfg = copy.deepcopy(cfg) | |
| model_env_fn = get_env_cls(cfg) | |
| cfg.pop('import_names') | |
| cfg.pop('type') | |
| return model_env_fn(**cfg) | |