Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / serialized_test / serialized_test_util.py






import argparse
from caffe2.proto import caffe2_pb2
from caffe2.python import gradient_checker
import caffe2.python.hypothesis_test_util as hu
from caffe2.python.serialized_test import coverage
import hypothesis as hy
import inspect
import numpy as np
import os
import shutil
import sys
import tempfile
import threading
from zipfile import ZipFile

operator_test_type = 'operator_test'
TOP_DIR = os.path.dirname(os.path.realpath(__file__))
DATA_SUFFIX = 'data'
DATA_DIR = os.path.join(TOP_DIR, DATA_SUFFIX)
_output_context = threading.local()


def given(*given_args, **given_kwargs):
    def wrapper(f):
        hyp_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given(*given_args, **given_kwargs)(f)))
        fixed_seed_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given(
            *given_args, **given_kwargs)(f)))

        def func(self, *args, **kwargs):
            self.should_serialize = True
            fixed_seed_func(self, *args, **kwargs)
            self.should_serialize = False
            hyp_func(self, *args, **kwargs)
        return func
    return wrapper


def _getGradientOrNone(op_proto):
    try:
        grad_ops, _ = gradient_checker.getGradientForOp(op_proto)
        return grad_ops
    except Exception:
        return []


# necessary to support converting jagged lists into numpy arrays
def _transformList(l):
    ret = np.empty(len(l), dtype=np.object)
    for (i, arr) in enumerate(l):
        ret[i] = arr
    return ret


def _prepare_dir(path):
    if os.path.exists(path):
        shutil.rmtree(path)
    os.makedirs(path)


