import copy
import warnings
from torch.utils.data.datapipes.datapipe import MapDataPipe
__all__ = ["SequenceWrapperMapDataPipe", ]
class SequenceWrapperMapDataPipe(MapDataPipe):
r"""
Wraps a sequence object into a MapDataPipe.
Args:
sequence: Sequence object to be wrapped into an MapDataPipe
deepcopy: Option to deepcopy input sequence object
.. note::
If ``deepcopy`` is set to False explicitly, users should ensure
that data pipeline doesn't contain any in-place operations over
the iterable instance, in order to prevent data inconsistency
across iterations.
Example:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.map import SequenceWrapper
>>> dp = SequenceWrapper(range(10))
>>> list(dp)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> dp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
>>> dp['a']
100
"""
def __init__(self, sequence, deepcopy=True):
if deepcopy:
try:
self.sequence = copy.deepcopy(sequence)
except TypeError:
warnings.warn(
"The input sequence can not be deepcopied, "
"please be aware of in-place modification would affect source data"
)
self.sequence = sequence
else:
self.sequence = sequence
def __getitem__(self, index):
return self.sequence[index]
def __len__(self):
return len(self.sequence)