Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / caffe_translator_test.py

# This a large test that goes through the translation of the bvlc caffenet
# model, runs an example through the whole model, and verifies numerically
# that all the results look right. In default, it is disabled unless you
# explicitly want to run it.

from google.protobuf import text_format
import numpy as np
import os
import sys

CAFFE_FOUND = False
try:
    from caffe.proto import caffe_pb2
    from caffe2.python import caffe_translator
    CAFFE_FOUND = True
except Exception as e:
    # Safeguard so that we only catch the caffe module not found exception.
    if ("'caffe'" in str(e)):
        print(
            "PyTorch/Caffe2 now requires a separate installation of caffe. "
            "Right now, this is not found, so we will skip the caffe "
            "translator test.")

from caffe2.python import utils, workspace, test_util
import unittest

def setUpModule():
    # Do nothing if caffe and test data is not found
    if not (CAFFE_FOUND and os.path.exists('data/testdata/caffe_translator')):
        return
    # We will do all the computation stuff in the global space.
    caffenet = caffe_pb2.NetParameter()
    caffenet_pretrained = caffe_pb2.NetParameter()
    with open('data/testdata/caffe_translator/deploy.prototxt') as f:
        text_format.Merge(f.read(), caffenet)
    with open('data/testdata/caffe_translator/'
              'bvlc_reference_caffenet.caffemodel') as f:
        caffenet_pretrained.ParseFromString(f.read())
    for remove_legacy_pad in [True, False]:
        net, pretrained_params = caffe_translator.TranslateModel(
            caffenet, caffenet_pretrained, is_test=True,
            remove_legacy_pad=remove_legacy_pad
        )
        with open('data/testdata/caffe_translator/'
                  'bvlc_reference_caffenet.translatedmodel',
                  'w') as fid:
            fid.write(str(net))
        for param in pretrained_params.protos:
            workspace.FeedBlob(param.name, utils.Caffe2TensorToNumpyArray(param))
        # Let's also feed in the data from the Caffe test code.
        data = np.load('data/testdata/caffe_translator/data_dump.npy').astype(
            np.float32)
        workspace.FeedBlob('data', data)
        # Actually running the test.
        workspace.RunNetOnce(net.SerializeToString())


@unittest.skipIf(not CAFFE_FOUND,
                 'No Caffe installation found.')
@unittest.skipIf(not os.path.exists('data/testdata/caffe_translator'),
                 'No testdata existing for the caffe translator test. Exiting.')
class TestNumericalEquivalence(test_util.TestCase):
    def testBlobs(self):
        names = [
            "conv1", "pool1", "norm1", "conv2", "pool2", "norm2", "conv3",
            "conv4", "conv5", "pool5", "fc6", "fc7", "fc8", "prob"
        ]
        for name in names:
            print('Verifying {}'.format(name))
            caffe2_result = workspace.FetchBlob(name)
            reference = np.load(
                'data/testdata/caffe_translator/' + name + '_dump.npy'
            )
            self.assertEqual(caffe2_result.shape, reference.shape)
            scale = np.max(caffe2_result)
            np.testing.assert_almost_equal(
                caffe2_result / scale,
                reference / scale,
                decimal=5
            )


if __name__ == '__main__':
    if len(sys.argv) == 1:
        print(
            'If you do not explicitly ask to run this test, I will not run it. '
            'Pass in any argument to have the test run for you.'
        )
        sys.exit(0)
    unittest.main()