Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ utils / data / datapipes / _hook_iterator.py

import inspect
import functools
from enum import Enum

import torch.autograd


class _SnapshotState(Enum):
    r"""
    These are the snapshotting-related states that IterDataPipes can be in.
    `NotStarted` - allows you to restore a snapshot and create an iterator with reset
    `Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe
    `Iterating` - can restore, will reset if you create a new iterator
    """
    NotStarted = 0
    Restored = 1
    Iterating = 2


def _simplify_obj_name(obj) -> str:
    """
    Simplify the display strings of objects for the purpose of rendering within DataPipe error messages.
    """
    if inspect.isfunction(obj):
        return obj.__name__
    else:
        return repr(obj)


def _generate_input_args_string(obj):
    """
    Generate a string for the input arguments of an object.
    """
    signature = inspect.signature(obj.__class__)
    input_param_names = set()
    for param_name, _ in signature.parameters.items():
        input_param_names.add(param_name)
    result = []
    for name, obj in inspect.getmembers(obj):
        if name in input_param_names:
            result.append((name, _simplify_obj_name(obj)))
    return ', '.join([f'{name}={value}' for name, value in result])


def _generate_iterdatapipe_msg(datapipe):
    return f"{datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"


def _gen_invalid_iterdatapipe_msg(datapipe):
    return ("This iterator has been invalidated because another iterator has been created "
            f"from the same IterDataPipe: {_generate_iterdatapipe_msg(datapipe)}\n"
            "This may be caused multiple references to the same IterDataPipe. We recommend "
            "using `.fork()` if that is necessary.")


_feedback_msg = ("\nFor feedback regarding this single iterator per IterDataPipe constraint, feel free "
                 "to comment on this issue: https://github.com/pytorch/data/issues/45.")


def _check_iterator_valid(datapipe, iterator_id, next_method_exists=False) -> None:
    r"""
    Given an instance of a DataPipe and an iterator ID, check if the IDs match, and if not, raises an exception.
    In the case of ChildDataPipe, the ID gets compared to the one stored in `main_datapipe` as well.
    """
    if next_method_exists:
        # This is the case where `IterDataPipe` has both `__iter__` and `__next__`.
        # The `_valid_iterator_id` should either be never set (`None`), or set by at most one
        # iterator (`0`). Otherwise, it means there are multiple iterators.
        if datapipe._valid_iterator_id is not None and datapipe._valid_iterator_id != 0:
            extra_msg = "\nNote that this exception is raised inside your IterDataPipe's a `__next__` method"
            raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + extra_msg + _feedback_msg)
    elif hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True:
        if hasattr(datapipe, "_check_valid_iterator_id"):
            if not datapipe._check_valid_iterator_id(iterator_id):
                raise RuntimeError("This iterator has been invalidated, because a new iterator has been created "
                                   f"from one of the ChildDataPipes of "
                                   f"{_generate_iterdatapipe_msg(datapipe.main_datapipe)}." + _feedback_msg)
        else:
            raise RuntimeError("ChildDataPipe must have method `_check_valid_iterator_id`.")
    elif datapipe._valid_iterator_id != iterator_id:
        raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + _feedback_msg)


def _set_datapipe_valid_iterator_id(datapipe):
    r"""
    Given a DataPipe, updates its valid iterator ID and reset the DataPipe.
    """
    if hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True:
        if hasattr(datapipe, "_set_main_datapipe_valid_iterator_id"):
            datapipe._set_main_datapipe_valid_iterator_id()  # reset() is called within this method when appropriate
        else:
            raise RuntimeError("ChildDataPipe must have method `_set_main_datapipe_valid_iterator_id`.")
    else:
        if datapipe._valid_iterator_id is None:
            datapipe._valid_iterator_id = 0
        else:
            datapipe._valid_iterator_id += 1
        datapipe.reset()
    return datapipe._valid_iterator_id


