Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ utils / data / datapipes / iter / callable.py

import functools
from collections import namedtuple

from typing import Callable, Iterator, Sized, TypeVar, Optional, Union, Any, Dict, List

from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data._utils.collate import default_collate
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.utils.common import (_check_unpickable_fn,
                                                     validate_input_col)

__all__ = [
    "CollatorIterDataPipe",
    "MapperIterDataPipe",
]

T_co = TypeVar("T_co", covariant=True)


@functional_datapipe("map")
class MapperIterDataPipe(IterDataPipe[T_co]):
    r"""
    Applies a function over each item from the source DataPipe (functional name: ``map``).
    The function can be any regular Python function or partial object. Lambda
    function is not recommended as it is not supported by pickle.

    Args:
        datapipe: Source Iterable DataPipe
        fn: Function being applied over each item
        input_col: Index or indices of data which ``fn`` is applied, such as:

            - ``None`` as default to apply ``fn`` to the data directly.
            - Integer(s) is used for list/tuple.
            - Key(s) is used for dict.

        output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified
            only when ``input_col`` is not ``None``

            - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with
              multiple indices, the left-most one is used, and other indices will be removed.
            - Integer is used for list/tuple. ``-1`` represents to append result at the end.
            - Key is used for dict. New key is acceptable.

    Example:
        >>> # xdoctest: +SKIP
        >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
        >>> def add_one(x):
        ...     return x + 1
        >>> dp = IterableWrapper(range(10))
        >>> map_dp_1 = dp.map(add_one)  # Invocation via functional form is preferred
        >>> list(map_dp_1)
        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        >>> # We discourage the usage of `lambda` functions as they are not serializable with `pickle`
        >>> # Use `functools.partial` or explicitly define the function instead
        >>> map_dp_2 = Mapper(dp, lambda x: x + 1)
        >>> list(map_dp_2)
        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    """
    datapipe: IterDataPipe
    fn: Callable

    def __init__(
        self,
        datapipe: IterDataPipe,
        fn: Callable,
        input_col=None,
        output_col=None,
    ) -> None:
        super().__init__()
        self.datapipe = datapipe

        _check_unpickable_fn(fn)
        self.fn = fn  # type: ignore[assignment]

        self.input_col = input_col
        if input_col is None and output_col is not None:
            raise ValueError("`output_col` must be None when `input_col` is None.")
        if isinstance(output_col, (list, tuple)):
            if len(output_col) > 1:
                raise ValueError("`output_col` must be a single-element list or tuple")
            output_col = output_col[0]
        self.output_col = output_col
        validate_input_col(fn, input_col)

    def _apply_fn(self, data):
        if self.input_col is None and self.output_col is None:
            return self.fn(data)

        if self.input_col is None:
            res = self.fn(data)
        elif isinstance(self.input_col, (list, tuple)):
            args = tuple(data[col] for col in self.input_col)
            res = self.fn(*args)
        else:
            res = self.fn(data[self.input_col])

        # Copy tuple to list and run in-place modification because tuple is immutable.
        if isinstance(data, tuple):
            t_flag = True
            data = list(data)
        else:
            t_flag = False

        if self.output_col is None:
            if isinstance(self.input_col, (list, tuple)):
                data[self.input_col[0]] = res
                for idx in sorted(self.input_col[1:], reverse=True):
                    del data[idx]
            else:
                data[self.input_col] = res
        else:
            if self.output_col == -1:
                data.append(res)
            else:
                data[self.output_col] = res

        # Convert list back to tuple
        return tuple(data) if t_flag else data

    def __iter__(self) -> Iterator[T_co]:
        for data in self.datapipe:
            yield self._apply_fn(data)

    def __len__(self) -> int:
        if isinstance(self.datapipe, Sized):
            return len(self.datapipe)
        raise TypeError(
            "{} instance doesn't have valid length".format(type(self).__name__)
        )


