Spaces:
Sleeping
Sleeping
| import sys | |
| import timeit | |
| import torch | |
| import random | |
| import pytest | |
| import numpy as np | |
| from ding.data.buffer import DequeBuffer | |
| from ding.data.buffer.middleware import clone_object, PriorityExperienceReplay | |
| # test different buffer size, eg: 1000, 10000, 100000; | |
| size_list = [1000, 10000] | |
| # test different tensor dim, eg: 32*32, 128*128, 512*512; | |
| data_dim_list = [32, 128] | |
| # repeat times. | |
| repeats = 100 | |
| class BufferBenchmark: | |
| def __init__(self, buffer_size, data_dim, buffer_type='base') -> None: | |
| self._buffer = DequeBuffer(size=buffer_size) | |
| self._meta = dict() | |
| if buffer_type == "clone": | |
| self._buffer.use(clone_object()) | |
| if buffer_type == "priority": | |
| self._buffer.use(PriorityExperienceReplay(self._buffer)) | |
| self._meta["priority"] = 2.0 | |
| self._data = {"obs": torch.rand(data_dim, data_dim)} | |
| def data_storage(self) -> float: | |
| return sys.getsizeof(self._data["obs"].storage()) / 1024 | |
| def count(self) -> int: | |
| return self._buffer.count() | |
| def push_op(self) -> None: | |
| self._buffer.push(self._data, meta=self._meta) | |
| def push_with_group_info(self, num_keys=256) -> None: | |
| meta = self._meta.copy() | |
| rand = random.random() | |
| value = int(rand * num_keys) | |
| meta['group'] = value | |
| self._buffer.push(self._data, meta=meta) | |
| def sample_op(self) -> None: | |
| self._buffer.sample(128, replace=False) | |
| def replace_sample_op(self) -> None: | |
| self._buffer.sample(128, replace=True) | |
| def groupby_sample_op(self) -> None: | |
| self._buffer.sample(128, groupby="group") | |
| def get_mean_std(res): | |
| # return the total time per 1000 ops | |
| return np.mean(res) * 1000.0 / repeats, np.std(res) * 1000.0 / repeats | |
| def test_benchmark(buffer_type): | |
| for size in size_list: | |
| for dim in data_dim_list: | |
| assert size >= 128, "size is too small, please set an int no less than 128!" | |
| buffer_test = BufferBenchmark(size, dim, buffer_type) | |
| print("exp-buffer_{}_{}-data_{:.2f}_KB".format(buffer_type, size, buffer_test.data_storage())) | |
| # test pushing | |
| mean, std = get_mean_std(timeit.repeat(buffer_test.push_op, number=repeats)) | |
| print("Empty Push Test: mean {:.4f} s, std {:.4f} s".format(mean, std)) | |
| # fill the buffer before sampling tests | |
| for _ in range(size): | |
| buffer_test.push_with_group_info() | |
| assert buffer_test.count() == size, "buffer is not full when testing sampling!" | |
| # test sampling without replace | |
| mean, std = get_mean_std(timeit.repeat(buffer_test.sample_op, number=repeats)) | |
| print("No-Replace Sample Test: mean {:.4f} s, std {:.4f} s".format(mean, std)) | |
| # test sampling with replace | |
| mean, std = get_mean_std(timeit.repeat(buffer_test.replace_sample_op, number=repeats)) | |
| print("Replace Sample Test: mean {:.4f} s, std {:.4f} s".format(mean, std)) | |
| # test groupby sampling | |
| if buffer_type != 'priority': | |
| mean, std = get_mean_std(timeit.repeat(buffer_test.groupby_sample_op, number=repeats)) | |
| print("Groupby Sample Test: mean {:.4f} s, std {:.4f} s".format(mean, std)) | |
| print("=" * 100) | |