def hook_iterator(namespace, profile_name):
    r"""
    Hook that is applied to all `__iter__` of metaclass `_DataPipeMeta`. This is done for the purpose of
    profiling and checking if an iterator is still valid.
    """
    def profiler_record_fn_context():
        return torch.autograd.profiler.record_function(profile_name)

    class IteratorDecorator:
        r"""
        Wrap the iterator and modifying its `__next__` method. This decorator is applied to
        DataPipes of which `__iter__` method is NOT a generator function. Those `__iter__`
        method commonly returns `self` but not necessarily.
        """
        def __init__(self, iterator, source_dp, iterator_id, has_next_method):
            self.iterator = iterator
            self.source_dp = source_dp
            self.iterator_id = iterator_id
            self._profiler_enabled = torch.autograd._profiler_enabled()
            # Check if `__iter__` returns `self` and `DataPipe` has `__next__`
            self.self_and_has_next_method = self.iterator is self.source_dp and has_next_method

        def __iter__(self):
            return self

        def _get_next(self):
            r"""
            Return next with logic related to iterator validity, profiler, and incrementation of samples yielded.
            """
            _check_iterator_valid(self.source_dp, self.iterator_id)
            result = next(self.iterator)
            if not self.self_and_has_next_method:
                self.source_dp._number_of_samples_yielded += 1
            return result

        def __next__(self):
            # TODO: Add try-except to in-place reduce traceback from the Exception
            # See: https://github.com/pytorch/data/issues/284
            if self._profiler_enabled:
                with profiler_record_fn_context():
                    return self._get_next()
            else:  # Decided against using `contextlib.nullcontext` for performance reasons
                return self._get_next()

        def __getattr__(self, name):
            return getattr(self.iterator, name)

    func = namespace['__iter__']

    # ``__iter__`` of IterDataPipe is a generator function
    if inspect.isgeneratorfunction(func):
        @functools.wraps(func)
        def wrap_generator(*args, **kwargs):
            gen = func(*args, **kwargs)
            datapipe = args[0]
            if datapipe._fast_forward_iterator:
                it = datapipe._fast_forward_iterator
                datapipe._fast_forward_iterator = None
                datapipe._snapshot_state = _SnapshotState.Iterating
                while True:
                    try:
                        yield next(it)
                    except StopIteration:
                        return
            iterator_id = _set_datapipe_valid_iterator_id(datapipe)  # This ID is tied to each created iterator
            _profiler_enabled = torch.autograd._profiler_enabled()
            try:
                if _profiler_enabled:
                    with profiler_record_fn_context():
                        response = gen.send(None)
                else:
                    response = gen.send(None)

                while True:
                    datapipe._number_of_samples_yielded += 1
                    request = yield response
                    # Pass through here every time `__next__` is called
                    if _profiler_enabled:
                        with profiler_record_fn_context():
                            _check_iterator_valid(datapipe, iterator_id)
                            response = gen.send(request)
                    else:  # Decided against using `contextlib.nullcontext` for performance reasons
                        _check_iterator_valid(datapipe, iterator_id)
                        response = gen.send(request)
            except StopIteration as e:
                return
            except Exception as e:
                # TODO: Simplify the traceback message to skip over `response = gen.send(None)`
                #       Part of https://github.com/pytorch/data/issues/284
                datapipe = args[0]
                msg = "thrown by __iter__ of"
                single_iterator_msg = "single iterator per IterDataPipe constraint"
                if hasattr(e.args, '__len__'):
                    full_msg = f"{msg} {datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})"
                    if len(e.args) == 0 or not isinstance(e.args[0], str):  # If an exception message doesn't exist
                        e.args = (f'\nThis exception is {full_msg}',)
                    elif msg not in e.args[0] and single_iterator_msg not in e.args[0]:
                        e.args = (e.args[0] + f'\nThis exception is {full_msg}',) + e.args[1:]
                raise

        namespace['__iter__'] = wrap_generator
    else:  # ``__iter__`` of IterDataPipe is NOT a generator function
        # IterDataPipe is an iterator with both ``__iter__`` and ``__next__``
        # And ``__iter__`` may or may not return `self`
        if '__next__' in namespace:  # If `__next__` exists, put a wrapper around it
            next_func = namespace['__next__']

            @functools.wraps(next_func)
            def wrap_next(*args, **kwargs):
                if torch.autograd._profiler_enabled():
                    with profiler_record_fn_context():
                        result = next_func(*args, **kwargs)
                else:
                    result = next_func(*args, **kwargs)
                datapipe = args[0]
                datapipe._number_of_samples_yielded += 1
                return result

            namespace['__next__'] = wrap_next

            # Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but
            # the user will be violating the iterator protocol. Potential issue:
            # 1. Valid iterator ID may not update or checked properly
            # 2. The number of samples yielded will be miscounted

        # Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators
        @functools.wraps(func)
        def wrap_iter(*args, **kwargs):
            iter_ret = func(*args, **kwargs)
            datapipe = args[0]
            datapipe._snapshot_state = _SnapshotState.Iterating
            if datapipe._fast_forward_iterator:
                iter_ret = datapipe._fast_forward_iterator
                datapipe._fast_forward_iterator = None
                return iter_ret
            iterator_id = _set_datapipe_valid_iterator_id(datapipe)  # This ID is tied to each created iterator
            return IteratorDecorator(iter_ret, datapipe, iterator_id, '__next__' in namespace)

        namespace['__iter__'] = wrap_iter