Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / onnx / backend_rep.py

# @package onnx
# Module caffe2.python.onnx.backend_rep





from caffe2.python import core
from caffe2.proto import caffe2_pb2
from onnx.backend.base import BackendRep, namedtupledict

class Caffe2Rep(BackendRep):
    def __init__(self, init_net, predict_net, workspace, uninitialized):
        super(Caffe2Rep, self).__init__()
        self.init_net = init_net
        self.predict_net = predict_net
        self.workspace = workspace
        # The list of uninitialized external_inputs in workspace, we need this to
        # pair the name with given sequence inputs.
        self.uninitialized = uninitialized
        self.nets_created = False
        self.ran_init_net = False

    @property
    def _name_scope(self):
        if self.predict_net.device_option.device_type == caffe2_pb2.CUDA:
            return 'gpu_{}'.format(self.predict_net.device_option.device_id)
        return ''

    def run(self, inputs, **kwargs):
        super(Caffe2Rep, self).run(inputs, **kwargs)
        with core.DeviceScope(self.predict_net.device_option):
            if isinstance(inputs, dict):
                with core.NameScope(self._name_scope):
                    for key, value in inputs.items():
                        self.workspace.FeedBlob(key, value)
            elif isinstance(inputs, list) or isinstance(inputs, tuple):
                if len(self.uninitialized) != len(inputs):
                    raise RuntimeError('Expected {} values for uninitialized '
                                       'graph inputs ({}), but got {}.'.format(
                                           len(self.uninitialized),
                                           ', '.join(self.uninitialized),
                                           len(inputs)))
                for i, value in enumerate(inputs):
                    # namescope already baked into protobuf
                    self.workspace.FeedBlob(self.uninitialized[i], value)
            else:
                # single input
                self.workspace.FeedBlob(self.uninitialized[0], inputs)
            if not self.nets_created:
                self.workspace.CreateNet(self.init_net)
                self.workspace.CreateNet(self.predict_net)
                self.nets_created = True
            if not self.ran_init_net:
                self.workspace.RunNet(self.init_net.name)
                self.ran_init_net = True
            self.workspace.RunNet(self.predict_net.name)
        output_values = []
        for name in self.predict_net.external_output:
            try:
                output_values.append(self.workspace.FetchBlob(name))
            except Exception:
                output_values.append(self.workspace.FetchInt8Blob(name))
        return namedtupledict('Outputs',
                              self.predict_net.external_output)(*output_values)