Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| # Copyright (c) Megvii, Inc. and its affiliates. | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data.sampler import BatchSampler as torchBatchSampler | |
| from torch.utils.data.sampler import Sampler | |
| import itertools | |
| from typing import Optional | |
| class YoloBatchSampler(torchBatchSampler): | |
| """ | |
| This batch sampler will generate mini-batches of (dim, index) tuples from another sampler. | |
| It works just like the :class:`torch.utils.data.sampler.BatchSampler`, | |
| but it will prepend a dimension, whilst ensuring it stays the same across one mini-batch. | |
| """ | |
| def __init__(self, *args, input_dimension=None, mosaic=True, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.input_dim = input_dimension | |
| self.new_input_dim = None | |
| self.mosaic = mosaic | |
| def __iter__(self): | |
| self.__set_input_dim() | |
| for batch in super().__iter__(): | |
| yield [(self.input_dim, idx, self.mosaic) for idx in batch] | |
| self.__set_input_dim() | |
| def __set_input_dim(self): | |
| """ This function randomly changes the the input dimension of the dataset. """ | |
| if self.new_input_dim is not None: | |
| self.input_dim = (self.new_input_dim[0], self.new_input_dim[1]) | |
| self.new_input_dim = None | |
| class InfiniteSampler(Sampler): | |
| """ | |
| In training, we only care about the "infinite stream" of training data. | |
| So this sampler produces an infinite stream of indices and | |
| all workers cooperate to correctly shuffle the indices and sample different indices. | |
| The samplers in each worker effectively produces `indices[worker_id::num_workers]` | |
| where `indices` is an infinite stream of indices consisting of | |
| `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) | |
| or `range(size) + range(size) + ...` (if shuffle is False) | |
| """ | |
| def __init__( | |
| self, | |
| size: int, | |
| shuffle: bool = True, | |
| seed: Optional[int] = 0, | |
| rank=0, | |
| world_size=1, | |
| ): | |
| """ | |
| Args: | |
| size (int): the total number of data of the underlying dataset to sample from | |
| shuffle (bool): whether to shuffle the indices or not | |
| seed (int): the initial seed of the shuffle. Must be the same | |
| across all workers. If None, will use a random seed shared | |
| among workers (require synchronization among all workers). | |
| """ | |
| self._size = size | |
| assert size > 0 | |
| self._shuffle = shuffle | |
| self._seed = int(seed) | |
| if dist.is_available() and dist.is_initialized(): | |
| self._rank = dist.get_rank() | |
| self._world_size = dist.get_world_size() | |
| else: | |
| self._rank = rank | |
| self._world_size = world_size | |
| def __iter__(self): | |
| start = self._rank | |
| yield from itertools.islice( | |
| self._infinite_indices(), start, None, self._world_size | |
| ) | |
| def _infinite_indices(self): | |
| g = torch.Generator() | |
| g.manual_seed(self._seed) | |
| while True: | |
| if self._shuffle: | |
| yield from torch.randperm(self._size, generator=g) | |
| else: | |
| yield from torch.arange(self._size) | |
| def __len__(self): | |
| return self._size // self._world_size | |