Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / core.py

## @package core
# Module caffe2.python.core





from collections import namedtuple, OrderedDict, defaultdict
from past.builtins import basestring
from future.utils import viewitems, viewkeys, viewvalues
from itertools import chain
from six import binary_type, string_types, text_type

from caffe2.proto import caffe2_pb2
from caffe2.python import scope, utils, workspace
from caffe2.python.lazy import TriggerLazyImport
from caffe2.python.control_ops_grad import \
    gen_do_gradient, gen_if_gradient, gen_while_gradient, disambiguate_grad_if_op_output

import caffe2.python._import_c_extension as C

import copy
import pickle
import numpy as np
import sys
import traceback
import os

# Mac os specific message
if (sys.platform == 'darwin' and 'leveldb' in C.registered_dbs()):
    print('If you are using homebrew leveldb on a Mac OS, you might see an '
          'error warning you that malloc_zone_unregister() failed. This is '
          'not a caffe2 issue but is due to the homebrew leveldb having an '
          'incompatible memory allocator. It does not affect usage.')

# Convenience redirections to functions inside scope.
DeviceScope = scope.DeviceScope
NameScope = scope.NameScope


# Bring datatype enums to the main namespace
class DataType:
    pass


def _InitDataType():
    for name, value in caffe2_pb2.TensorProto.DataType.items():
        setattr(DataType, name, value)


_InitDataType()


def _GetRegisteredOperators():
    return set(workspace.RegisteredOperators())


_REGISTERED_OPERATORS = _GetRegisteredOperators()


def RefreshRegisteredOperators(trigger_lazy=True):
    if trigger_lazy:
        TriggerLazyImport()
    global _REGISTERED_OPERATORS
    _REGISTERED_OPERATORS = _GetRegisteredOperators()


_GLOBAL_INIT_ARGS = []


def GlobalInit(args):
    TriggerLazyImport()
    _GLOBAL_INIT_ARGS.extend(args[1:])
    C.global_init(args)


def GetGlobalInitArgs():
    return _GLOBAL_INIT_ARGS[:]


def IsOperator(op_type):
    return IsOperatorWithEngine(op_type, engine='DEFAULT')


def IsOperatorWithEngine(op_type, engine):
    TriggerLazyImport()
    return C.op_registry_key(op_type, engine) in _REGISTERED_OPERATORS


def IsGPUDeviceType(device_type):
    return device_type in {caffe2_pb2.CUDA, caffe2_pb2.HIP}


def DeviceOption(
    device_type,
    device_id=0,
    random_seed=None,
    node_name=None,
    numa_node_id=None,
    extra_info=None,
):
    option = caffe2_pb2.DeviceOption()
    option.device_type = device_type
    option.device_id = device_id
    if node_name is not None:
        option.node_name = node_name
    if random_seed is not None:
        option.random_seed = random_seed
    if numa_node_id is not None:
        assert device_type == caffe2_pb2.CPU
        option.numa_node_id = numa_node_id
    if extra_info is not None:
        option.extra_info.extend(extra_info)
    return option


def device_option_equal(opt1, opt2, ignore_node_name=True, ignore_random_seed=True):
    if not opt1 or not opt2:
        return opt1 == opt2
    if not ignore_node_name and opt1.node_name != opt2.node_name:
        return False
    if not ignore_random_seed and opt1.random_seed != opt2.random_seed:
        return False
    if not opt1.device_type or not opt2.device_type:
        # At least one option is for CPU, check if both are for CPU.
        return not opt1.device_type and not opt2.device_type
    return opt1.device_id == opt2.device_id


def InferBlobDevices(net):
    '''
    Compute mapping from parameters to devices by looking at the
    device option of the op that creates the blob has
    '''
    mapping = {}
    for op in net.Proto().op:
        op_device = op.device_option
        if op_device is None:
            op_device = caffe2_pb2.DeviceOption(caffe2_pb2.CPU)
        # TODO: T18892922, use device annotations
        for b in op.output:
            mapping[b] = op_device
    return mapping


def InferOpBlobDevicesAsDict(op):
    input_dev_list, output_dev_list = InferOpBlobDevices(op)
    input_dict = {
        op.input[i]: input_dev_list[i]
        for i in range(len(op.input))
    }
    output_dict = {
        op.output[i]: output_dev_list[i]
        for i in range(len(op.output))
    }
    return input_dict, output_dict


def InferOpBlobDevices(op):
    device_info = C.infer_op_input_output_device(op.SerializeToString())
    input_info = []
    output_info = []
    for dev_str in device_info[0]:
        device_option = caffe2_pb2.DeviceOption()
        device_option.ParseFromString(dev_str)
        input_info.append(device_option)
    for dev_str in device_info[1]:
        device_option = caffe2_pb2.DeviceOption()
        device_option.ParseFromString(dev_str)
        output_info.append(device_option)
    return input_info, output_info


def InferOpDeviceAsBlobDevices(op):
    op_dev = op.device_option if op.device_option else caffe2_pb2.DeviceOption()
    input_dev = [op_dev] * len(op.input)
    output_dev = [op_dev] * len(op.output)
    return input_dev, output_dev


