Spaces:
Sleeping
Sleeping
| from typing import Dict | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import gym | |
| from gym import spaces | |
| from ditk import logging | |
| from ding.envs import DingEnvWrapper, EvalEpisodeReturnWrapper, \ | |
| BaseEnvManagerV2 | |
| from ding.config import compile_config | |
| from ding.policy import PPOPolicy | |
| from ding.utils import set_pkg_seed | |
| from ding.model import VAC | |
| from ding.framework import task, ding_init | |
| from ding.framework.context import OnlineRLContext | |
| from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ | |
| gae_estimator, online_logger | |
| from easydict import EasyDict | |
| my_env_ppo_config = dict( | |
| exp_name='my_env_ppo_seed0', | |
| env=dict( | |
| collector_env_num=4, | |
| evaluator_env_num=4, | |
| n_evaluator_episode=4, | |
| stop_value=195, | |
| ), | |
| policy=dict( | |
| cuda=True, | |
| action_space='discrete', | |
| model=dict( | |
| obs_shape=None, | |
| action_shape=2, | |
| action_space='discrete', | |
| critic_head_hidden_size=138, | |
| actor_head_hidden_size=138, | |
| ), | |
| learn=dict( | |
| epoch_per_collect=2, | |
| batch_size=64, | |
| learning_rate=0.001, | |
| value_weight=0.5, | |
| entropy_weight=0.01, | |
| clip_ratio=0.2, | |
| learner=dict(hook=dict(save_ckpt_after_iter=100)), | |
| ), | |
| collect=dict( | |
| n_sample=256, unroll_len=1, discount_factor=0.9, gae_lambda=0.95, collector=dict(transform_obs=True, ) | |
| ), | |
| eval=dict(evaluator=dict(eval_freq=100, ), ), | |
| ), | |
| ) | |
| my_env_ppo_config = EasyDict(my_env_ppo_config) | |
| main_config = my_env_ppo_config | |
| my_env_ppo_create_config = dict( | |
| env_manager=dict(type='base'), | |
| policy=dict(type='ppo'), | |
| ) | |
| my_env_ppo_create_config = EasyDict(my_env_ppo_create_config) | |
| create_config = my_env_ppo_create_config | |
| class MyEnv(gym.Env): | |
| def __init__(self, seq_len=5, feature_dim=10, image_size=(10, 10, 3)): | |
| super().__init__() | |
| # Define the action space | |
| self.action_space = spaces.Discrete(2) | |
| # Define the observation space | |
| self.observation_space = spaces.Dict( | |
| ( | |
| { | |
| 'key_0': spaces.Dict( | |
| { | |
| 'k1': spaces.Box(low=0, high=np.inf, shape=(1, ), dtype=np.float32), | |
| 'k2': spaces.Box(low=-1, high=1, shape=(1, ), dtype=np.float32), | |
| } | |
| ), | |
| 'key_1': spaces.Box(low=-np.inf, high=np.inf, shape=(seq_len, feature_dim), dtype=np.float32), | |
| 'key_2': spaces.Box(low=0, high=255, shape=image_size, dtype=np.uint8), | |
| 'key_3': spaces.Box(low=0, high=np.array([np.inf, 3]), shape=(2, ), dtype=np.float32) | |
| } | |
| ) | |
| ) | |
| def reset(self): | |
| # Generate a random initial state | |
| return self.observation_space.sample() | |
| def step(self, action): | |
| # Compute the reward and done flag (which are not used in this example) | |
| reward = np.random.uniform(low=0.0, high=1.0) | |
| done = False | |
| if np.random.uniform(low=0.0, high=1.0) > 0.7: | |
| done = True | |
| info = {} | |
| # Return the next state, reward, and done flag | |
| return self.observation_space.sample(), reward, done, info | |
| def ding_env_maker(): | |
| return DingEnvWrapper( | |
| MyEnv(), cfg={'env_wrapper': [ | |
| lambda env: EvalEpisodeReturnWrapper(env), | |
| ]} | |
| ) | |
| class Encoder(nn.Module): | |
| def __init__(self, feature_dim: int): | |
| super(Encoder, self).__init__() | |
| # Define the networks for each input type | |
| self.fc_net_1_k1 = nn.Sequential(nn.Linear(1, 8), nn.ReLU()) | |
| self.fc_net_1_k2 = nn.Sequential(nn.Linear(1, 8), nn.ReLU()) | |
| self.fc_net_1 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) | |
| """ | |
| Implementation of transformer_encoder refers to Vision Transformer (ViT) code: | |
| https://arxiv.org/abs/2010.11929 | |
| https://pytorch.org/vision/main/_modules/torchvision/models/vision_transformer.html | |
| """ | |
| self.class_token = nn.Parameter(torch.zeros(1, 1, feature_dim)) | |
| self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_dim, nhead=2, batch_first=True) | |
| self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1) | |
| self.conv_net = nn.Sequential( | |
| nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, kernel_size=3, padding=1), | |
| nn.ReLU() | |
| ) | |
| self.conv_fc_net = nn.Sequential(nn.Flatten(), nn.Linear(3200, 64), nn.ReLU()) | |
| self.fc_net_2 = nn.Sequential(nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, 32), nn.ReLU(), nn.Flatten()) | |
| def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: | |
| # Unpack the input tuple | |
| dict_input = inputs['key_0'] # dict{key:(B)} | |
| transformer_input = inputs['key_1'] # (B, seq_len, feature_dim) | |
| conv_input = inputs['key_2'] # (B, H, W, 3) | |
| fc_input = inputs['key_3'] # (B, X) | |
| B = fc_input.shape[0] | |
| # Pass each input through its corresponding network | |
| dict_output = self.fc_net_1( | |
| torch.cat( | |
| [self.fc_net_1_k1(dict_input['k1'].unsqueeze(-1)), | |
| self.fc_net_1_k2(dict_input['k2'].unsqueeze(-1))], | |
| dim=1 | |
| ) | |
| ) | |
| batch_class_token = self.class_token.expand(B, -1, -1) | |
| transformer_output = self.transformer_encoder(torch.cat([batch_class_token, transformer_input], dim=1)) | |
| transformer_output = transformer_output[:, 0] | |
| conv_output = self.conv_fc_net(self.conv_net(conv_input.permute(0, 3, 1, 2))) | |
| fc_output = self.fc_net_2(fc_input) | |
| # Concatenate the outputs along the feature dimension | |
| encoded_output = torch.cat([dict_output, transformer_output, conv_output, fc_output], dim=1) | |
| return encoded_output | |
| def main(): | |
| logging.getLogger().setLevel(logging.INFO) | |
| cfg = compile_config(main_config, create_cfg=create_config, auto=True) | |
| ding_init(cfg) | |
| with task.start(async_mode=False, ctx=OnlineRLContext()): | |
| collector_env = BaseEnvManagerV2( | |
| env_fn=[ding_env_maker for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager | |
| ) | |
| evaluator_env = BaseEnvManagerV2( | |
| env_fn=[ding_env_maker for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager | |
| ) | |
| set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) | |
| encoder = Encoder(feature_dim=10) | |
| model = VAC(encoder=encoder, **cfg.policy.model) | |
| policy = PPOPolicy(cfg.policy, model=model) | |
| task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) | |
| task.use(StepCollector(cfg, policy.collect_mode, collector_env)) | |
| task.use(gae_estimator(cfg, policy.collect_mode)) | |
| task.use(multistep_trainer(policy.learn_mode, log_freq=50)) | |
| task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) | |
| task.use(online_logger(train_show_freq=3)) | |
| task.run() | |
| if __name__ == "__main__": | |
| main() | |