import random
import torch
from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
from typing import Iterator, List, Optional, TypeVar
__all__ = ["ShufflerIterDataPipe", ]
T_co = TypeVar('T_co', covariant=True)
# @functional_datapipe('shuffle')
class ShufflerIterDataPipe(IterDataPipe[T_co]):
r"""
Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``).
When it is used with :class:`~torch.utils.data.DataLoader`, the methods to
set up random seed are different based on :attr:`num_workers`.
For single-process mode (:attr:`num_workers == 0`), the random seed is set before
the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed
for each worker process.
Args:
datapipe: MapDataPipe being shuffled
indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.map import SequenceWrapper
>>> dp = SequenceWrapper(range(10))
>>> shuffle_dp = dp.shuffle().set_seed(0)
>>> list(shuffle_dp)
[7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
>>> list(shuffle_dp)
[6, 1, 9, 5, 2, 4, 7, 3, 8, 0]
>>> # Reset seed for Shuffler
>>> shuffle_dp = shuffle_dp.set_seed(0)
>>> list(shuffle_dp)
[7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
Note:
Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an
``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to
the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order
of data during data-processing.
"""
datapipe: MapDataPipe[T_co]
_enabled: bool
_seed: Optional[int]
_rng: random.Random
def __init__(self,
datapipe: MapDataPipe[T_co],
*,
indices: Optional[List] = None,
) -> None:
super().__init__()
self.datapipe = datapipe
self.indices = list(range(len(datapipe))) if indices is None else indices
self._enabled = True
self._seed = None
self._rng = random.Random()
self._shuffled_indices: List = self.indices
def set_shuffle(self, shuffle=True):
self._enabled = shuffle
return self
def set_seed(self, seed: int):
self._seed = seed
return self
def __iter__(self) -> Iterator[T_co]:
if not self._enabled:
for idx in self.indices:
yield self.datapipe[idx]
else:
while self._shuffled_indices:
idx = self._shuffled_indices.pop()
yield self.datapipe[idx]
def reset(self) -> None:
if self._enabled and self._seed is None:
self._seed = int(torch.empty((), dtype=torch.int64).random_().item())
self._rng.seed(self._seed)
self._seed = None
self._shuffled_indices = self._rng.sample(self.indices, len(self.indices))
def __len__(self) -> int:
return len(self.datapipe)
def __getstate__(self):
state = (
self.datapipe,
self.indices,
self._enabled,
self._seed,
self._rng.getstate(),
self._shuffled_indices,
self._valid_iterator_id,
self._number_of_samples_yielded,
)
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(state)
return state
def __setstate__(self, state):
(
self.datapipe,
self.indices,
self._enabled,
self._seed,
rng_state,
self._shuffled_indices,
self._valid_iterator_id,
self._number_of_samples_yielded,
) = state
self._rng = random.Random()
self._rng.setstate(rng_state)
MapDataPipe.register_datapipe_as_function("shuffle", ShufflerIterDataPipe)