Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / dataio_test.py






from caffe2.python.dataio import (
    CompositeReader,
    CompositeReaderBuilder,
    ReaderBuilder,
    ReaderWithDelay,
    ReaderWithLimit,
    ReaderWithTimeLimit,
)
from caffe2.python.dataset import Dataset
from caffe2.python.db_file_reader import DBFileReader
from caffe2.python.pipeline import pipe
from caffe2.python.schema import Struct, NewRecord, FeedRecord
from caffe2.python.session import LocalSession
from caffe2.python.task import TaskGroup, final_output, WorkspaceType
from caffe2.python.test_util import TestCase
from caffe2.python.cached_reader import CachedReader
from caffe2.python import core, workspace, schema
from caffe2.python.net_builder import ops

import numpy as np
import numpy.testing as npt
import os
import shutil
import unittest
import tempfile


def make_source_dataset(ws, size=100, offset=0, name=None):
    name = name or "src"
    src_init = core.Net("{}_init".format(name))
    with core.NameScope(name):
        src_values = Struct(('label', np.array(range(offset, offset + size))))
        src_blobs = NewRecord(src_init, src_values)
        src_ds = Dataset(src_blobs, name=name)
        FeedRecord(src_blobs, src_values, ws)
    ws.run(src_init)
    return src_ds


def make_destination_dataset(ws, schema, name=None):
    name = name or 'dst'
    dst_init = core.Net('{}_init'.format(name))
    with core.NameScope(name):
        dst_ds = Dataset(schema, name=name)
        dst_ds.init_empty(dst_init)
    ws.run(dst_init)
    return dst_ds


class TestReaderBuilder(ReaderBuilder):
    def __init__(self, name, size, offset):
        self._schema = schema.Struct(
            ('label', schema.Scalar()),
        )
        self._name = name
        self._size = size
        self._offset = offset
        self._src_ds = None

    def schema(self):
        return self._schema

    def setup(self, ws):
        self._src_ds = make_source_dataset(ws, offset=self._offset, size=self._size,
                                    name=self._name)
        return {}

    def new_reader(self, **kwargs):
        return self._src_ds


