Spaces:
Sleeping
Sleeping
| import pytest | |
| import torch | |
| from itertools import product | |
| from ding.world_model.model.ensemble import EnsembleFC, EnsembleModel | |
| # arguments | |
| state_size = [16] | |
| action_size = [16, 1] | |
| reward_size = [1] | |
| args = list(product(*[state_size, action_size, reward_size])) | |
| def test_EnsembleFC(): | |
| in_dim, out_dim, ensemble_size, B = 4, 8, 7, 64 | |
| fc = EnsembleFC(in_dim, out_dim, ensemble_size) | |
| x = torch.randn(ensemble_size, B, in_dim) | |
| y = fc(x) | |
| assert y.shape == (ensemble_size, B, out_dim) | |
| def test_EnsembleModel(state_size, action_size, reward_size): | |
| ensemble_size, B = 7, 64 | |
| model = EnsembleModel(state_size, action_size, reward_size, ensemble_size) | |
| x = torch.randn(ensemble_size, B, state_size + action_size) | |
| y = model(x) | |
| assert len(y) == 2 | |
| assert y[0].shape == y[1].shape == (ensemble_size, B, state_size + reward_size) | |