def _collate_helper(conversion, item):
    # TODO(VitalyFedyunin): Verify that item is any sort of batch
    if len(item.items) > 1:
        # TODO(VitalyFedyunin): Compact all batch dataframes into one
        raise Exception("Only supports one DataFrame per batch")
    df = item[0]
    columns_name = df_wrapper.get_columns(df)
    tuple_names: List = []
    tuple_values: List = []

    for name in conversion.keys():
        if name not in columns_name:
            raise Exception("Conversion keys missmatch")

    for name in columns_name:
        if name in conversion:
            if not callable(conversion[name]):
                raise Exception('Collate (DF)DataPipe requires callable as dict values')
            collation_fn = conversion[name]
        else:
            # TODO(VitalyFedyunin): Add default collation into df_wrapper
            try:
                import torcharrow.pytorch as tap  # type: ignore[import]
                collation_fn = tap.rec.Default()
            except Exception as e:
                raise Exception("unable to import default collation function from the TorchArrow") from e

        tuple_names.append(str(name))
        value = collation_fn(df[name])
        tuple_values.append(value)

    # TODO(VitalyFedyunin): We can dynamically extract types from the tuple_values here
    # TODO(VitalyFedyunin): Instead of ignoring mypy error, make sure tuple_names is not empty
    tpl_cls = namedtuple("CollateResult", tuple_names)  # type: ignore[misc]
    tuple = tpl_cls(*tuple_values)
    return tuple


@functional_datapipe("collate")
class CollatorIterDataPipe(MapperIterDataPipe):
    r"""
    Collates samples from DataPipe to Tensor(s) by a custom collate function (functional name: ``collate``).
    By default, it uses :func:`torch.utils.data.default_collate`.

    .. note::
        While writing a custom collate function, you can import :func:`torch.utils.data.default_collate` for the
        default behavior and `functools.partial` to specify any additional arguments.

    Args:
        datapipe: Iterable DataPipe being collated
        collate_fn: Customized collate function to collect and combine data or a batch of data.
            Default function collates to Tensor(s) based on data type.

    Example:
        >>> # xdoctest: +SKIP
        >>> # Convert integer data to float Tensor
        >>> class MyIterDataPipe(torch.utils.data.IterDataPipe):
        ...     def __init__(self, start, end):
        ...         super(MyIterDataPipe).__init__()
        ...         assert end > start, "this example code only works with end >= start"
        ...         self.start = start
        ...         self.end = end
        ...
        ...     def __iter__(self):
        ...         return iter(range(self.start, self.end))
        ...
        ...     def __len__(self):
        ...         return self.end - self.start
        ...
        >>> ds = MyIterDataPipe(start=3, end=7)
        >>> print(list(ds))
        [3, 4, 5, 6]
        >>> def collate_fn(batch):
        ...     return torch.tensor(batch, dtype=torch.float)
        ...
        >>> collated_ds = CollateIterDataPipe(ds, collate_fn=collate_fn)
        >>> print(list(collated_ds))
        [tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
    """

    def __init__(
        self,
        datapipe: IterDataPipe,
        conversion: Optional[
            Union[
            Callable[..., Any],
            Dict[Union[str, Any], Union[Callable, Any]],
            ]
        ] = default_collate,
        collate_fn: Optional[Callable] = None,
    ) -> None:
        # TODO(VitalyFedyunin): Replace `Callable[..., Any]` with `Callable[[IColumn], Any]`
        # TODO(VitalyFedyunin): Replace with `Dict[Union[str, IColumn], Union[Callable, Enum]]`
        if collate_fn is not None:
            super().__init__(datapipe, fn=collate_fn)
        else:
            if callable(conversion):
                super().__init__(datapipe, fn=conversion)
            else:
                # TODO(VitalyFedyunin): Validate passed dictionary
                collate_fn = functools.partial(_collate_helper, conversion)
                super().__init__(datapipe, fn=collate_fn)