Spaces:
Sleeping
Sleeping
| from typing import TYPE_CHECKING | |
| from easydict import EasyDict | |
| import treetensor.torch as ttorch | |
| from ding.policy import get_random_policy | |
| from ding.envs import BaseEnvManager | |
| from ding.framework import task | |
| from .functional import inferencer, rolloutor, TransitionList | |
| if TYPE_CHECKING: | |
| from ding.framework import OnlineRLContext | |
| class StepCollector: | |
| """ | |
| Overview: | |
| The class of the collector running by steps, including model inference and transition \ | |
| process. Use the `__call__` method to execute the whole collection process. | |
| """ | |
| def __new__(cls, *args, **kwargs): | |
| if task.router.is_active and not task.has_role(task.role.COLLECTOR): | |
| return task.void() | |
| return super(StepCollector, cls).__new__(cls) | |
| def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None: | |
| """ | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): Config. | |
| - policy (:obj:`Policy`): The policy to be collected. | |
| - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ | |
| its derivatives are supported. | |
| - random_collect_size (:obj:`int`): The count of samples that will be collected randomly, \ | |
| typically used in initial runs. | |
| """ | |
| self.cfg = cfg | |
| self.env = env | |
| self.policy = policy | |
| self.random_collect_size = random_collect_size | |
| self._transitions = TransitionList(self.env.env_num) | |
| self._inferencer = task.wrap(inferencer(cfg.seed, policy, env)) | |
| self._rolloutor = task.wrap(rolloutor(policy, env, self._transitions)) | |
| def __call__(self, ctx: "OnlineRLContext") -> None: | |
| """ | |
| Overview: | |
| An encapsulation of inference and rollout middleware. Stop when completing \ | |
| the target number of steps. | |
| Input of ctx: | |
| - env_step (:obj:`int`): The env steps which will increase during collection. | |
| """ | |
| old = ctx.env_step | |
| if self.random_collect_size > 0 and old < self.random_collect_size: | |
| target_size = self.random_collect_size - old | |
| random_policy = get_random_policy(self.cfg, self.policy, self.env) | |
| current_inferencer = task.wrap(inferencer(self.cfg.seed, random_policy, self.env)) | |
| else: | |
| # compatible with old config, a train sample = unroll_len step | |
| target_size = self.cfg.policy.collect.n_sample * self.cfg.policy.collect.unroll_len | |
| current_inferencer = self._inferencer | |
| while True: | |
| current_inferencer(ctx) | |
| self._rolloutor(ctx) | |
| if ctx.env_step - old >= target_size: | |
| ctx.trajectories, ctx.trajectory_end_idx = self._transitions.to_trajectories() | |
| self._transitions.clear() | |
| break | |
| class PPOFStepCollector: | |
| """ | |
| Overview: | |
| The class of the collector running by steps, including model inference and transition \ | |
| process. Use the `__call__` method to execute the whole collection process. | |
| """ | |
| def __new__(cls, *args, **kwargs): | |
| if task.router.is_active and not task.has_role(task.role.COLLECTOR): | |
| return task.void() | |
| return super(PPOFStepCollector, cls).__new__(cls) | |
| def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None: | |
| """ | |
| Arguments: | |
| - seed (:obj:`int`): Random seed. | |
| - policy (:obj:`Policy`): The policy to be collected. | |
| - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ | |
| its derivatives are supported. | |
| """ | |
| self.env = env | |
| self.env.seed(seed) | |
| self.policy = policy | |
| self.n_sample = n_sample | |
| self.unroll_len = unroll_len | |
| self._transitions = TransitionList(self.env.env_num) | |
| self._env_episode_id = [_ for _ in range(env.env_num)] | |
| self._current_id = env.env_num | |
| def __call__(self, ctx: "OnlineRLContext") -> None: | |
| """ | |
| Overview: | |
| An encapsulation of inference and rollout middleware. Stop when completing \ | |
| the target number of steps. | |
| Input of ctx: | |
| - env_step (:obj:`int`): The env steps which will increase during collection. | |
| """ | |
| device = self.policy._device | |
| old = ctx.env_step | |
| target_size = self.n_sample * self.unroll_len | |
| if self.env.closed: | |
| self.env.launch() | |
| while True: | |
| obs = ttorch.as_tensor(self.env.ready_obs).to(dtype=ttorch.float32) | |
| obs = obs.to(device) | |
| inference_output = self.policy.collect(obs, **ctx.collect_kwargs) | |
| inference_output = inference_output.cpu() | |
| action = inference_output.action.numpy() | |
| timesteps = self.env.step(action) | |
| ctx.env_step += len(timesteps) | |
| obs = obs.cpu() | |
| for i, timestep in enumerate(timesteps): | |
| transition = self.policy.process_transition(obs[i], inference_output[i], timestep) | |
| transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) | |
| transition.env_data_id = ttorch.as_tensor([self._env_episode_id[timestep.env_id]]) | |
| self._transitions.append(timestep.env_id, transition) | |
| if timestep.done: | |
| self.policy.reset([timestep.env_id]) | |
| self._env_episode_id[timestep.env_id] = self._current_id | |
| self._current_id += 1 | |
| ctx.env_episode += 1 | |
| if ctx.env_step - old >= target_size: | |
| ctx.trajectories, ctx.trajectory_end_idx = self._transitions.to_trajectories() | |
| self._transitions.clear() | |
| break | |
| class EpisodeCollector: | |
| """ | |
| Overview: | |
| The class of the collector running by episodes, including model inference and transition \ | |
| process. Use the `__call__` method to execute the whole collection process. | |
| """ | |
| def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None: | |
| """ | |
| Arguments: | |
| - cfg (:obj:`EasyDict`): Config. | |
| - policy (:obj:`Policy`): The policy to be collected. | |
| - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ | |
| its derivatives are supported. | |
| - random_collect_size (:obj:`int`): The count of samples that will be collected randomly, \ | |
| typically used in initial runs. | |
| """ | |
| self.cfg = cfg | |
| self.env = env | |
| self.policy = policy | |
| self.random_collect_size = random_collect_size | |
| self._transitions = TransitionList(self.env.env_num) | |
| self._inferencer = task.wrap(inferencer(cfg.seed, policy, env)) | |
| self._rolloutor = task.wrap(rolloutor(policy, env, self._transitions)) | |
| def __call__(self, ctx: "OnlineRLContext") -> None: | |
| """ | |
| Overview: | |
| An encapsulation of inference and rollout middleware. Stop when completing the \ | |
| target number of episodes. | |
| Input of ctx: | |
| - env_episode (:obj:`int`): The env env_episode which will increase during collection. | |
| """ | |
| old = ctx.env_episode | |
| if self.random_collect_size > 0 and old < self.random_collect_size: | |
| target_size = self.random_collect_size - old | |
| random_policy = get_random_policy(self.cfg, self.policy, self.env) | |
| current_inferencer = task.wrap(inferencer(self.cfg, random_policy, self.env)) | |
| else: | |
| target_size = self.cfg.policy.collect.n_episode | |
| current_inferencer = self._inferencer | |
| while True: | |
| current_inferencer(ctx) | |
| self._rolloutor(ctx) | |
| if ctx.env_episode - old >= target_size: | |
| ctx.episodes = self._transitions.to_episodes() | |
| self._transitions.clear() | |
| break | |
| # TODO battle collector | |