Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / control_ops_grad_test.py






import unittest
from caffe2.python import core, test_util, workspace
from caffe2.python.control_ops_grad import disambiguate_grad_if_op_output
from caffe2.python.model_helper import ModelHelper
import numpy as np


class TestControl(test_util.TestCase):
    def test_disambiguate_grad_if_op_output(self):
        workspace.FeedBlob("cond", np.array(True))
        workspace.FeedBlob("then_grad", np.array(1))
        workspace.FeedBlob("else_grad", np.array(2))

        then_model = ModelHelper(name="then_test_model")
        then_model.net.Copy("then_grad", "input_grad")

        else_model = ModelHelper(name="else_test_model")
        else_model.net.Copy("else_grad", "else_temp_grad")
        else_model.net.Copy("else_temp", "input_grad")

        # to BuildGradientGenerators, in forward pass, we need else temp
        # as one of the output. Which later on results in a grad op like this:
        grad_op = core.CreateOperator(
            "If",
            ["cond", "then_grad", "else_grad"],
            ["input_grad", "else_temp_grad"],
            then_net=then_model.net.Proto(),
            else_net=else_model.net.Proto(),
        )

        # in certain cases, another branch of the net also generates input_grad
        # and we call _DisambiguateGradOpOutput in core.py
        new_grad_output = "input_grad" + "_autosplit_" + "0"
        disambiguate_grad_if_op_output(grad_op, 0, new_grad_output)
        self.assertEqual(grad_op.output[0], new_grad_output)
        for arg in grad_op.arg:
            if arg.name == "else_net":
                self.assertEqual(arg.n.op[1].output[0], new_grad_output)
            else:
                self.assertEqual(arg.name, "then_net")


if __name__ == '__main__':
    unittest.main()