Spaces:
Sleeping
Sleeping
| from copy import deepcopy | |
| import pytest | |
| import torch | |
| from easydict import EasyDict | |
| from ding.model.wrapper.model_wrappers import BaseModelWrapper, MultinomialSampleWrapper | |
| from ding.policy import PPOSTDIMPolicy | |
| obs_shape = 4 | |
| action_shape = 2 | |
| cfg1 = EasyDict(PPOSTDIMPolicy.default_config()) | |
| cfg1.model.obs_shape = obs_shape | |
| cfg1.model.action_shape = action_shape | |
| cfg2 = deepcopy(cfg1) | |
| cfg2.action_space = "continuous" | |
| def get_transition_discrete(size=64): | |
| data = [] | |
| for i in range(size): | |
| sample = {} | |
| sample['obs'] = torch.rand(obs_shape) | |
| sample['next_obs'] = torch.rand(obs_shape) | |
| sample['action'] = torch.tensor([0], dtype=torch.long) | |
| sample['value'] = torch.rand(1) | |
| sample['logit'] = torch.rand(size=(action_shape, )) | |
| sample['done'] = False | |
| sample['reward'] = torch.rand(1) | |
| data.append(sample) | |
| return data | |
| def test_stdim(cfg): | |
| policy = PPOSTDIMPolicy(cfg, enable_field=['collect', 'eval', 'learn']) | |
| assert type(policy._learn_model) == BaseModelWrapper | |
| assert type(policy._collect_model) == MultinomialSampleWrapper | |
| sample = get_transition_discrete(size=64) | |
| state = policy._state_dict_learn() | |
| policy._load_state_dict_learn(state) | |
| sample = get_transition_discrete(size=64) | |
| out = policy._forward_learn(sample) | |