Spaces:
Sleeping
Sleeping
| #unittest for priority_calculator | |
| import unittest | |
| import pytest | |
| import numpy as np | |
| from unittest.mock import Mock, patch | |
| from ding.framework import OnlineRLContext, OfflineRLContext | |
| from ding.framework import task, Parallel | |
| from ding.framework.middleware.functional import priority_calculator | |
| class MockPolicy(Mock): | |
| def priority_fun(self, data): | |
| return np.random.rand(len(data)) | |
| def test_priority_calculator(): | |
| policy = MockPolicy() | |
| ctx = OnlineRLContext() | |
| ctx.trajectories = [ | |
| { | |
| 'obs': np.random.rand(2, 2), | |
| 'next_obs': np.random.rand(2, 2), | |
| 'reward': np.random.rand(1), | |
| 'info': {} | |
| } for _ in range(10) | |
| ] | |
| priority_calculator_middleware = priority_calculator(priority_calculation_fn=policy.priority_fun) | |
| priority_calculator_middleware(ctx) | |
| assert len(ctx.trajectories) == 10 | |
| assert all([isinstance(traj['priority'], float) for traj in ctx.trajectories]) | |