Spaces:
Sleeping
Sleeping
| import pytest | |
| from copy import deepcopy | |
| from ding.entry.serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream | |
| from dizoo.classic_control.pendulum.config.mbrl.pendulum_sac_mbpo_config \ | |
| import main_config as pendulum_sac_mbpo_main_config,\ | |
| create_config as pendulum_sac_mbpo_create_config | |
| from dizoo.classic_control.pendulum.config.mbrl.pendulum_mbsac_mbpo_config \ | |
| import main_config as pendulum_mbsac_mbpo_main_config,\ | |
| create_config as pendulum_mbsac_mbpo_create_config | |
| from dizoo.classic_control.pendulum.config.mbrl.pendulum_stevesac_mbpo_config \ | |
| import main_config as pendulum_stevesac_mbpo_main_config,\ | |
| create_config as pendulum_stevesac_mbpo_create_config | |
| def test_dyna(): | |
| config = [deepcopy(pendulum_sac_mbpo_main_config), deepcopy(pendulum_sac_mbpo_create_config)] | |
| config[0].world_model.model.max_epochs_since_update = 0 | |
| try: | |
| serial_pipeline_dyna(config, seed=0, max_train_iter=1) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| def test_dream(): | |
| configs = [ | |
| [deepcopy(pendulum_mbsac_mbpo_main_config), | |
| deepcopy(pendulum_mbsac_mbpo_create_config)], | |
| [deepcopy(pendulum_stevesac_mbpo_main_config), | |
| deepcopy(pendulum_stevesac_mbpo_create_config)] | |
| ] | |
| try: | |
| for config in configs: | |
| config[0].world_model.model.max_epochs_since_update = 0 | |
| serial_pipeline_dream(config, seed=0, max_train_iter=1) | |
| except Exception: | |
| assert False, "pipeline fail" | |