class SerializedTestCase(hu.HypothesisTestCase):

    should_serialize = False

    def get_output_dir(self):
        output_dir_arg = getattr(_output_context, 'output_dir', DATA_DIR)
        output_dir = os.path.join(
            output_dir_arg, operator_test_type)

        if os.path.exists(output_dir):
            return output_dir

        # fall back to pwd
        cwd = os.getcwd()
        serialized_util_module_components = __name__.split('.')
        serialized_util_module_components.pop()
        serialized_dir = '/'.join(serialized_util_module_components)
        output_dir_fallback = os.path.join(cwd, serialized_dir, DATA_SUFFIX)
        output_dir = os.path.join(
            output_dir_fallback,
            operator_test_type)

        return output_dir

    def get_output_filename(self):
        class_path = inspect.getfile(self.__class__)
        file_name_components = os.path.basename(class_path).split('.')
        test_file = file_name_components[0]

        function_name_components = self.id().split('.')
        test_function = function_name_components[-1]

        return test_file + '.' + test_function

    def serialize_test(self, inputs, outputs, grad_ops, op, device_option):
        output_dir = self.get_output_dir()
        test_name = self.get_output_filename()
        full_dir = os.path.join(output_dir, test_name)
        _prepare_dir(full_dir)

        inputs = _transformList(inputs)
        outputs = _transformList(outputs)
        device_type = int(device_option.device_type)

        op_path = os.path.join(full_dir, 'op.pb')
        grad_paths = []
        inout_path = os.path.join(full_dir, 'inout')

        with open(op_path, 'wb') as f:
            f.write(op.SerializeToString())
        for (i, grad) in enumerate(grad_ops):
            grad_path = os.path.join(full_dir, 'grad_{}.pb'.format(i))
            grad_paths.append(grad_path)
            with open(grad_path, 'wb') as f:
                f.write(grad.SerializeToString())

        np.savez_compressed(
            inout_path,
            inputs=inputs,
            outputs=outputs,
            device_type=device_type)

        with ZipFile(os.path.join(output_dir, test_name + '.zip'), 'w') as z:
            z.write(op_path, 'op.pb')
            z.write(inout_path + '.npz', 'inout.npz')
            for path in grad_paths:
                z.write(path, os.path.basename(path))

        shutil.rmtree(full_dir)

    def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7):

        def parse_proto(x):
            proto = caffe2_pb2.OperatorDef()
            proto.ParseFromString(x)
            return proto

        source_dir = self.get_output_dir()
        test_name = self.get_output_filename()
        temp_dir = tempfile.mkdtemp()
        with ZipFile(os.path.join(source_dir, test_name + '.zip')) as z:
            z.extractall(temp_dir)

        op_path = os.path.join(temp_dir, 'op.pb')
        inout_path = os.path.join(temp_dir, 'inout.npz')

        # load serialized input and output
        loaded = np.load(inout_path, encoding='bytes', allow_pickle=True)
        loaded_inputs = loaded['inputs'].tolist()
        inputs_equal = True
        for (x, y) in zip(inputs, loaded_inputs):
            if not np.array_equal(x, y):
                inputs_equal = False
        loaded_outputs = loaded['outputs'].tolist()

        # if inputs are not the same, run serialized input through serialized op
        if not inputs_equal:
            # load operator
            with open(op_path, 'rb') as f:
                loaded_op = f.read()

            op_proto = parse_proto(loaded_op)
            device_type = loaded['device_type']
            device_option = caffe2_pb2.DeviceOption(
                device_type=int(device_type))

            outputs = hu.runOpOnInput(device_option, op_proto, loaded_inputs)
            grad_ops = _getGradientOrNone(op_proto)

        # assert outputs are equal
        for (x, y) in zip(outputs, loaded_outputs):
            np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)

        # assert gradient op is equal
        for i in range(len(grad_ops)):
            grad_path = os.path.join(temp_dir, 'grad_{}.pb'.format(i))
            with open(grad_path, 'rb') as f:
                loaded_grad = f.read()
            grad_proto = parse_proto(loaded_grad)
            self._assertSameOps(grad_proto, grad_ops[i])

        shutil.rmtree(temp_dir)

    def _assertSameOps(self, op1, op2):
        op1_ = caffe2_pb2.OperatorDef()
        op1_.CopyFrom(op1)
        op1_.arg.sort(key=lambda arg: arg.name)

        op2_ = caffe2_pb2.OperatorDef()
        op2_.CopyFrom(op2)
        op2_.arg.sort(key=lambda arg: arg.name)

        self.assertEqual(op1_, op2_)

    def assertSerializedOperatorChecks(
            self,
            inputs,
            outputs,
            gradient_operator,
            op,
            device_option,
            atol=1e-7,
            rtol=1e-7,
    ):
        if self.should_serialize:
            if getattr(_output_context, 'should_generate_output', False):
                self.serialize_test(
                    inputs, outputs, gradient_operator, op, device_option)
                if not getattr(_output_context, 'disable_gen_coverage', False):
                    coverage.gen_serialized_test_coverage(
                        self.get_output_dir(), TOP_DIR)
            else:
                self.compare_test(
                    inputs, outputs, gradient_operator, atol, rtol)

    def assertReferenceChecks(
        self,
        device_option,
        op,
        inputs,
        reference,
        input_device_options=None,
        threshold=1e-4,
        output_to_grad=None,
        grad_reference=None,
        atol=None,
        outputs_to_check=None,
        ensure_outputs_are_inferred=False,
    ):
        outs = super(SerializedTestCase, self).assertReferenceChecks(
            device_option,
            op,
            inputs,
            reference,
            input_device_options,
            threshold,
            output_to_grad,
            grad_reference,
            atol,
            outputs_to_check,
            ensure_outputs_are_inferred,
        )
        if not getattr(_output_context, 'disable_serialized_check', False):
            grad_ops = _getGradientOrNone(op)
            rtol = threshold
            if atol is None:
                atol = threshold
            self.assertSerializedOperatorChecks(
                inputs,
                outs,
                grad_ops,
                op,
                device_option,
                atol,
                rtol,
            )


def testWithArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-G', '--generate-serialized', action='store_true', dest='generate',
        help='generate output files (default=false, compares to current files)')
    parser.add_argument(
        '-O', '--output', default=DATA_DIR,
        help='output directory (default: %(default)s)')
    parser.add_argument(
        '-D', '--disable-serialized_check', action='store_true', dest='disable',
        help='disable checking serialized tests')
    parser.add_argument(
        '-C', '--disable-gen-coverage', action='store_true',
        dest='disable_coverage',
        help='disable generating coverage markdown file')
    parser.add_argument('unittest_args', nargs='*')
    args = parser.parse_args()
    sys.argv[1:] = args.unittest_args
    _output_context.__setattr__('should_generate_output', args.generate)
    _output_context.__setattr__('output_dir', args.output)
    _output_context.__setattr__('disable_serialized_check', args.disable)
    _output_context.__setattr__('disable_gen_coverage', args.disable_coverage)

    import unittest
    unittest.main()