Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / net_builder.py

## @package net_builder
# Module caffe2.python.net_builder





from caffe2.python import core, context
from caffe2.python.task import Task, TaskGroup
from caffe2.python.control_ops_util import add_if_op, add_while_op


class NetBuilder(context.Managed):
    """
    Scope-driven mechanism for building nets, loops and conditional blocks.
    Args:
      name: NetBuilder's name
      initial_scope: list of blobs that are available for reading/writing
    Example:
        from caffe2.python.net_builder import NetBuilder, ops
        with NetBuilder() as nb:
            c = ops.Const(5)
            d = ops.Const(0)
            with ops.loop():
                ops.stop_if(ops.LE([c, ops.Const(0)]))
                ops.Add([c, ops.Const(-1)], [c])
                with ops.If(ops.GE([c, ops.Const(3)])):
                    ops.Add([d, ops.Const(10)], [d])
            ops.Print(c, [])
            ops.Print(d, [])
        step = core.to_execution_step(nb)
    """
    def __init__(self, name=None, initial_scope=None, _stop_blob_required=False,
                 _stop_blob=None, _fullname=None, _use_control_ops=False):
        parent = NetBuilder.current(required=False)
        assert not _fullname or not name, 'Cannot set both _fullname and name'
        assert not _use_control_ops or \
            (not _stop_blob_required and not _stop_blob), \
            'Stop blobs are not used with control operators'
        self.name = _fullname or '/'.join(
            n for n in (parent.name if parent else None, name) if n
        )
        self._frozen = False
        self._current_net = None
        self._children = []
        if parent:
            # make sure parent has an up to date lexical scope computed
            parent._update_lexical_scope()
        self._init_lexical_scope = set(parent._lexical_scope) if parent else set()
        if initial_scope:
            self._init_lexical_scope |= set([str(b) for b in initial_scope])
        self._lexical_scope = set(self._init_lexical_scope)
        self._stop_blob = _stop_blob
        self._stop_blob_required = _stop_blob_required
        self._use_control_ops = _use_control_ops

    def stop_blob(self):
        """
        Returns the BlobReference to the stop_blob of this NetBuilder.
        If one is not yet available, creates one.
        This function assumes that the stop_blob() will be used immediatelly
        in the current net, so it doesn't initialize it if the current net is
        the first of the builder.
        """
        assert not self._use_control_ops, \
            'Stop blobs are not used with control operators'
        if self._stop_blob is None:
            net = self.current_net()
            self._stop_blob = core.BlobReference(
                net.NextName('stop_blob'), net=net)
            net.Const(False, blob_out=self._stop_blob)
            if self._current_net != self._children[0]:
                self._children.insert(0, core.Net('stop_blob_init'))
                self._children[0].Const(False, blob_out=self._stop_blob)
        return self._stop_blob

    def stop_if(self, blob):
        assert not self._use_control_ops, \
            'Stop blobs are not used with control operators'
        stop_blob = self.stop_blob()
        ops.Or([stop_blob, blob], [stop_blob])
        self._current_net = None

    def _assert_mutable(self):
        assert not self._frozen, (
            'This NetBuilder (%s) has been built already.' % self.name)

    def _update_lexical_scope(self):
        """
        Updates lexical scope based on the current list of children.
        Lexical scope contains names of blobs that are currently available
        and were introduced in the net builder
        """
        self._lexical_scope = set(self._init_lexical_scope)
        for child in self._children:
            if isinstance(child, core.Net):
                self._lexical_scope |= child.UsedBlobNames()
            elif isinstance(child, NetBuilder) and child._use_control_ops:
                self._lexical_scope |= child._lexical_scope

    def _reset_children(self):
        self._current_net = None
        self._children = []
        self._lexical_scope = set(self._init_lexical_scope)

    def add(self, child):
        self._assert_mutable()

        if self._use_control_ops:
            assert isinstance(child, core.Net) or (
                isinstance(child, NetBuilder) and child._use_control_ops), \
                "Expected Net or NetBuilder with control ops"

        self._current_net = None
        self._children.append(child)
        # to-do : check it's not a dag net
        if isinstance(child, core.Net):
            self._current_net = child
        self._update_lexical_scope()
        return child

    def current_net(self, name=None):
        self._assert_mutable()
        if self._current_net is None or name is not None:
            self.add(core.Net(name))
        return self._current_net

    def freeze(self):
        for child in self._children:
            if hasattr(child, 'freeze'):
                child.freeze()
        self._current_net = None
        self._frozen = True

    def get(self):
        self.freeze()
        return self._children

    def __exit__(self, etype, *args):
        super(NetBuilder, self).__exit__(etype, *args)

        if self._use_control_ops and len(self._children) > 0:
            _children = self._children
            self._reset_children()
            merged_net = NetBuilder.merge_nets(
                _children, self._lexical_scope)
            assert merged_net, "Expected a non-empty merge of children"
            self._children = [merged_net]

        self.freeze()
        if etype is not None:
            return
        assert (not self._stop_blob_required) or self._stop_blob is not None, (
            'This NetBuilder (%s) requires a stop condition ' % self.name +
            'to be set with `stop` or `stop_if`')

    @staticmethod
    def merge_nets(nets_or_builders, outer_blob_names):
        # Only nets or builders with control ops are allowed.
        # Need to pay attention to external outputs, e.g.
        #   ...
        #   IfNet1 (cond_blob):
        #       (Net1)
        #           X = 1
        #       IfNet2 (...):
        #           X = X + 1
        #   ...
        # In this example there're two children in then branch of IfNet1:
        # a subnet Net1 that creates blob X and sets its value to one, and
        # a net builder IfNet2 that (conditionally) increments X.
        # From IfNet2's point of view X is an external input
        # and output blob, it will be put into IfNet2 net's external_output.
        # At the same time, from the point of view of IfNet1 X is purely local.
        # Net.AppendNet just merges external outputs of the networks, so
        # without checking this the result of Net1.AppendNet(IfNet2's net)
        # would have blob X in external_output

        net = None
        for n in nets_or_builders:
            cur = None
            if isinstance(n, NetBuilder):
                assert n._use_control_ops, \
                    "Merging of NetBuilder supported only for control ops"
                nets = n.get()
                assert len(nets) == 1 and isinstance(nets[0], core.Net), \
                    "Invalid control op net builder"
                cur = nets[0]
            else:
                assert isinstance(n, core.Net)
                cur = n
            if net:
                net.AppendNet(cur)
            else:
                net = cur
        if net:
            # correct external output
            external_outputs = [o for o in net.Proto().external_output
                                    if o in outer_blob_names]
            net.Proto().external_output[:] = external_outputs
        return net

    def __str__(self):
        return self.name or 'Un-named NetBuilder'


