Spaces:
Sleeping
Sleeping
| from typing import Iterable, Any, Optional, List | |
| from collections.abc import Sequence | |
| import numbers | |
| import time | |
| import copy | |
| from threading import Thread | |
| from queue import Queue | |
| import numpy as np | |
| import torch | |
| import treetensor.torch as ttorch | |
| from ding.utils.default_helper import get_shape0 | |
| def to_device(item: Any, device: str, ignore_keys: list = []) -> Any: | |
| """ | |
| Overview: | |
| Transfer data to certain device. | |
| Arguments: | |
| - item (:obj:`Any`): The item to be transferred. | |
| - device (:obj:`str`): The device wanted. | |
| - ignore_keys (:obj:`list`): The keys to be ignored in transfer, default set to empty. | |
| Returns: | |
| - item (:obj:`Any`): The transferred item. | |
| Examples: | |
| >>> setup_data_dict['module'] = nn.Linear(3, 5) | |
| >>> device = 'cuda' | |
| >>> cuda_d = to_device(setup_data_dict, device, ignore_keys=['module']) | |
| >>> assert cuda_d['module'].weight.device == torch.device('cpu') | |
| Examples: | |
| >>> setup_data_dict['module'] = nn.Linear(3, 5) | |
| >>> device = 'cuda' | |
| >>> cuda_d = to_device(setup_data_dict, device) | |
| >>> assert cuda_d['module'].weight.device == torch.device('cuda:0') | |
| .. note: | |
| Now supports item type: :obj:`torch.nn.Module`, :obj:`torch.Tensor`, :obj:`Sequence`, \ | |
| :obj:`dict`, :obj:`numbers.Integral`, :obj:`numbers.Real`, :obj:`np.ndarray`, :obj:`str` and :obj:`None`. | |
| """ | |
| if isinstance(item, torch.nn.Module): | |
| return item.to(device) | |
| elif isinstance(item, ttorch.Tensor): | |
| if 'prev_state' in item: | |
| prev_state = to_device(item.prev_state, device) | |
| del item.prev_state | |
| item = item.to(device) | |
| item.prev_state = prev_state | |
| return item | |
| else: | |
| return item.to(device) | |
| elif isinstance(item, torch.Tensor): | |
| return item.to(device) | |
| elif isinstance(item, Sequence): | |
| if isinstance(item, str): | |
| return item | |
| else: | |
| return [to_device(t, device) for t in item] | |
| elif isinstance(item, dict): | |
| new_item = {} | |
| for k in item.keys(): | |
| if k in ignore_keys: | |
| new_item[k] = item[k] | |
| else: | |
| new_item[k] = to_device(item[k], device) | |
| return new_item | |
| elif isinstance(item, numbers.Integral) or isinstance(item, numbers.Real): | |
| return item | |
| elif isinstance(item, np.ndarray) or isinstance(item, np.bool_): | |
| return item | |
| elif item is None or isinstance(item, str): | |
| return item | |
| elif isinstance(item, torch.distributions.Distribution): # for compatibility | |
| return item | |
| else: | |
| raise TypeError("not support item type: {}".format(type(item))) | |
| def to_dtype(item: Any, dtype: type) -> Any: | |
| """ | |
| Overview: | |
| Change data to certain dtype. | |
| Arguments: | |
| - item (:obj:`Any`): The item for changing the dtype. | |
| - dtype (:obj:`type`): The type wanted. | |
| Returns: | |
| - item (:obj:`object`): The item with changed dtype. | |
| Examples (tensor): | |
| >>> t = torch.randint(0, 10, (3, 5)) | |
| >>> tfloat = to_dtype(t, torch.float) | |
| >>> assert tfloat.dtype == torch.float | |
| Examples (list): | |
| >>> tlist = [torch.randint(0, 10, (3, 5))] | |
| >>> tlfloat = to_dtype(tlist, torch.float) | |
| >>> assert tlfloat[0].dtype == torch.float | |
| Examples (dict): | |
| >>> tdict = {'t': torch.randint(0, 10, (3, 5))} | |
| >>> tdictf = to_dtype(tdict, torch.float) | |
| >>> assert tdictf['t'].dtype == torch.float | |
| .. note: | |
| Now supports item type: :obj:`torch.Tensor`, :obj:`Sequence`, :obj:`dict`. | |
| """ | |
| if isinstance(item, torch.Tensor): | |
| return item.to(dtype=dtype) | |
| elif isinstance(item, Sequence): | |
| return [to_dtype(t, dtype) for t in item] | |
| elif isinstance(item, dict): | |
| return {k: to_dtype(item[k], dtype) for k in item.keys()} | |
| else: | |
| raise TypeError("not support item type: {}".format(type(item))) | |
| def to_tensor( | |
| item: Any, dtype: Optional[torch.dtype] = None, ignore_keys: list = [], transform_scalar: bool = True | |
| ) -> Any: | |
| """ | |
| Overview: | |
| Convert ``numpy.ndarray`` object to ``torch.Tensor``. | |
| Arguments: | |
| - item (:obj:`Any`): The ``numpy.ndarray`` objects to be converted. It can be exactly a ``numpy.ndarray`` \ | |
| object or a container (list, tuple or dict) that contains several ``numpy.ndarray`` objects. | |
| - dtype (:obj:`torch.dtype`): The type of wanted tensor. If set to ``None``, its dtype will be unchanged. | |
| - ignore_keys (:obj:`list`): If the ``item`` is a dict, values whose keys are in ``ignore_keys`` will not \ | |
| be converted. | |
| - transform_scalar (:obj:`bool`): If set to ``True``, a scalar will be also converted to a tensor object. | |
| Returns: | |
| - item (:obj:`Any`): The converted tensors. | |
| Examples (scalar): | |
| >>> i = 10 | |
| >>> t = to_tensor(i) | |
| >>> assert t.item() == i | |
| Examples (dict): | |
| >>> d = {'i': i} | |
| >>> dt = to_tensor(d, torch.int) | |
| >>> assert dt['i'].item() == i | |
| Examples (named tuple): | |
| >>> data_type = namedtuple('data_type', ['x', 'y']) | |
| >>> inputs = data_type(np.random.random(3), 4) | |
| >>> outputs = to_tensor(inputs, torch.float32) | |
| >>> assert type(outputs) == data_type | |
| >>> assert isinstance(outputs.x, torch.Tensor) | |
| >>> assert isinstance(outputs.y, torch.Tensor) | |
| >>> assert outputs.x.dtype == torch.float32 | |
| >>> assert outputs.y.dtype == torch.float32 | |
| .. note: | |
| Now supports item type: :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`. | |
| """ | |
| def transform(d): | |
| if dtype is None: | |
| return torch.as_tensor(d) | |
| else: | |
| return torch.tensor(d, dtype=dtype) | |
| if isinstance(item, dict): | |
| new_data = {} | |
| for k, v in item.items(): | |
| if k in ignore_keys: | |
| new_data[k] = v | |
| else: | |
| new_data[k] = to_tensor(v, dtype, ignore_keys, transform_scalar) | |
| return new_data | |
| elif isinstance(item, list) or isinstance(item, tuple): | |
| if len(item) == 0: | |
| return [] | |
| elif isinstance(item[0], numbers.Integral) or isinstance(item[0], numbers.Real): | |
| return transform(item) | |
| elif hasattr(item, '_fields'): # namedtuple | |
| return type(item)(*[to_tensor(t, dtype) for t in item]) | |
| else: | |
| new_data = [] | |
| for t in item: | |
| new_data.append(to_tensor(t, dtype, ignore_keys, transform_scalar)) | |
| return new_data | |
| elif isinstance(item, np.ndarray): | |
| if dtype is None: | |
| if item.dtype == np.float64: | |
| return torch.FloatTensor(item) | |
| else: | |
| return torch.from_numpy(item) | |
| else: | |
| return torch.from_numpy(item).to(dtype) | |
| elif isinstance(item, bool) or isinstance(item, str): | |
| return item | |
| elif np.isscalar(item): | |
| if transform_scalar: | |
| if dtype is None: | |
| return torch.as_tensor(item) | |
| else: | |
| return torch.as_tensor(item).to(dtype) | |
| else: | |
| return item | |
| elif item is None: | |
| return None | |
| elif isinstance(item, torch.Tensor): | |
| if dtype is None: | |
| return item | |
| else: | |
| return item.to(dtype) | |
| else: | |
| raise TypeError("not support item type: {}".format(type(item))) | |
| def to_ndarray(item: Any, dtype: np.dtype = None) -> Any: | |
| """ | |
| Overview: | |
| Convert ``torch.Tensor`` to ``numpy.ndarray``. | |
| Arguments: | |
| - item (:obj:`Any`): The ``torch.Tensor`` objects to be converted. It can be exactly a ``torch.Tensor`` \ | |
| object or a container (list, tuple or dict) that contains several ``torch.Tensor`` objects. | |
| - dtype (:obj:`np.dtype`): The type of wanted array. If set to ``None``, its dtype will be unchanged. | |
| Returns: | |
| - item (:obj:`object`): The changed arrays. | |
| Examples (ndarray): | |
| >>> t = torch.randn(3, 5) | |
| >>> tarray1 = to_ndarray(t) | |
| >>> assert tarray1.shape == (3, 5) | |
| >>> assert isinstance(tarray1, np.ndarray) | |
| Examples (list): | |
| >>> t = [torch.randn(5, ) for i in range(3)] | |
| >>> tarray1 = to_ndarray(t, np.float32) | |
| >>> assert isinstance(tarray1, list) | |
| >>> assert tarray1[0].shape == (5, ) | |
| >>> assert isinstance(tarray1[0], np.ndarray) | |
| .. note: | |
| Now supports item type: :obj:`torch.Tensor`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`. | |
| """ | |
| def transform(d): | |
| if dtype is None: | |
| return np.array(d) | |
| else: | |
| return np.array(d, dtype=dtype) | |
| if isinstance(item, dict): | |
| new_data = {} | |
| for k, v in item.items(): | |
| new_data[k] = to_ndarray(v, dtype) | |
| return new_data | |
| elif isinstance(item, list) or isinstance(item, tuple): | |
| if len(item) == 0: | |
| return None | |
| elif isinstance(item[0], numbers.Integral) or isinstance(item[0], numbers.Real): | |
| return transform(item) | |
| elif hasattr(item, '_fields'): # namedtuple | |
| return type(item)(*[to_ndarray(t, dtype) for t in item]) | |
| else: | |
| new_data = [] | |
| for t in item: | |
| new_data.append(to_ndarray(t, dtype)) | |
| return new_data | |
| elif isinstance(item, torch.Tensor): | |
| if dtype is None: | |
| return item.numpy() | |
| else: | |
| return item.numpy().astype(dtype) | |
| elif isinstance(item, np.ndarray): | |
| if dtype is None: | |
| return item | |
| else: | |
| return item.astype(dtype) | |
| elif isinstance(item, bool) or isinstance(item, str): | |
| return item | |
| elif np.isscalar(item): | |
| if dtype is None: | |
| return np.array(item) | |
| else: | |
| return np.array(item, dtype=dtype) | |
| elif item is None: | |
| return None | |
| else: | |
| raise TypeError("not support item type: {}".format(type(item))) | |
| def to_list(item: Any) -> Any: | |
| """ | |
| Overview: | |
| Convert ``torch.Tensor``, ``numpy.ndarray`` objects to ``list`` objects, and keep their dtypes unchanged. | |
| Arguments: | |
| - item (:obj:`Any`): The item to be converted. | |
| Returns: | |
| - item (:obj:`Any`): The list after conversion. | |
| Examples: | |
| >>> data = { \ | |
| 'tensor': torch.randn(4), \ | |
| 'list': [True, False, False], \ | |
| 'tuple': (4, 5, 6), \ | |
| 'bool': True, \ | |
| 'int': 10, \ | |
| 'float': 10., \ | |
| 'array': np.random.randn(4), \ | |
| 'str': "asdf", \ | |
| 'none': None, \ | |
| } \ | |
| >>> transformed_data = to_list(data) | |
| .. note:: | |
| Now supports item type: :obj:`torch.Tensor`, :obj:`numpy.ndarray`, :obj:`dict`, :obj:`list`, \ | |
| :obj:`tuple` and :obj:`None`. | |
| """ | |
| if item is None: | |
| return item | |
| elif isinstance(item, torch.Tensor): | |
| return item.tolist() | |
| elif isinstance(item, np.ndarray): | |
| return item.tolist() | |
| elif isinstance(item, list) or isinstance(item, tuple): | |
| return [to_list(t) for t in item] | |
| elif isinstance(item, dict): | |
| return {k: to_list(v) for k, v in item.items()} | |
| elif np.isscalar(item): | |
| return item | |
| else: | |
| raise TypeError("not support item type: {}".format(type(item))) | |
| def tensor_to_list(item: Any) -> Any: | |
| """ | |
| Overview: | |
| Convert ``torch.Tensor`` objects to ``list``, and keep their dtypes unchanged. | |
| Arguments: | |
| - item (:obj:`Any`): The item to be converted. | |
| Returns: | |
| - item (:obj:`Any`): The lists after conversion. | |
| Examples (2d-tensor): | |
| >>> t = torch.randn(3, 5) | |
| >>> tlist1 = tensor_to_list(t) | |
| >>> assert len(tlist1) == 3 | |
| >>> assert len(tlist1[0]) == 5 | |
| Examples (1d-tensor): | |
| >>> t = torch.randn(3, ) | |
| >>> tlist1 = tensor_to_list(t) | |
| >>> assert len(tlist1) == 3 | |
| Examples (list) | |
| >>> t = [torch.randn(5, ) for i in range(3)] | |
| >>> tlist1 = tensor_to_list(t) | |
| >>> assert len(tlist1) == 3 | |
| >>> assert len(tlist1[0]) == 5 | |
| Examples (dict): | |
| >>> td = {'t': torch.randn(3, 5)} | |
| >>> tdlist1 = tensor_to_list(td) | |
| >>> assert len(tdlist1['t']) == 3 | |
| >>> assert len(tdlist1['t'][0]) == 5 | |
| .. note:: | |
| Now supports item type: :obj:`torch.Tensor`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`. | |
| """ | |
| if item is None: | |
| return item | |
| elif isinstance(item, torch.Tensor): | |
| return item.tolist() | |
| elif isinstance(item, list) or isinstance(item, tuple): | |
| return [tensor_to_list(t) for t in item] | |
| elif isinstance(item, dict): | |
| return {k: tensor_to_list(v) for k, v in item.items()} | |
| elif np.isscalar(item): | |
| return item | |
| else: | |
| raise TypeError("not support item type: {}".format(type(item))) | |
| def to_item(data: Any, ignore_error: bool = True) -> Any: | |
| """ | |
| Overview: | |
| Convert data to python native scalar (i.e. data item), and keep their dtypes unchanged. | |
| Arguments: | |
| - data (:obj:`Any`): The data that needs to be converted. | |
| - ignore_error (:obj:`bool`): Whether to ignore the error when the data type is not supported. That is to \ | |
| say, only the data can be transformed into a python native scalar will be returned. | |
| Returns: | |
| - data (:obj:`Any`): Converted data. | |
| Examples: | |
| >>>> data = { \ | |
| 'tensor': torch.randn(1), \ | |
| 'list': [True, False, torch.randn(1)], \ | |
| 'tuple': (4, 5, 6), \ | |
| 'bool': True, \ | |
| 'int': 10, \ | |
| 'float': 10., \ | |
| 'array': np.random.randn(1), \ | |
| 'str': "asdf", \ | |
| 'none': None, \ | |
| } | |
| >>>> new_data = to_item(data) | |
| >>>> assert np.isscalar(new_data['tensor']) | |
| >>>> assert np.isscalar(new_data['array']) | |
| >>>> assert np.isscalar(new_data['list'][-1]) | |
| .. note:: | |
| Now supports item type: :obj:`torch.Tensor`, :obj:`torch.Tensor`, :obj:`ttorch.Tensor`, \ | |
| :obj:`bool`, :obj:`str`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`. | |
| """ | |
| if data is None: | |
| return data | |
| elif isinstance(data, bool) or isinstance(data, str): | |
| return data | |
| elif np.isscalar(data): | |
| return data | |
| elif isinstance(data, np.ndarray) or isinstance(data, torch.Tensor) or isinstance(data, ttorch.Tensor): | |
| return data.item() | |
| elif isinstance(data, list) or isinstance(data, tuple): | |
| return [to_item(d) for d in data] | |
| elif isinstance(data, dict): | |
| new_data = {} | |
| for k, v in data.items(): | |
| if ignore_error: | |
| try: | |
| new_data[k] = to_item(v) | |
| except (ValueError, RuntimeError): | |
| pass | |
| else: | |
| new_data[k] = to_item(v) | |
| return new_data | |
| else: | |
| raise TypeError("not support data type: {}".format(data)) | |
| def same_shape(data: list) -> bool: | |
| """ | |
| Overview: | |
| Judge whether all data elements in a list have the same shapes. | |
| Arguments: | |
| - data (:obj:`list`): The list of data. | |
| Returns: | |
| - same (:obj:`bool`): Whether the list of data all have the same shape. | |
| Examples: | |
| >>> tlist = [torch.randn(3, 5) for i in range(5)] | |
| >>> assert same_shape(tlist) | |
| >>> tlist = [torch.randn(3, 5), torch.randn(4, 5)] | |
| >>> assert not same_shape(tlist) | |
| """ | |
| assert (isinstance(data, list)) | |
| shapes = [t.shape for t in data] | |
| return len(set(shapes)) == 1 | |
| class LogDict(dict): | |
| """ | |
| Overview: | |
| Derived from ``dict``. Would convert ``torch.Tensor`` to ``list`` for convenient logging. | |
| Interfaces: | |
| ``_transform``, ``__setitem__``, ``update``. | |
| """ | |
| def _transform(self, data: Any) -> None: | |
| """ | |
| Overview: | |
| Convert tensor objects to lists for better logging. | |
| Arguments: | |
| - data (:obj:`Any`): The input data to be converted. | |
| """ | |
| if isinstance(data, torch.Tensor): | |
| new_data = data.tolist() | |
| else: | |
| new_data = data | |
| return new_data | |
| def __setitem__(self, key: Any, value: Any) -> None: | |
| """ | |
| Overview: | |
| Override the ``__setitem__`` function of built-in dict. | |
| Arguments: | |
| - key (:obj:`Any`): The key of the data item. | |
| - value (:obj:`Any`): The value of the data item. | |
| """ | |
| new_value = self._transform(value) | |
| super().__setitem__(key, new_value) | |
| def update(self, data: dict) -> None: | |
| """ | |
| Overview: | |
| Override the ``update`` function of built-in dict. | |
| Arguments: | |
| - data (:obj:`dict`): The dict for updating current object. | |
| """ | |
| for k, v in data.items(): | |
| self.__setitem__(k, v) | |
| def build_log_buffer() -> LogDict: | |
| """ | |
| Overview: | |
| Build log buffer, a subclass of dict, which can convert the input data into log format. | |
| Returns: | |
| - log_buffer (:obj:`LogDict`): Log buffer dict. | |
| Examples: | |
| >>> log_buffer = build_log_buffer() | |
| >>> log_buffer['not_tensor'] = torch.randn(3) | |
| >>> assert isinstance(log_buffer['not_tensor'], list) | |
| >>> assert len(log_buffer['not_tensor']) == 3 | |
| >>> log_buffer.update({'not_tensor': 4, 'a': 5}) | |
| >>> assert log_buffer['not_tensor'] == 4 | |
| """ | |
| return LogDict() | |
| class CudaFetcher(object): | |
| """ | |
| Overview: | |
| Fetch data from source, and transfer it to a specified device. | |
| Interfaces: | |
| ``__init__``, ``__next__``, ``run``, ``close``. | |
| """ | |
| def __init__(self, data_source: Iterable, device: str, queue_size: int = 4, sleep: float = 0.1) -> None: | |
| """ | |
| Overview: | |
| Initialize the CudaFetcher object using the given arguments. | |
| Arguments: | |
| - data_source (:obj:`Iterable`): The iterable data source. | |
| - device (:obj:`str`): The device to put data to, such as "cuda:0". | |
| - queue_size (:obj:`int`): The internal size of queue, such as 4. | |
| - sleep (:obj:`float`): Sleeping time when the internal queue is full. | |
| """ | |
| self._source = data_source | |
| self._queue = Queue(maxsize=queue_size) | |
| self._stream = torch.cuda.Stream() | |
| self._producer_thread = Thread(target=self._producer, args=(), name='cuda_fetcher_producer') | |
| self._sleep = sleep | |
| self._device = device | |
| def __next__(self) -> Any: | |
| """ | |
| Overview: | |
| Response to the request for data. Return one data item from the internal queue. | |
| Returns: | |
| - item (:obj:`Any`): The data item on the required device. | |
| """ | |
| return self._queue.get() | |
| def run(self) -> None: | |
| """ | |
| Overview: | |
| Start ``producer`` thread: Keep fetching data from source, change the device, and put into \ | |
| ``queue`` for request. | |
| Examples: | |
| >>> timer = EasyTimer() | |
| >>> dataloader = iter([torch.randn(3, 3) for _ in range(10)]) | |
| >>> dataloader = CudaFetcher(dataloader, device='cuda', sleep=0.1) | |
| >>> dataloader.run() | |
| >>> data = next(dataloader) | |
| """ | |
| self._end_flag = False | |
| self._producer_thread.start() | |
| def close(self) -> None: | |
| """ | |
| Overview: | |
| Stop ``producer`` thread by setting ``end_flag`` to ``True`` . | |
| """ | |
| self._end_flag = True | |
| def _producer(self) -> None: | |
| """ | |
| Overview: | |
| Keep fetching data from source, change the device, and put into ``queue`` for request. | |
| """ | |
| with torch.cuda.stream(self._stream): | |
| while not self._end_flag: | |
| if self._queue.full(): | |
| time.sleep(self._sleep) | |
| else: | |
| data = next(self._source) | |
| data = to_device(data, self._device) | |
| self._queue.put(data) | |
| def get_tensor_data(data: Any) -> Any: | |
| """ | |
| Overview: | |
| Get pure tensor data from the given data (without disturbing grad computation graph). | |
| Arguments: | |
| - data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict). | |
| Returns: | |
| - output (:obj:`Any`): The output data. | |
| Examples: | |
| >>> a = { \ | |
| 'tensor': torch.tensor([1, 2, 3.], requires_grad=True), \ | |
| 'list': [torch.tensor([1, 2, 3.], requires_grad=True) for _ in range(2)], \ | |
| 'none': None \ | |
| } | |
| >>> tensor_a = get_tensor_data(a) | |
| >>> assert not tensor_a['tensor'].requires_grad | |
| >>> for t in tensor_a['list']: | |
| >>> assert not t.requires_grad | |
| """ | |
| if isinstance(data, torch.Tensor): | |
| return data.data.clone() | |
| elif data is None: | |
| return None | |
| elif isinstance(data, Sequence): | |
| return [get_tensor_data(d) for d in data] | |
| elif isinstance(data, dict): | |
| return {k: get_tensor_data(v) for k, v in data.items()} | |
| else: | |
| raise TypeError("not support type in get_tensor_data: {}".format(type(data))) | |
| def unsqueeze(data: Any, dim: int = 0) -> Any: | |
| """ | |
| Overview: | |
| Unsqueeze the tensor data. | |
| Arguments: | |
| - data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict). | |
| - dim (:obj:`int`): The dimension to be unsqueezed. | |
| Returns: | |
| - output (:obj:`Any`): The output data. | |
| Examples (tensor): | |
| >>> t = torch.randn(3, 3) | |
| >>> tt = unsqueeze(t, dim=0) | |
| >>> assert tt.shape == torch.Shape([1, 3, 3]) | |
| Examples (list): | |
| >>> t = [torch.randn(3, 3)] | |
| >>> tt = unsqueeze(t, dim=0) | |
| >>> assert tt[0].shape == torch.Shape([1, 3, 3]) | |
| Examples (dict): | |
| >>> t = {"t": torch.randn(3, 3)} | |
| >>> tt = unsqueeze(t, dim=0) | |
| >>> assert tt["t"].shape == torch.Shape([1, 3, 3]) | |
| """ | |
| if isinstance(data, torch.Tensor): | |
| return data.unsqueeze(dim) | |
| elif isinstance(data, Sequence): | |
| return [unsqueeze(d) for d in data] | |
| elif isinstance(data, dict): | |
| return {k: unsqueeze(v, 0) for k, v in data.items()} | |
| else: | |
| raise TypeError("not support type in unsqueeze: {}".format(type(data))) | |
| def squeeze(data: Any, dim: int = 0) -> Any: | |
| """ | |
| Overview: | |
| Squeeze the tensor data. | |
| Arguments: | |
| - data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict). | |
| - dim (:obj:`int`): The dimension to be Squeezed. | |
| Returns: | |
| - output (:obj:`Any`): The output data. | |
| Examples (tensor): | |
| >>> t = torch.randn(1, 3, 3) | |
| >>> tt = squeeze(t, dim=0) | |
| >>> assert tt.shape == torch.Shape([3, 3]) | |
| Examples (list): | |
| >>> t = [torch.randn(1, 3, 3)] | |
| >>> tt = squeeze(t, dim=0) | |
| >>> assert tt[0].shape == torch.Shape([3, 3]) | |
| Examples (dict): | |
| >>> t = {"t": torch.randn(1, 3, 3)} | |
| >>> tt = squeeze(t, dim=0) | |
| >>> assert tt["t"].shape == torch.Shape([3, 3]) | |
| """ | |
| if isinstance(data, torch.Tensor): | |
| return data.squeeze(dim) | |
| elif isinstance(data, Sequence): | |
| return [squeeze(d) for d in data] | |
| elif isinstance(data, dict): | |
| return {k: squeeze(v, 0) for k, v in data.items()} | |
| else: | |
| raise TypeError("not support type in squeeze: {}".format(type(data))) | |
| def get_null_data(template: Any, num: int) -> List[Any]: | |
| """ | |
| Overview: | |
| Get null data given an input template. | |
| Arguments: | |
| - template (:obj:`Any`): The template data. | |
| - num (:obj:`int`): The number of null data items to generate. | |
| Returns: | |
| - output (:obj:`List[Any]`): The generated null data. | |
| Examples: | |
| >>> temp = {'obs': [1, 2, 3], 'action': 1, 'done': False, 'reward': torch.tensor(1.)} | |
| >>> null_data = get_null_data(temp, 2) | |
| >>> assert len(null_data) ==2 | |
| >>> assert null_data[0]['null'] and null_data[0]['done'] | |
| """ | |
| ret = [] | |
| for _ in range(num): | |
| data = copy.deepcopy(template) | |
| data['null'] = True | |
| data['done'] = True | |
| data['reward'].zero_() | |
| ret.append(data) | |
| return ret | |
| def zeros_like(h: Any) -> Any: | |
| """ | |
| Overview: | |
| Generate zero-tensors like the input data. | |
| Arguments: | |
| - h (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict). | |
| Returns: | |
| - output (:obj:`Any`): The output zero-tensors. | |
| Examples (tensor): | |
| >>> t = torch.randn(3, 3) | |
| >>> tt = zeros_like(t) | |
| >>> assert tt.shape == torch.Shape([3, 3]) | |
| >>> assert torch.sum(torch.abs(tt)) < 1e-8 | |
| Examples (list): | |
| >>> t = [torch.randn(3, 3)] | |
| >>> tt = zeros_like(t) | |
| >>> assert tt[0].shape == torch.Shape([3, 3]) | |
| >>> assert torch.sum(torch.abs(tt[0])) < 1e-8 | |
| Examples (dict): | |
| >>> t = {"t": torch.randn(3, 3)} | |
| >>> tt = zeros_like(t) | |
| >>> assert tt["t"].shape == torch.Shape([3, 3]) | |
| >>> assert torch.sum(torch.abs(tt["t"])) < 1e-8 | |
| """ | |
| if isinstance(h, torch.Tensor): | |
| return torch.zeros_like(h) | |
| elif isinstance(h, (list, tuple)): | |
| return [zeros_like(t) for t in h] | |
| elif isinstance(h, dict): | |
| return {k: zeros_like(v) for k, v in h.items()} | |
| else: | |
| raise TypeError("not support type: {}".format(h)) | |