Spaces:
Sleeping
Sleeping
| import pytest | |
| import torch | |
| from torch import nn | |
| from itertools import product | |
| from easydict import EasyDict | |
| from ding.world_model.ddppo import DDPPOWorldMode, get_batch_jacobian, get_neighbor_index | |
| from ding.utils import deep_merge_dicts | |
| # arguments | |
| state_size = [16] | |
| action_size = [16, 1] | |
| args = list(product(*[state_size, action_size])) | |
| class TestDDPPO: | |
| def get_world_model(self, state_size, action_size): | |
| cfg = DDPPOWorldMode.default_config() | |
| cfg.model.max_epochs_since_update = 0 | |
| cfg = deep_merge_dicts( | |
| cfg, dict(cuda=False, model=dict(state_size=state_size, action_size=action_size, reward_size=1)) | |
| ) | |
| fake_env = EasyDict(termination_fn=lambda obs: torch.zeros_like(obs.sum(-1)).bool()) | |
| model = DDPPOWorldMode(cfg, fake_env, None) | |
| model.serial_calc_nn = True | |
| return model | |
| def test_get_neighbor_index(self): | |
| k = 2 | |
| data = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 0, -1], [5, 0, 0], [5, 0, 1], [5, 0, -1]]) | |
| idx = get_neighbor_index(data, k, serial=True) | |
| target_idx = torch.tensor([[2, 1], [0, 2], [0, 1], [5, 4], [3, 5], [3, 4]]) | |
| assert (idx - target_idx).sum() == 0 | |
| def test_get_batch_jacobian(self): | |
| B, in_dim, out_dim = 64, 4, 8 | |
| net = nn.Linear(in_dim, out_dim) | |
| x = torch.randn(B, in_dim) | |
| jacobian = get_batch_jacobian(net, x, out_dim) | |
| assert jacobian.shape == (B, out_dim, in_dim) | |
| def test_get_jacobian(self, state_size, action_size): | |
| B, ensemble_size = 64, 7 | |
| model = self.get_world_model(state_size, action_size) | |
| train_input_reg = torch.randn(ensemble_size, B, state_size + action_size) | |
| jacobian = model._get_jacobian(model.gradient_model, train_input_reg) | |
| assert jacobian.shape == (ensemble_size, B, state_size + 1, state_size + action_size) | |
| assert jacobian.requires_grad | |
| def test_step(self, state_size, action_size): | |
| states = torch.rand(128, state_size) | |
| actions = torch.rand(128, action_size) | |
| model = self.get_world_model(state_size, action_size) | |
| model.elite_model_idxes = [0, 1] | |
| rewards, next_obs, dones = model.step(states, actions) | |
| assert rewards.shape == (128, ) | |
| assert next_obs.shape == (128, state_size) | |
| assert dones.shape == (128, ) | |
| def test_train_rollout_model(self, state_size, action_size): | |
| states = torch.rand(1280, state_size) | |
| actions = torch.rand(1280, action_size) | |
| next_states = states + actions.mean(1, keepdim=True) | |
| rewards = next_states.mean(1, keepdim=True).repeat(1, 1) | |
| inputs = torch.cat([states, actions], dim=1) | |
| labels = torch.cat([rewards, next_states], dim=1) | |
| model = self.get_world_model(state_size, action_size) | |
| model._train_rollout_model(inputs[:64], labels[:64]) | |
| def test_train_graident_model(self, state_size, action_size): | |
| states = torch.rand(1280, state_size) | |
| actions = torch.rand(1280, action_size) | |
| next_states = states + actions.mean(1, keepdim=True) | |
| rewards = next_states.mean(1, keepdim=True) | |
| inputs = torch.cat([states, actions], dim=1) | |
| labels = torch.cat([rewards, next_states], dim=1) | |
| model = self.get_world_model(state_size, action_size) | |
| model._train_gradient_model(inputs[:64], labels[:64], inputs[:64], labels[:64]) | |
| def test_others(self, state_size, action_size): | |
| states = torch.rand(1280, state_size) | |
| actions = torch.rand(1280, action_size) | |
| next_states = states + actions.mean(1, keepdim=True) | |
| rewards = next_states.mean(1, keepdim=True) | |
| inputs = torch.cat([states, actions], dim=1) | |
| labels = torch.cat([rewards, next_states], dim=1) | |
| model = self.get_world_model(state_size, action_size) | |
| model._train_rollout_model(inputs[:64], labels[:64]) | |
| model._train_gradient_model(inputs[:64], labels[:64], inputs[:64], labels[:64]) | |
| model._save_states() | |
| model._load_states() | |
| model._save_best(0, [1, 2, 3]) | |