Spaces:
Sleeping
Sleeping
| import pytest | |
| import torch | |
| from copy import deepcopy | |
| from ding.entry import serial_pipeline | |
| from ding.entry.serial_entry_bco import serial_pipeline_bco | |
| from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config | |
| from dizoo.classic_control.cartpole.config.cartpole_bco_config import cartpole_bco_config, cartpole_bco_create_config | |
| def test_bco(): | |
| expert_policy_state_dict_path = './expert_policy.pth' | |
| expert_config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)] | |
| expert_policy = serial_pipeline(expert_config, seed=0) | |
| torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path) | |
| config = [deepcopy(cartpole_bco_config), deepcopy(cartpole_bco_create_config)] | |
| config[0].policy.collect.model_path = expert_policy_state_dict_path | |
| try: | |
| serial_pipeline_bco( | |
| config, [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)], seed=0, max_train_iter=3 | |
| ) | |
| except Exception as e: | |
| print(e) | |
| assert False, "pipeline fail" | |