Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ utils / data / datapipes / iter / batch.py

import warnings
from torch.utils.data import IterDataPipe
from typing import TypeVar, Optional, Iterator, List, Sized, Callable

T_co = TypeVar('T_co', covariant=True)


class BatchIterDataPipe(IterDataPipe[List[T_co]]):
    r""" :class:`BatchIterDataPipe`.

    Iterable DataPipe to create mini-batches of data. An outer dimension will be added as
    `batch_size` if `drop_last` is set to `True`, or `length % batch_size` for the
    last batch if `drop_last` is set to `False`.
    args:
        datapipe: Iterable DataPipe being batched
        batch_size: The size of each batch
        drop_last: Option to drop the last batch if it's not full
    """
    datapipe: IterDataPipe[T_co]
    batch_size: int
    drop_last: bool
    length: Optional[int]

    def __init__(self,
                 datapipe: IterDataPipe[T_co],
                 *,
                 batch_size: int,
                 drop_last: bool = False,
                 ) -> None:
        assert batch_size > 0, "Batch size is required to be larger than 0!"
        super().__init__()
        self.datapipe = datapipe
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.length = None

    def __iter__(self) -> Iterator[List[T_co]]:
        batch: List[T_co] = []
        for x in self.datapipe:
            batch.append(x)
            if len(batch) == self.batch_size:
                yield batch
                batch.clear()
        if len(batch) > 0:
            if not self.drop_last:
                yield batch
            batch.clear()

    def __len__(self) -> int:
        if self.length is not None:
            return self.length
        if isinstance(self.datapipe, Sized) and len(self.datapipe) >= 0:
            if self.drop_last:
                self.length = len(self.datapipe) // self.batch_size
            else:
                self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size
            return self.length
        raise NotImplementedError


class BucketBatchIterDataPipe(IterDataPipe[List[T_co]]):
    r""" :class:`BucketBatchIterDataPipe`.

    Iterable DataPipe to create mini-batches of data from sorted bucket. An outer
    dimension will be added as `batch_size` if `drop_last` is set to `True`,
    or `length % batch_size` for the last batch if `drop_last` is set to `False`.
        args:
        datapipe: Iterable DataPipe being batched
        batch_size: The size of each batch
        drop_last: Option to drop the last batch if it's not full
        bucket_size_mul: The multiplier to specify the size of bucket
        sort_key: Callable to specify the comparison key for sorting within bucket
    """
    datapipe: IterDataPipe[T_co]
    batch_size: int
    drop_last: bool
    bucket_size_mul: int
    sort_key: Optional[Callable]
    length: Optional[int]

    def __init__(self,
                 datapipe: IterDataPipe[T_co],
                 *,
                 batch_size: int,
                 drop_last: bool = False,
                 bucket_size_mul: int = 100,
                 sort_key: Optional[Callable] = None,
                 ) -> None:
        assert batch_size > 0, "Batch size is required to be larger than 0!"
        super().__init__()
        self.datapipe = datapipe
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.bucket_size = batch_size * bucket_size_mul
        self.sort_key = sort_key
        if sort_key is not None and sort_key.__name__ == '<lambda>':
            warnings.warn("Lambda function is not supported for pickle, "
                          "please use regular python function instead.")
        self.bucket_ds = BatchIterDataPipe(datapipe, batch_size=self.bucket_size, drop_last=False)
        self.length = None

    def __iter__(self) -> Iterator[List[T_co]]:
        # Bucket without sorting remains same order, directly returns BatchDataset
        if self.sort_key is None:
            yield from BatchIterDataPipe(self.datapipe, batch_size=self.batch_size, drop_last=self.drop_last)
        else:
            bucket: List[T_co]
            batch: List[T_co] = []
            for bucket in self.bucket_ds:
                # In-place sort within bucket
                bucket.sort(key=self.sort_key)
                for start in range(0, len(bucket), self.batch_size):
                    batch = bucket[start: start + self.batch_size]
                    if len(batch) == self.batch_size or not self.drop_last:
                        yield batch

    def __len__(self) -> int:
        if self.length is not None:
            return self.length
        if isinstance(self.datapipe, Sized) and len(self.datapipe) >= 0:
            if self.drop_last:
                self.length = len(self.datapipe) // self.batch_size
            else:
                self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size
            return self.length
        raise NotImplementedError