Spaces:
Sleeping
Sleeping
| import gym | |
| from easydict import EasyDict | |
| from copy import deepcopy | |
| import numpy as np | |
| from collections import namedtuple | |
| from typing import Any, Union, List, Tuple, Dict, Callable, Optional | |
| from ditk import logging | |
| try: | |
| import envpool | |
| except ImportError: | |
| import sys | |
| logging.warning("Please install envpool first, use 'pip install envpool'") | |
| envpool = None | |
| from ding.envs import BaseEnvTimestep | |
| from ding.utils import ENV_MANAGER_REGISTRY, deep_merge_dicts | |
| from ding.torch_utils import to_ndarray | |
| class PoolEnvManager: | |
| ''' | |
| Overview: | |
| Envpool now supports Atari, Classic Control, Toy Text, ViZDoom. | |
| Here we list some commonly used env_ids as follows. | |
| For more examples, you can refer to <https://envpool.readthedocs.io/en/latest/api/atari.html>. | |
| - Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5" | |
| - Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1" | |
| ''' | |
| def default_config(cls) -> EasyDict: | |
| return EasyDict(deepcopy(cls.config)) | |
| config = dict( | |
| type='envpool', | |
| # Sync mode: batch_size == env_num | |
| # Async mode: batch_size < env_num | |
| env_num=8, | |
| batch_size=8, | |
| ) | |
| def __init__(self, cfg: EasyDict) -> None: | |
| self._cfg = cfg | |
| self._env_num = cfg.env_num | |
| self._batch_size = cfg.batch_size | |
| self._ready_obs = {} | |
| self._closed = True | |
| self._seed = None | |
| def launch(self) -> None: | |
| assert self._closed, "Please first close the env manager" | |
| if self._seed is None: | |
| seed = 0 | |
| else: | |
| seed = self._seed | |
| self._envs = envpool.make( | |
| task_id=self._cfg.env_id, | |
| env_type="gym", | |
| num_envs=self._env_num, | |
| batch_size=self._batch_size, | |
| seed=seed, | |
| episodic_life=self._cfg.episodic_life, | |
| reward_clip=self._cfg.reward_clip, | |
| stack_num=self._cfg.stack_num, | |
| gray_scale=self._cfg.gray_scale, | |
| frame_skip=self._cfg.frame_skip | |
| ) | |
| self._closed = False | |
| self.reset() | |
| def reset(self) -> None: | |
| self._ready_obs = {} | |
| self._envs.async_reset() | |
| while True: | |
| obs, _, _, info = self._envs.recv() | |
| env_id = info['env_id'] | |
| obs = obs.astype(np.float32) | |
| self._ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, self._ready_obs) | |
| if len(self._ready_obs) == self._env_num: | |
| break | |
| self._eval_episode_return = [0. for _ in range(self._env_num)] | |
| def step(self, action: dict) -> Dict[int, namedtuple]: | |
| env_id = np.array(list(action.keys())) | |
| action = np.array(list(action.values())) | |
| if len(action.shape) == 2: | |
| action = action.squeeze(1) | |
| self._envs.send(action, env_id) | |
| obs, rew, done, info = self._envs.recv() | |
| obs = obs.astype(np.float32) | |
| rew = rew.astype(np.float32) | |
| env_id = info['env_id'] | |
| timesteps = {} | |
| self._ready_obs = {} | |
| for i in range(len(env_id)): | |
| d = bool(done[i]) | |
| r = to_ndarray([rew[i]]) | |
| self._eval_episode_return[env_id[i]] += r | |
| timesteps[env_id[i]] = BaseEnvTimestep(obs[i], r, d, info={'env_id': i}) | |
| if d: | |
| timesteps[env_id[i]].info['eval_episode_return'] = self._eval_episode_return[env_id[i]] | |
| self._eval_episode_return[env_id[i]] = 0. | |
| self._ready_obs[env_id[i]] = obs[i] | |
| return timesteps | |
| def close(self) -> None: | |
| if self._closed: | |
| return | |
| # Envpool has no `close` API | |
| self._closed = True | |
| def seed(self, seed: int, dynamic_seed=False) -> None: | |
| # The i-th environment seed in Envpool will be set with i+seed, so we don't do extra transformation here | |
| self._seed = seed | |
| logging.warning("envpool doesn't support dynamic_seed in different episode") | |
| def env_num(self) -> int: | |
| return self._env_num | |
| def ready_obs(self) -> Dict[int, Any]: | |
| return self._ready_obs | |