GradientSlice = namedtuple('GradientSlice', ['indices', 'values'])


class BlobReference(object):
    """A wrapper around a blob in a net.

    BlobReference gives us a way to refer to the network that the blob is
    generated from. Note that blobs are, essentially, just strings in the
    current workspace.
    """

    def __init__(self, name, net=None):
        """Initializes a blob reference.

        Note that this does not prepends the namescope. If needed, use
        ScopedBlobReference() to prepend the existing namespace.
        """
        if isinstance(name, string_types):
            self._name = name
        elif isinstance(name, binary_type):
            self._name = name.decode('utf-8')
        else:
            self._name = str(name)
        self._from_net = net
        # meta allows helper functions to put whatever metainformation needed
        # there.
        self.meta = {}

    def __hash__(self):
        return hash(self._name)

    def __eq__(self, other):
        if isinstance(other, string_types):
            return self._name == other
        elif isinstance(other, binary_type):
            return self._name == other.decode('utf-8')
        elif isinstance(other, BlobReference):
            return self._name == other._name
        else:
            return False

    def __ne__(self, other):
        return not(self == other)

    def __str__(self):
        return self._name

    def __repr__(self):
        return 'BlobReference("{}")'.format(self._name)

    def __add__(self, other):
        if not isinstance(other, string_types):
            raise RuntimeError('Cannot add BlobReference to a non-string.')
        return BlobReference(self._name + other, self._from_net)

    def __radd__(self, other):
        if not isinstance(other, string_types):
            raise RuntimeError('Cannot add a non-string to BlobReference.')
        return BlobReference(other + self._name, self._from_net)

    def Net(self):
        return self._from_net

    def GetNameScope(self):
        return self._name[:self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1]

    def GetUnscopedName(self):
        return self._name[self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1:]

    def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs):
        """Internal function that routes the operator generation to the
        network's __getattr__ function.
        """
        inputs = [] if inputs is None else inputs
        if isinstance(inputs, BlobReference) or isinstance(inputs, string_types):
            inputs = [inputs]
        # add self to the input list.
        inputs.insert(0, self)
        return self._from_net.__getattr__(op_type)(inputs, *args, **kwargs)

    def __getattr__(self, op_type):
        """A wrapper allowing one to initiate operators from a blob reference.

        Example: for a blob reference b that comes from network n, doing
            b.Relu(...)
        is equivalent to doing
            net.Relu([b], ...)
        """
        if op_type.startswith('__'):
            raise AttributeError('Attribute {} not found.'.format(op_type))
        if self._from_net is None:
            raise AttributeError(
                'You cannot use a blob reference that does not have a net '
                'source to create operators. Create the operator from an '
                'explicit net object.')
        if not IsOperator(op_type):
            raise AttributeError(
                'Method ' + op_type + ' is not a registered operator.' +
                ' Did you mean: [' +
                ",".join(workspace.C.nearby_opnames(op_type)) + ']'
            )
        return lambda *args, **kwargs: self._CreateAndAddToNet(
            op_type, *args, **kwargs)

    def __dir__(self):
        TriggerLazyImport()
        additional_methods = [
            op
            for op in _REGISTERED_OPERATORS
            if '_ENGINE_' not in op or '_ENGINE_CUDNN' in op]
        return sorted(set(chain(
            dir(type(self)),
            viewkeys(self.__dict__),
            additional_methods
        )))


def ScopedName(name):
    """prefix the name with the current scope."""
    if isinstance(name, binary_type):
        name = name.decode('ascii')
    return scope.CurrentNameScope() + name


def ScopedBlobReference(name, *args, **kwargs):
    """Returns a blob reference with scope prefixed."""
    return BlobReference(ScopedName(name), *args, **kwargs)


def _RectifyInputOutput(blobs, net=None):
    """A helper function to rectify the input or output of the CreateOperator
    interface.
    """
    if isinstance(blobs, string_types) or isinstance(blobs, binary_type):
        # If blobs is a single string, prepend scope.CurrentNameScope()
        # and put it as a list.
        # TODO(jiayq): enforce using BlobReference instead of raw strings.
        return [ScopedBlobReference(blobs, net=net)]
    elif type(blobs) is BlobReference:
        # If blob is a BlobReference, simply put it as a list.
        return [blobs]
    elif type(blobs) in (list, tuple):
        # If blob is a list, we go through it and type check.
        rectified = []
        for blob in blobs:
            if isinstance(blob, string_types) or isinstance(blob, binary_type):
                rectified.append(ScopedBlobReference(blob, net=net))
            elif type(blob) is BlobReference:
                rectified.append(blob)
            else:
                raise TypeError(
                    "I/O blob #{} of unsupported type: {} of type {}"
                    .format(len(rectified), str(blob), type(blob)))
        return rectified
    else:
        raise TypeError(
            "Unknown input/output type: %s of type %s." %
            (str(blobs), type(blobs))
        )


def CreateOperator(
    operator_type,
    inputs,
    outputs,
Loading ...