Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ contrib / warpctc / ctc_ops_test.py





import numpy as np
from caffe2.proto import caffe2_pb2

from caffe2.python import core, workspace, dyndep, test_util

dyndep.InitOpsLibrary('@/caffe2/caffe2/contrib/warpctc:ctc_ops')
workspace.GlobalInit(["python"])


def softmax(w):
    maxes = np.amax(w, axis=-1, keepdims=True)
    e = np.exp(w - maxes)
    dist = e / np.sum(e, axis=-1, keepdims=True)
    return dist


class CTCOpsTest(test_util.TestCase):
    def verify_cost(self, device_option, is_test, skip_input_lengths=False):
        alphabet_size = 5
        N = 1
        T = 2

        inputs = np.asarray(
            [
                [[0.1, 0.6, 0.1, 0.1, 0.1]],
                [[0.1, 0.1, 0.6, 0.1, 0.1]],
            ]
        ).reshape(T, N, alphabet_size).astype(np.float32)

        labels = np.asarray([1, 2]).astype(np.int32).reshape(T)
        label_lengths = np.asarray([2]).astype(np.int32).reshape(N)
        input_lengths = np.asarray([T]).astype(np.int32)

        net = core.Net("test-net")
        input_blobs = ["inputs", "labels", "label_lengths"]
        if not skip_input_lengths:
            input_blobs.append("input_lengths")
        output_blobs = ["costs", "workspace"] if is_test \
                else ["inputs_grad_to_be_copied", "costs", "workspace"]
        net.CTC(input_blobs,
                output_blobs,
                is_test=is_test,
                device_option=device_option)
        if not is_test:
            net.AddGradientOperators(["costs"])
        self.ws.create_blob("inputs").feed(inputs, device_option=device_option)
        self.ws.create_blob("labels").feed(labels)
        self.ws.create_blob("label_lengths").feed(label_lengths)
        if not skip_input_lengths:
            self.ws.create_blob("input_lengths").feed(input_lengths)
        self.ws.run(net)
        probs = softmax(inputs)
        expected = probs[0, 0, 1] * probs[1, 0, 2]
        self.assertEqual(self.ws.blobs["costs"].fetch().shape, (N,))
        self.assertEqual(self.ws.blobs["costs"].fetch().dtype, np.float32)
        cost = self.ws.blobs["costs"].fetch()[0]
        print(cost)
        self.assertAlmostEqual(np.exp(-cost), expected)
        if not is_test:
            # Make sure inputs_grad was added by AddGradientOperators and
            # it is equal to the inputs_grad_to_be_copied blob returned by CTCop
            assert np.array_equal(
                self.ws.blobs["inputs_grad"].fetch(),
                self.ws.blobs["inputs_grad_to_be_copied"].fetch()
            )

    def test_ctc_cost_cpu(self):
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU),
            is_test=False)
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU),
            is_test=False, skip_input_lengths=True)

    def test_ctc_cost_gpu(self):
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA,
                                    device_id=0),
            is_test=False)
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA,
                                    device_id=0),
            is_test=False,
            skip_input_lengths=True)

    def test_ctc_forward_only_cpu(self):
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU),
            is_test=True)
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU),
            is_test=True,
            skip_input_lengths=True)

    def test_ctc_forward_only_gpu(self):
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA,
                                    device_id=0),
            is_test=True)
        self.verify_cost(
            caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA,
                                    device_id=0),
            is_test=True,
            skip_input_lengths=True)