Spaces:
Sleeping
Sleeping
| import pytest | |
| import os | |
| from ditk import logging | |
| from easydict import EasyDict | |
| from copy import deepcopy | |
| from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config | |
| from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa | |
| from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa | |
| from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config # noqa | |
| from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \ | |
| serial_pipeline_reward_model_onpolicy | |
| cfg = [ | |
| { | |
| 'type': 'pdeil', | |
| "alpha": 0.5, | |
| "discrete_action": False | |
| }, | |
| { | |
| 'type': 'gail', | |
| 'input_size': 5, | |
| 'hidden_size': 64, | |
| 'batch_size': 64, | |
| }, | |
| { | |
| 'type': 'pwil', | |
| 's_size': 4, | |
| 'a_size': 2, | |
| 'sample_size': 500, | |
| }, | |
| { | |
| 'type': 'red', | |
| 'sample_size': 5000, | |
| 'input_size': 5, | |
| 'hidden_size': 64, | |
| 'update_per_collect': 200, | |
| 'batch_size': 128, | |
| }, | |
| ] | |
| def test_irl(reward_model_config): | |
| reward_model_config = EasyDict(reward_model_config) | |
| config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) | |
| expert_policy = serial_pipeline(config, seed=0, max_train_iter=2) | |
| # collect expert demo data | |
| collect_count = 10000 | |
| expert_data_path = 'expert_data.pkl' | |
| state_dict = expert_policy.collect_mode.state_dict() | |
| config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) | |
| collect_demo_data( | |
| config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count | |
| ) | |
| # irl + rl training | |
| cp_cartpole_dqn_config = deepcopy(cartpole_dqn_config) | |
| cp_cartpole_dqn_create_config = deepcopy(cartpole_dqn_create_config) | |
| cp_cartpole_dqn_create_config.reward_model = dict(type=reward_model_config.type) | |
| if reward_model_config.type == 'gail': | |
| reward_model_config['data_path'] = '.' | |
| else: | |
| reward_model_config['expert_data_path'] = expert_data_path | |
| cp_cartpole_dqn_config.reward_model = reward_model_config | |
| cp_cartpole_dqn_config.policy.collect.n_sample = 128 | |
| serial_pipeline_reward_model_offpolicy( | |
| (cp_cartpole_dqn_config, cp_cartpole_dqn_create_config), seed=0, max_train_iter=2 | |
| ) | |
| os.popen("rm -rf ckpt_* log expert_data.pkl") | |
| def test_rnd(): | |
| config = [deepcopy(cartpole_ppo_rnd_config), deepcopy(cartpole_ppo_rnd_create_config)] | |
| try: | |
| serial_pipeline_reward_model_onpolicy(config, seed=0, max_train_iter=2) | |
| except Exception: | |
| assert False, "pipeline fail" | |
| def test_icm(): | |
| config = [deepcopy(cartpole_ppo_icm_config), deepcopy(cartpole_ppo_icm_create_config)] | |
| try: | |
| serial_pipeline_reward_model_offpolicy(config, seed=0, max_train_iter=2) | |
| except Exception: | |
| assert False, "pipeline fail" | |