Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / nomnigraph.py



import errno
import os
from subprocess import PIPE, Popen

import caffe2.python._import_c_extension as C
from caffe2.proto import caffe2_pb2
from caffe2.python import core


class NNModule(object):
    def __init__(self, net=None, device_map=None):
        if net is not None:
            serialized_proto = None
            if isinstance(net, core.Net):
                serialized_proto = net.Proto().SerializeToString()
            elif isinstance(net, caffe2_pb2.NetDef):
                serialized_proto = net.SerializeToString()

            # Distributed
            if device_map is not None:
                serialized_device_map = {}
                for k in device_map:
                    serialized_device_map[k] = device_map[k].SerializeToString()
                self._NNModule = C.NNModuleFromProtobufDistributed(
                    serialized_proto, serialized_device_map
                )
            # Default
            elif serialized_proto:
                self._NNModule, self._OpList = C.NNModuleFromProtobuf(serialized_proto)
            else:
                raise Exception(
                    "NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"
                )
        else:
            self._NNModule = C.NNModule()

    @property
    def dataFlow(self):
        return self._NNModule.dataFlow()

    @property
    def controlFlow(self):
        return self._NNModule.getExecutionOrder()

    @property
    def nodes(self):
        return self._NNModule.dataFlow().nodes

    @property
    def operators(self):
        return self._NNModule.dataFlow().operators

    @property
    def tensors(self):
        return self._NNModule.dataFlow().tensors

    def createNode(self, val):
        return self._NNModule.dataFlow().createNode(val)

    def deleteNode(self, node):
        return self._NNModule.dataFlow().deleteNode(node)

    def createEdge(self, a, b):
        return self._NNModule.dataFlow().createEdge(a, b)

    def deleteEdge(self, a, b=None):
        if b:
            self._NNModule.dataFlow().deleteEdge(a, b)
        else:
            self._NNModule.dataFlow().deleteEdge(a)

    def replaceNode(self, old_node, new_node):
        return self._NNModule.dataFlow().replaceNode(old_node, new_node)

    def replaceProducer(self, tensor, new_producer):
        C.replaceProducer(tensor, new_producer)

    def replaceAllUsesWith(self, old_tensor, new_tensor):
        C.replaceAllUsesWith(old_tensor, new_tensor)

    def replaceAsConsumer(self, old_consumer, new_consumer):
        C.replaceAsConsumer(old_consumer, new_consumer)

    def replaceSubgraph(self, subgraph, new_node, inputs, outputs):
        self._NNModule.replaceSubgraph(subgraph, new_node, inputs, outputs)

    def deleteSubgraph(self, subgraph):
        self._NNModule.deleteSubgraph(subgraph)

    def createUniqueDataNode(self, prefix="_unique"):
        return self._NNModule.createUniqueDataNode(prefix)

    def convertToCaffe2Proto(self, old_proto=None):
        if not old_proto:
            old_proto = caffe2_pb2.NetDef()
        output = self._NNModule.convertToCaffe2Proto(old_proto)
        new_proto = caffe2_pb2.NetDef()
        new_proto.ParseFromString(output)
        return new_proto

    def match(self, pattern):
        for n in self.dataFlow.getMutableNodes():
            m = C.matchSubgraph(n, pattern)
            if m:
                yield m


def render(s):
    s = str(s)
    cmd_exists = lambda x: any(
        os.access(os.path.join(path, x), os.X_OK)
        for path in os.getenv("PATH", "").split(os.pathsep)
    )
    if cmd_exists("graph-easy"):
        p = Popen("graph-easy", stdin=PIPE)
        try:
            p.stdin.write(s.encode("utf-8"))
        except IOError as e:
            if e.errno == errno.EPIPE or e.errno == errno.EINVAL:
                pass
            else:
                # Raise any other error.
                raise

        p.stdin.close()
        p.wait()
    else:
        print(s)


NeuralNetOperator = C.NeuralNetOperator
Operator = C.NeuralNetOperator
NeuralNetData = C.NeuralNetData
Data = C.NeuralNetData
NNSubgraph = C.NNSubgraph
NNMatchGraph = C.NNMatchGraph
Graph = C.Graph
Annotation = C.Annotation