from caffe2.python import control, core, test_util, workspace
import logging
logger = logging.getLogger(__name__)
class TestControl(test_util.TestCase):
def setUp(self):
super(TestControl, self).setUp()
self.N_ = 10
self.init_net_ = core.Net("init-net")
cnt = self.init_net_.CreateCounter([], init_count=0)
const_n = self.init_net_.ConstantFill(
[], shape=[], value=self.N_, dtype=core.DataType.INT64)
const_0 = self.init_net_.ConstantFill(
[], shape=[], value=0, dtype=core.DataType.INT64)
self.cnt_net_ = core.Net("cnt-net")
self.cnt_net_.CountUp([cnt])
curr_cnt = self.cnt_net_.RetrieveCount([cnt])
self.init_net_.ConstantFill(
[], [curr_cnt], shape=[], value=0, dtype=core.DataType.INT64)
self.cnt_net_.AddExternalOutput(curr_cnt)
self.cnt_2_net_ = core.Net("cnt-2-net")
self.cnt_2_net_.CountUp([cnt])
self.cnt_2_net_.CountUp([cnt])
curr_cnt_2 = self.cnt_2_net_.RetrieveCount([cnt])
self.init_net_.ConstantFill(
[], [curr_cnt_2], shape=[], value=0, dtype=core.DataType.INT64)
self.cnt_2_net_.AddExternalOutput(curr_cnt_2)
self.cond_net_ = core.Net("cond-net")
cond_blob = self.cond_net_.LT([curr_cnt, const_n])
self.cond_net_.AddExternalOutput(cond_blob)
self.not_cond_net_ = core.Net("not-cond-net")
cond_blob = self.not_cond_net_.GE([curr_cnt, const_n])
self.not_cond_net_.AddExternalOutput(cond_blob)
self.true_cond_net_ = core.Net("true-cond-net")
true_blob = self.true_cond_net_.LT([const_0, const_n])
self.true_cond_net_.AddExternalOutput(true_blob)
self.false_cond_net_ = core.Net("false-cond-net")
false_blob = self.false_cond_net_.GT([const_0, const_n])
self.false_cond_net_.AddExternalOutput(false_blob)
self.idle_net_ = core.Net("idle-net")
self.idle_net_.ConstantFill(
[], shape=[], value=0, dtype=core.DataType.INT64)
def CheckNetOutput(self, nets_and_expects):
"""
Check the net output is expected
nets_and_expects is a list of tuples (net, expect)
"""
for net, expect in nets_and_expects:
output = workspace.FetchBlob(
net.Proto().external_output[-1])
self.assertEqual(output, expect)
def CheckNetAllOutput(self, net, expects):
"""
Check the net output is expected
expects is a list of bools.
"""
self.assertEqual(len(net.Proto().external_output), len(expects))
for i in range(len(expects)):
output = workspace.FetchBlob(
net.Proto().external_output[i])
self.assertEqual(output, expects[i])
def BuildAndRunPlan(self, step):
plan = core.Plan("test")
plan.AddStep(control.Do('init', self.init_net_))
plan.AddStep(step)
self.assertEqual(workspace.RunPlan(plan), True)
def ForLoopTest(self, nets_or_steps):
step = control.For('myFor', nets_or_steps, self.N_)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(self.cnt_net_, self.N_)])
def testForLoopWithNets(self):
self.ForLoopTest(self.cnt_net_)
self.ForLoopTest([self.cnt_net_, self.idle_net_])
def testForLoopWithStep(self):
step = control.Do('count', self.cnt_net_)
self.ForLoopTest(step)
self.ForLoopTest([step, self.idle_net_])
def WhileLoopTest(self, nets_or_steps):
step = control.While('myWhile', self.cond_net_, nets_or_steps)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(self.cnt_net_, self.N_)])
def testWhileLoopWithNet(self):
self.WhileLoopTest(self.cnt_net_)
self.WhileLoopTest([self.cnt_net_, self.idle_net_])
def testWhileLoopWithStep(self):
step = control.Do('count', self.cnt_net_)
self.WhileLoopTest(step)
self.WhileLoopTest([step, self.idle_net_])
def UntilLoopTest(self, nets_or_steps):
step = control.Until('myUntil', self.not_cond_net_, nets_or_steps)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(self.cnt_net_, self.N_)])
def testUntilLoopWithNet(self):
self.UntilLoopTest(self.cnt_net_)
self.UntilLoopTest([self.cnt_net_, self.idle_net_])
def testUntilLoopWithStep(self):
step = control.Do('count', self.cnt_net_)
self.UntilLoopTest(step)
self.UntilLoopTest([step, self.idle_net_])
def DoWhileLoopTest(self, nets_or_steps):
step = control.DoWhile('myDoWhile', self.cond_net_, nets_or_steps)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(self.cnt_net_, self.N_)])
def testDoWhileLoopWithNet(self):
self.DoWhileLoopTest(self.cnt_net_)
self.DoWhileLoopTest([self.idle_net_, self.cnt_net_])
def testDoWhileLoopWithStep(self):
step = control.Do('count', self.cnt_net_)
self.DoWhileLoopTest(step)
self.DoWhileLoopTest([self.idle_net_, step])
def DoUntilLoopTest(self, nets_or_steps):
step = control.DoUntil('myDoUntil', self.not_cond_net_, nets_or_steps)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(self.cnt_net_, self.N_)])
def testDoUntilLoopWithNet(self):
self.DoUntilLoopTest(self.cnt_net_)
self.DoUntilLoopTest([self.cnt_net_, self.idle_net_])
def testDoUntilLoopWithStep(self):
step = control.Do('count', self.cnt_net_)
self.DoUntilLoopTest(step)
self.DoUntilLoopTest([self.idle_net_, step])
def IfCondTest(self, cond_net, expect, cond_on_blob):
if cond_on_blob:
step = control.Do(
'if-all',
control.Do('count', cond_net),
control.If('myIf', cond_net.Proto().external_output[-1],
self.cnt_net_))
else:
step = control.If('myIf', cond_net, self.cnt_net_)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(self.cnt_net_, expect)])
def testIfCondTrueOnNet(self):
self.IfCondTest(self.true_cond_net_, 1, False)
def testIfCondTrueOnBlob(self):
self.IfCondTest(self.true_cond_net_, 1, True)
def testIfCondFalseOnNet(self):
self.IfCondTest(self.false_cond_net_, 0, False)
def testIfCondFalseOnBlob(self):
self.IfCondTest(self.false_cond_net_, 0, True)
def IfElseCondTest(self, cond_net, cond_value, expect, cond_on_blob):
if cond_value:
run_net = self.cnt_net_
else:
run_net = self.cnt_2_net_
if cond_on_blob:
step = control.Do(
'if-else-all',
control.Do('count', cond_net),
control.If('myIfElse', cond_net.Proto().external_output[-1],
self.cnt_net_, self.cnt_2_net_))
else:
step = control.If('myIfElse', cond_net,
self.cnt_net_, self.cnt_2_net_)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(run_net, expect)])
def testIfElseCondTrueOnNet(self):
self.IfElseCondTest(self.true_cond_net_, True, 1, False)
def testIfElseCondTrueOnBlob(self):
self.IfElseCondTest(self.true_cond_net_, True, 1, True)
def testIfElseCondFalseOnNet(self):
self.IfElseCondTest(self.false_cond_net_, False, 2, False)
def testIfElseCondFalseOnBlob(self):
self.IfElseCondTest(self.false_cond_net_, False, 2, True)
def IfNotCondTest(self, cond_net, expect, cond_on_blob):
if cond_on_blob:
step = control.Do(
'if-not',
control.Do('count', cond_net),
control.IfNot('myIfNot', cond_net.Proto().external_output[-1],
self.cnt_net_))
else:
step = control.IfNot('myIfNot', cond_net, self.cnt_net_)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(self.cnt_net_, expect)])
def testIfNotCondTrueOnNet(self):
self.IfNotCondTest(self.true_cond_net_, 0, False)
def testIfNotCondTrueOnBlob(self):
self.IfNotCondTest(self.true_cond_net_, 0, True)
def testIfNotCondFalseOnNet(self):
self.IfNotCondTest(self.false_cond_net_, 1, False)
def testIfNotCondFalseOnBlob(self):
self.IfNotCondTest(self.false_cond_net_, 1, True)
def IfNotElseCondTest(self, cond_net, cond_value, expect, cond_on_blob):
if cond_value:
run_net = self.cnt_2_net_
else:
run_net = self.cnt_net_
if cond_on_blob:
step = control.Do(
'if-not-else',
control.Do('count', cond_net),
control.IfNot('myIfNotElse',
cond_net.Proto().external_output[-1],
self.cnt_net_, self.cnt_2_net_))
else:
step = control.IfNot('myIfNotElse', cond_net,
self.cnt_net_, self.cnt_2_net_)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(run_net, expect)])
def testIfNotElseCondTrueOnNet(self):
self.IfNotElseCondTest(self.true_cond_net_, True, 2, False)
def testIfNotElseCondTrueOnBlob(self):
self.IfNotElseCondTest(self.true_cond_net_, True, 2, True)
def testIfNotElseCondFalseOnNet(self):
self.IfNotElseCondTest(self.false_cond_net_, False, 1, False)
def testIfNotElseCondFalseOnBlob(self):
self.IfNotElseCondTest(self.false_cond_net_, False, 1, True)
def testSwitch(self):
step = control.Switch(
'mySwitch',
(self.false_cond_net_, self.cnt_net_),
(self.true_cond_net_, self.cnt_2_net_)
)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(self.cnt_net_, 0), (self.cnt_2_net_, 2)])
def testSwitchNot(self):
step = control.SwitchNot(
'mySwitchNot',
(self.false_cond_net_, self.cnt_net_),
(self.true_cond_net_, self.cnt_2_net_)
)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(self.cnt_net_, 1), (self.cnt_2_net_, 0)])
def testBoolNet(self):
bool_net = control.BoolNet(('a', True))
step = control.Do('bool', bool_net)
self.BuildAndRunPlan(step)
self.CheckNetAllOutput(bool_net, [True])
bool_net = control.BoolNet(('a', True), ('b', False))
step = control.Do('bool', bool_net)
self.BuildAndRunPlan(step)
self.CheckNetAllOutput(bool_net, [True, False])
bool_net = control.BoolNet([('a', True), ('b', False)])
step = control.Do('bool', bool_net)
self.BuildAndRunPlan(step)
self.CheckNetAllOutput(bool_net, [True, False])
def testCombineConditions(self):
# combined by 'Or'
combine_net = control.CombineConditions(
'test', [self.true_cond_net_, self.false_cond_net_], 'Or')
step = control.Do('combine',
self.true_cond_net_,
self.false_cond_net_,
combine_net)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(combine_net, True)])
# combined by 'And'
combine_net = control.CombineConditions(
'test', [self.true_cond_net_, self.false_cond_net_], 'And')
step = control.Do('combine',
self.true_cond_net_,
self.false_cond_net_,
combine_net)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(combine_net, False)])
def testMergeConditionNets(self):
# merged by 'Or'
merge_net = control.MergeConditionNets(
'test', [self.true_cond_net_, self.false_cond_net_], 'Or')
step = control.Do('merge', merge_net)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(merge_net, True)])
# merged by 'And'
merge_net = control.MergeConditionNets(
'test', [self.true_cond_net_, self.false_cond_net_], 'And')
step = control.Do('merge', merge_net)
self.BuildAndRunPlan(step)
self.CheckNetOutput([(merge_net, False)])