class TestCompositeReader(TestCase):
    @unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins')
    def test_composite_reader(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        num_srcs = 3
        names = ["src_{}".format(i) for i in range(num_srcs)]
        size = 100
        offsets = [i * size for i in range(num_srcs)]
        src_dses = [make_source_dataset(ws, offset=offset, size=size, name=name)
                    for (name, offset) in zip(names, offsets)]

        data = [ws.fetch_blob(str(src.field_blobs[0])) for src in src_dses]
        # Sanity check we didn't overwrite anything
        for d, offset in zip(data, offsets):
            npt.assert_array_equal(d, range(offset, offset + size))

        # Make an identically-sized empty destination dataset
        dst_ds_schema = schema.Struct(
            *[
                (name, src_ds.content().clone_schema())
                for name, src_ds in zip(names, src_dses)
            ]
        )
        dst_ds = make_destination_dataset(ws, dst_ds_schema)

        with TaskGroup() as tg:
            reader = CompositeReader(names,
                                     [src_ds.reader() for src_ds in src_dses])
            pipe(reader, dst_ds.writer(), num_runtime_threads=3)
        session.run(tg)

        for i in range(num_srcs):
            written_data = sorted(
                ws.fetch_blob(str(dst_ds.content()[names[i]].label())))
            npt.assert_array_equal(data[i], written_data, "i: {}".format(i))

    @unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins')
    def test_composite_reader_builder(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        num_srcs = 3
        names = ["src_{}".format(i) for i in range(num_srcs)]
        size = 100
        offsets = [i * size for i in range(num_srcs)]
        src_ds_builders = [
            TestReaderBuilder(offset=offset, size=size, name=name)
            for (name, offset) in zip(names, offsets)
        ]

        # Make an identically-sized empty destination dataset
        dst_ds_schema = schema.Struct(
            *[
                (name, src_ds_builder.schema())
                for name, src_ds_builder in zip(names, src_ds_builders)
            ]
        )
        dst_ds = make_destination_dataset(ws, dst_ds_schema)

        with TaskGroup() as tg:
            reader_builder = CompositeReaderBuilder(
                names, src_ds_builders)
            reader_builder.setup(ws=ws)
            pipe(reader_builder.new_reader(), dst_ds.writer(),
                 num_runtime_threads=3)
        session.run(tg)

        for name, offset in zip(names, offsets):
            written_data = sorted(
                ws.fetch_blob(str(dst_ds.content()[name].label())))
            npt.assert_array_equal(range(offset, offset + size), written_data,
                                   "name: {}".format(name))


class TestReaderWithLimit(TestCase):
    def test_runtime_threads(self):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)
        src_ds = make_source_dataset(ws)
        totals = [None] * 3

        def proc(rec):
            # executed once
            with ops.task_init():
                counter1 = ops.CreateCounter([], ['global_counter'])
                counter2 = ops.CreateCounter([], ['global_counter2'])
                counter3 = ops.CreateCounter([], ['global_counter3'])
            # executed once per thread
            with ops.task_instance_init():
                task_counter = ops.CreateCounter([], ['task_counter'])
            # executed on each iteration
            ops.CountUp(counter1)
            ops.CountUp(task_counter)
            # executed once per thread
            with ops.task_instance_exit():
                with ops.loop(ops.RetrieveCount(task_counter)):
                    ops.CountUp(counter2)
                ops.CountUp(counter3)
            # executed once
            with ops.task_exit():
                totals[0] = final_output(ops.RetrieveCount(counter1))
                totals[1] = final_output(ops.RetrieveCount(counter2))
                totals[2] = final_output(ops.RetrieveCount(counter3))
            return rec

        # Read full data set from original reader
        with TaskGroup() as tg:
            pipe(src_ds.reader(), num_runtime_threads=8, processor=proc)
        session.run(tg)
        self.assertEqual(totals[0].fetch(), 100)
        self.assertEqual(totals[1].fetch(), 100)
        self.assertEqual(totals[2].fetch(), 8)

        # Read with a count-limited reader
        with TaskGroup() as tg:
            q1 = pipe(src_ds.reader(), num_runtime_threads=2)
            q2 = pipe(
                ReaderWithLimit(q1.reader(), num_iter=25),
                num_runtime_threads=3)
            pipe(q2, processor=proc, num_runtime_threads=6)
        session.run(tg)
        self.assertEqual(totals[0].fetch(), 25)
        self.assertEqual(totals[1].fetch(), 25)
        self.assertEqual(totals[2].fetch(), 6)

    def _test_limit_reader_init_shared(self, size):
        ws = workspace.C.Workspace()
        session = LocalSession(ws)

        # Make source dataset
        src_ds = make_source_dataset(ws, size=size)

        # Make an identically-sized empty destination Dataset
        dst_ds = make_destination_dataset(ws, src_ds.content().clone_schema())

        return ws, session, src_ds, dst_ds

    def _test_limit_reader_shared(self, reader_class, size, expected_read_len,
                                  expected_read_len_threshold,
                                  expected_finish, num_threads, read_delay,
                                  **limiter_args):
        ws, session, src_ds, dst_ds = \
            self._test_limit_reader_init_shared(size)

        # Read without limiter
        # WorkspaceType.GLOBAL is required because we are fetching
        # reader.data_finished() after the TaskGroup finishes.
        with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg:
            if read_delay > 0:
                reader = reader_class(ReaderWithDelay(src_ds.reader(),
                                                      read_delay),
                                      **limiter_args)
            else:
                reader = reader_class(src_ds.reader(), **limiter_args)
            pipe(reader, dst_ds.writer(), num_runtime_threads=num_threads)
        session.run(tg)
        read_len = len(sorted(ws.blobs[str(dst_ds.content().label())].fetch()))

        # Do a fuzzy match (expected_read_len +/- expected_read_len_threshold)
        # to eliminate flakiness for time-limited tests
        self.assertGreaterEqual(
            read_len,
            expected_read_len - expected_read_len_threshold)
        self.assertLessEqual(
            read_len,
            expected_read_len + expected_read_len_threshold)
        self.assertEqual(
            sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
            list(range(read_len))
        )
        self.assertEqual(ws.blobs[str(reader.data_finished())].fetch(),
                         expected_finish)

    def test_count_limit_reader_without_limit(self):
        # No iter count specified, should read all records.
        self._test_limit_reader_shared(ReaderWithLimit,
                                       size=100,
                                       expected_read_len=100,
                                       expected_read_len_threshold=0,
                                       expected_finish=True,
                                       num_threads=8,
                                       read_delay=0,
                                       num_iter=None)

    def test_count_limit_reader_with_zero_limit(self):
        # Zero iter count specified, should read 0 records.
        self._test_limit_reader_shared(ReaderWithLimit,
                                       size=100,
                                       expected_read_len=0,
                                       expected_read_len_threshold=0,
                                       expected_finish=False,
                                       num_threads=8,
                                       read_delay=0,
                                       num_iter=0)

    def test_count_limit_reader_with_low_limit(self):
        # Read with limit smaller than size of dataset
        self._test_limit_reader_shared(ReaderWithLimit,
                                       size=100,
                                       expected_read_len=10,
                                       expected_read_len_threshold=0,
                                       expected_finish=False,
                                       num_threads=8,
                                       read_delay=0,
                                       num_iter=10)

    def test_count_limit_reader_with_high_limit(self):
        # Read with limit larger than size of dataset
        self._test_limit_reader_shared(ReaderWithLimit,
                                       size=100,
                                       expected_read_len=100,
                                       expected_read_len_threshold=0,
                                       expected_finish=True,
                                       num_threads=8,
                                       read_delay=0,
                                       num_iter=110)

    def test_time_limit_reader_without_limit(self):
        # No duration specified, should read all records.
        self._test_limit_reader_shared(ReaderWithTimeLimit,
                                       size=100,
                                       expected_read_len=100,
                                       expected_read_len_threshold=0,
                                       expected_finish=True,
                                       num_threads=8,
                                       read_delay=0.1,
                                       duration=0)

    def test_time_limit_reader_with_short_limit(self):
        # Read with insufficient time limit
        size = 50
        num_threads = 4
        sleep_duration = 0.25
        duration = 1
        expected_read_len = int(round(num_threads * duration / sleep_duration))
        # Because the time limit check happens before the delay + read op,
        # subtract a little bit of time to ensure we don't get in an extra read
        duration = duration - 0.25 * sleep_duration

        # NOTE: `expected_read_len_threshold` was added because this test case
        # has significant execution variation under stress. Under stress, we may
        # read strictly less than the expected # of samples; anywhere from
        # [0,N] where N = expected_read_len.
        # Hence we set expected_read_len to N/2, plus or minus N/2.
        self._test_limit_reader_shared(ReaderWithTimeLimit,
                                       size=size,
                                       expected_read_len=expected_read_len / 2,
                                       expected_read_len_threshold=expected_read_len / 2,
                                       expected_finish=False,
                                       num_threads=num_threads,
                                       read_delay=sleep_duration,
                                       duration=duration)

    def test_time_limit_reader_with_long_limit(self):
        # Read with ample time limit
        # NOTE: we don't use `expected_read_len_threshold` because the duration,
        # read_delay, and # threads should be more than sufficient
        self._test_limit_reader_shared(ReaderWithTimeLimit,
                                       size=50,
                                       expected_read_len=50,
                                       expected_read_len_threshold=0,
                                       expected_finish=True,
                                       num_threads=4,
                                       read_delay=0.2,
                                       duration=10)


class TestDBFileReader(TestCase):
    def setUp(self):
Loading ...