class Operations(object):
    """
    Operations to be used in the context of a NetBuilder.
    """
    def net(self, net=None, name=None):
        """
        Retrieves the current net, or add a new net to the builder.
        Args:
            net:   If provided, add the given net to the active builder.
                   Else, returns the current Net or creates a new one as needed.
            name:  if provided, creates a new Net with given name and makes
                   it the new current net of the active builder. Cannot
                   be provided if net is provided.
        """
        assert name is None or net is None, (
            'Cannot provide both `net` and `name`.')
        if net is not None:
            NetBuilder.current().add(net)
            return net
        return NetBuilder.current().current_net(name=name)

    def __getattr__(self, op_type):
        """
        Adds an operator call to the currently active Net.
        """
        if op_type.startswith('__'):
            raise AttributeError()
        # We want hasattr to work properly even if no context is active.
        if NetBuilder.current(required=False) is None:
            raise AttributeError('No active NetBuilder.')
        return getattr(self.net(), op_type)

    def task_group(self):
        """
        Creates a local task group which will execute as the next step of
        the current NetBuilder.
        """
        from caffe2.python import task
        group = NetBuilder.current()
        with task.Cluster():
            with task.Node('local'):
                tg = task.TaskGroup()
                group.add(tg)
                return tg

    def stop(self):
        """
        Stop execution of the current execution step.
            Example:
                ops.Print(a, 0)
                ops.stop()
                ops.Print(b, 0)
            In the example, 'b' will never be printed.
        """
        return self.stop_if(ops.Const(True))

    def stop_if(self, blob):
        """
        Stop execution of the current execution step if the
        condition `blob` is met.
            Example:
                ops.Print(a, 0)
                ops.stop_if(ops.LE([x, ops.Const(0)]))
                ops.Print(b, 0)
            In the example, 'b' will only be printed if the value of scalar
            tensor 'x' is greater than 0.
        """
        return NetBuilder.current().stop_if(blob)

    def loop(self, iters=None, name=None):
        """
        Creates a NetBuilder that will execute in a loop as the next step of
        the current NetBuilder. If `iters` is provided, the loop will execute
        for `iters` iterations and then stop. `iters` can be a constant or a
        BlobReference. If `iters` is not provided, the loop will execute
        until `ops.stop` or `ops.stop_if` is called.
            Examples:
                a = ops.Const(5)
                with ops.loop():
                    ops.stop_if(ops.LE([a, ops.Const(0)]))
                    ops.Print(a, 0)
                    ops.Add([a, ops.Const(-1)], [a])
            Above, 'a' will be printed 5 times, with values 5 to 1.

                with ops.loop(10) as loop:
                    ops.LogInfo(loop.iter())
            This will print the numbers from 0 to 9.

                x = ops.Add([ops.Const(10), ops.Const(10)])
                with ops.loop(x) as loop:
                    ops.LogInfo(loop.iter())
            This will print the numbers from 0 to 19.
        """
        return NetBuilder.current().add(_Loop(iters, name=name))

    def stop_guard(self, has_stopped_blob=None, name=None):
        """
        Creates a NetBuilder that will execute once as the next step of the
        current NetBuilder. After execution, a bool tensor will indicate
        whether the inner execution was halted with `stop` or `stop_if`.
            Example:
                a = ops.Const(True)
                with ops.stop_guard() as sg1:
                    ops.stop_if(a)
                    ops.Print(ops.Const('did not stop'))
                b = ops.Const(False)
                with ops.stop_guard() as sg2:
                    ops.stop_if(b)
                    ops.Print(ops.Const('did not stop'))
                ops.Print(sg1.has_stopped(), [])
                ops.Print(sg2.has_stopped(), [])
            In the example, 'did not stop' will be printed once,
            followed by True and False.
        """
        return NetBuilder.current().add(
            _StopGuard(has_stopped_blob=has_stopped_blob, name=name))

    def If(self, cond, name=None):
        """
        Creates a NetBuilder that will execute once as the next step of the
        current NetBuilder if the blob `cond` is True.
            Example:
                with ops.If(ops.Const(True)):
                    ops.Print(ops.Const('Will print'))
                with ops.If(ops.Const(False)):
                    ops.Print(ops.Const('Wont print'))
            The example will print 'Will print' once.
        """
        return NetBuilder.current().add(_RunIf(cond, name=name))

    def IfNet(self, cond, name=None):
        """
        Same as If, but uses 'If' operator instead of execution step logic
        """
        return NetBuilder.current().add(_RunIfNet(cond, name=name))

    def Else(self, name=None):
        """
        Else branch of IfNet, has to be specified immediately after IfNet.
            Example:
Loading ...