from caffe2.python import net_printer
from caffe2.python.checkpoint import Job
from caffe2.python.net_builder import ops
from caffe2.python.task import Task, final_output, WorkspaceType
import unittest
def example_loop():
with Task():
total = ops.Const(0)
total_large = ops.Const(0)
total_small = ops.Const(0)
total_tiny = ops.Const(0)
with ops.loop(10) as loop:
outer = ops.Mul([loop.iter(), ops.Const(10)])
with ops.loop(loop.iter()) as inner:
val = ops.Add([outer, inner.iter()])
with ops.If(ops.GE([val, ops.Const(80)])) as c:
ops.Add([total_large, val], [total_large])
with c.Elif(ops.GE([val, ops.Const(50)])) as c:
ops.Add([total_small, val], [total_small])
with c.Else():
ops.Add([total_tiny, val], [total_tiny])
ops.Add([total, val], total)
def example_task():
with Task():
with ops.task_init():
one = ops.Const(1)
two = ops.Add([one, one])
with ops.task_init():
three = ops.Const(3)
accum = ops.Add([two, three])
# here, accum should be 5
with ops.task_exit():
# here, accum should be 6, since this executes after lines below
seven_1 = ops.Add([accum, one])
six = ops.Add([accum, one])
ops.Add([accum, one], [accum])
seven_2 = ops.Add([accum, one])
o6 = final_output(six)
o7_1 = final_output(seven_1)
o7_2 = final_output(seven_2)
with Task(num_instances=2):
with ops.task_init():
one = ops.Const(1)
with ops.task_instance_init():
local = ops.Const(2)
ops.Add([one, local], [one])
ops.LogInfo('ble')
return o6, o7_1, o7_2
def example_job():
with Job() as job:
with job.init_group:
example_loop()
example_task()
return job
class TestNetPrinter(unittest.TestCase):
def test_print(self):
self.assertTrue(len(net_printer.to_string(example_job())) > 0)
def test_valid_job(self):
job = example_job()
with job:
with Task():
# distributed_ctx_init_* ignored by analyzer
ops.Add(['distributed_ctx_init_a', 'distributed_ctx_init_b'])
# net_printer.analyze(example_job())
print(net_printer.to_string(example_job()))
def test_undefined_blob(self):
job = example_job()
with job:
with Task():
ops.Add(['a', 'b'])
with self.assertRaises(AssertionError) as e:
net_printer.analyze(job)
self.assertEqual("Blob undefined: a", str(e.exception))
def test_multiple_definition(self):
job = example_job()
with job:
with Task(workspace_type=WorkspaceType.GLOBAL):
ops.Add([ops.Const(0), ops.Const(1)], 'out1')
with Task(workspace_type=WorkspaceType.GLOBAL):
ops.Add([ops.Const(2), ops.Const(3)], 'out1')
with self.assertRaises(AssertionError):
net_printer.analyze(job)