Spaces:
Sleeping
Sleeping
| import pytest | |
| import torch | |
| from easydict import EasyDict | |
| from ding.policy.r2d3 import R2D3Policy | |
| from ding.utils.data import offline_data_save_type | |
| from tensorboardX import SummaryWriter | |
| from ding.model.wrapper.model_wrappers import ArgmaxSampleWrapper, HiddenStateWrapper, EpsGreedySampleWrapper | |
| import os | |
| from typing import List | |
| from collections import namedtuple | |
| obs_space = 5 | |
| action_space = 4 | |
| cfg = dict( | |
| cuda=True, | |
| on_policy=False, | |
| priority=True, | |
| priority_IS_weight=True, | |
| model=dict( | |
| obs_shape=obs_space, | |
| action_shape=action_space, | |
| encoder_hidden_size_list=[128, 128, 512], | |
| ), | |
| discount_factor=0.99, | |
| burnin_step=2, | |
| nstep=5, | |
| learn_unroll_len=20, | |
| burning_step=5, | |
| learn=dict( | |
| value_rescale=True, | |
| update_per_collect=8, | |
| batch_size=64, | |
| learning_rate=0.0005, | |
| target_update_theta=0.001, | |
| lambda1=1.0, # n-step return | |
| lambda2=1.0, # supervised loss | |
| lambda3=1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'. | |
| lambda_one_step_td=1, # 1-step return | |
| margin_function=0.8, # margin function in JE, here we implement this as a constant | |
| per_train_iter_k=0, | |
| ignore_done=False, | |
| ), | |
| collect=dict( | |
| n_sample=32, | |
| traj_len_inf=True, | |
| env_num=8, | |
| pho=1 / 4, | |
| ), | |
| eval=dict(env_num=8, ), | |
| other=dict( | |
| eps=dict( | |
| type='exp', | |
| start=0.95, | |
| end=0.1, | |
| decay=100000, | |
| ), | |
| replay_buffer=dict( | |
| replay_buffer_size=int(1e4), | |
| alpha=0.6, | |
| beta=0.4, | |
| ), | |
| ), | |
| ) | |
| cfg = EasyDict(cfg) | |
| def get_batch(size=8): | |
| data = {} | |
| for i in range(size): | |
| obs = torch.zeros(obs_space) | |
| data[i] = obs | |
| return data | |
| def get_transition(size=20): | |
| data = [] | |
| import numpy as np | |
| for i in range(size): | |
| sample = {} | |
| sample['obs'] = torch.zeros(obs_space) | |
| sample['action'] = torch.tensor(np.array([int(i % action_space)])) | |
| sample['done'] = False | |
| sample['prev_state'] = [torch.randn(1, 1, 512) for __ in range(2)] | |
| sample['reward'] = torch.Tensor([1.]) | |
| sample['IS'] = 1. | |
| sample['is_expert'] = bool(i % 2) | |
| data.append(sample) | |
| return data | |
| def test_r2d3(cfg): | |
| policy = R2D3Policy(cfg, enable_field=['collect', 'eval']) | |
| policy._init_learn() | |
| assert type(policy._learn_model) == ArgmaxSampleWrapper | |
| assert type(policy._target_model) == HiddenStateWrapper | |
| policy._reset_learn() | |
| policy._reset_learn([0]) | |
| state = policy._state_dict_learn() | |
| policy._load_state_dict_learn(state) | |
| policy._init_collect() | |
| assert type(policy._collect_model) == EpsGreedySampleWrapper | |
| policy._reset_collect() | |
| policy._reset_collect([0]) | |
| policy._init_eval() | |
| assert type(policy._eval_model) == ArgmaxSampleWrapper | |
| policy._reset_eval() | |
| policy._reset_eval([0]) | |
| assert policy.default_model()[0] == 'drqn' | |
| var = policy._monitor_vars_learn() | |
| assert type(var) == list | |
| assert sum([type(s) == str for s in var]) == len(var) | |
| batch = get_batch(8) | |
| out = policy._forward_collect(batch, eps=0.1) | |
| assert len(set(out[0].keys()).intersection({'logit', 'prev_state', 'action'})) == 3 | |
| assert list(out[0]['logit'].shape) == [action_space] | |
| timestep = namedtuple('timestep', ['reward', 'done']) | |
| ts = timestep( | |
| 1., | |
| 0., | |
| ) | |
| ts = policy._process_transition(batch[0], out[0], ts) | |
| assert len(set(ts.keys()).intersection({'prev_state', 'action', 'reward', 'done', 'obs'})) == 5 | |
| ts = get_transition(64 * policy._sequence_len) | |
| sample = policy._get_train_sample(ts) | |
| n_traj = len(ts) // policy._sequence_len | |
| assert len(sample) == n_traj + 1 if len(ts) % policy._sequence_len != 0 else n_traj | |
| out = policy._forward_eval(batch) | |
| assert len(set(out[0].keys()).intersection({'logit', 'action'})) == 2 | |
| assert list(out[0]['logit'].shape) == [action_space] | |
| for i in range(len(sample)): | |
| sample[i]['IS'] = sample[i]['IS'][cfg.burnin_step:] | |
| out = policy._forward_learn(sample) | |
| policy._value_rescale = False | |
| out = policy._forward_learn(sample) | |