Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / cached_reader.py

## @package cached_reader
# Module caffe2.python.cached_reader





import os

from caffe2.python import core
from caffe2.python.db_file_reader import DBFileReader
from caffe2.python.pipeline import pipe
from caffe2.python.task import Cluster, TaskGroup


class CachedReader(DBFileReader):

    default_name_suffix = 'cached_reader'

    """Reader with persistent in-file cache.

    Example usage:
    cached_reader = CachedReader(
        reader,
        db_path='/tmp/cache.db',
        db_type='LevelDB',
    )
    build_cache_step = cached_reader.build_cache_step()
    with LocalSession() as session:
        session.run(build_cache_step)

    Every time new CachedReader is created, it's expected that
    db_path exists before calling .setup_ex(...) and .read(...).

    If db_path doesn't exist, it's expected build_cache_step to be called
    first to build a cache at db_path.

    build_cache_step will check existence of provided db_path and in case
    it's missing will initialize it by reading data from original reader.
    All consequent attempts to read will ignore original reader
    (i.e. no additional data will be read from it).

    Args:
        original_reader: Reader.
            If provided, it's the original reader used to build the cache file.
        db_path: str.

    Optional Args:
        db_type: str. DB type of file. A db_type is registed by
            `REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`.
            Default to 'LevelDB'.
        name: str or None. Name of CachedReader.
            Optional name to prepend to blobs that will store the data.
            Default to '<db_name>_<default_name_suffix>'.
        batch_size: int.
            How many examples are read for each time the read_net is run.
            Defaults to 100.
        loop_over: bool.
            If True given, will go through examples in random order endlessly.
            Defaults to False.
    """
    def __init__(
        self,
        original_reader,
        db_path,
        db_type='LevelDB',
        name=None,
        batch_size=100,
        loop_over=False,
    ):
        assert original_reader is not None, "original_reader can't be None"
        self.original_reader = original_reader

        super(CachedReader, self).__init__(
            db_path,
            db_type,
            name,
            batch_size,
            loop_over,
        )

    def _init_reader_schema(self, *args, **kwargs):
        """Prepare the reader schema.

            Since an original reader is given,
            use it's schema as ground truth.

            Returns:
                schema: schema.Struct. Used in Reader.__init__(...).
        """
        return self.original_reader._schema

    def build_cache_step(self, overwrite=False):
        """Build a step for generating cache DB file.

            If self.db_path exists and not overwritting, build an empty step.
            Overwise, build a step as follows.
            Pipe original reader to the _DatasetWriter,
            so that dataset field blobs are populated.
            Then save these blobs into a file.

            Args:
                overwrite: bool. If true, ignore the existing file
                    and build a new one overwritting the existing one anyway.

            Returns:
                build_cache_step: ExecutionStep.
                    The step to be run for building a cache DB file.
        """
        if os.path.exists(self.db_path) and not overwrite:
            # cache already exists, no need to rebuild it
            return core.execution_step('build_step', [])

        init_net = core.Net('init')
        self._init_field_blobs_as_empty(init_net)
        with Cluster(), core.NameScope(self.name), TaskGroup() as copy_tg:
            pipe(self.original_reader, self.ds.writer(), num_threads=16)
            copy_step = copy_tg.to_task().get_step()
        save_net = core.Net('save')
        self._save_field_blobs_to_db_file(save_net)

        return core.execution_step('build_cache', [init_net, copy_step, save_net])

    def _save_field_blobs_to_db_file(self, net):
        """Save dataset field blobs to a DB file at db_path"""
        net.Save(
            self.ds.get_blobs(),
            [],
            db=self.db_path,
            db_type=self.db_type,
            blob_name_overrides=self.ds.field_names(),
            absolute_path=True,
        )