Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / ideep / transform_ideep_net.py






import argparse
import copy
import json

import numpy as np

from caffe2.proto import caffe2_pb2
from caffe2.python import core, workspace, utils
import caffe2.python._import_c_extension as C



def pairwise(iterable):
    from itertools import tee
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)


def last_producer(ops, blob):
    for (i, op) in reversed(list(enumerate(ops))):
        if blob in op.output:
            return i
    raise ValueError("Failed to find last producer of blob, %s", blob)


def blob_uses(net, blob):
    u = []
    for i, op in enumerate(net.op):
        if blob in op.input or blob in op.control_input:
            u.append(i)
    return u


def GetArgumentParser():
    parser = argparse.ArgumentParser(description="Caffe2 optimization")
    parser.add_argument("--init_net",
                        type=argparse.FileType('rb'),
                        help="init net")
    parser.add_argument("--pred_net",
                        type=argparse.FileType('rb'),
                        help="predict net")
    parser.add_argument("--verify_input",
                        type=argparse.FileType('r'),
                        help="input dims for verification")
    parser.add_argument("--fuse_bn", default=False, action='store_true')
    parser.add_argument("--fuse_mul_add", default=False, action='store_true')
    parser.add_argument("--fuse_conv_relu", default=False, action='store_true')
    return parser


def fuse_first_bn(net, params, removed_tensors):
    net = copy.deepcopy(net)
    params = copy.deepcopy(params)

    for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
        if next_.input[0] != current.output[0]:
            continue

        if current.type not in ("Conv", "ConvTranspose") \
           or next_.type != "SpatialBN":
            continue
        if len(blob_uses(net, current.output[0])) != 1:
            # Can't fuse if more than one user
            continue

        # else, can fuse
        conv = current
        bn = next_
        fused_conv = copy.deepcopy(conv)
        fused_conv.output[0] = bn.output[0]

        # Fix fused_conv to ensure we have a bias passed.
        if len(fused_conv.input) != 3:
            bias_name = "{}_bias".format(conv.input[1])
            net.external_input.extend([bias_name])
            fused_conv.input.extend([bias_name])
            for arg in fused_conv.arg:
                if arg.name == "no_bias":
                    arg.i = 0

        conv_weight = params[conv.input[1]]
        conv_bias = params[conv.input[2]] if len(conv.input) == 3 \
            else np.zeros(shape=(conv_weight.shape[0])).astype(np.float32)

        bn_scale = params[bn.input[1]]
        bn_bias = params[bn.input[2]]
        bn_running_mean = params[bn.input[3]]
        bn_running_var = params[bn.input[4]]

        # First, BN computation can be phrased as follows:
        # (X - running_mean) * (1.0 / sqrt(running_var + eps)) *
        # bn_scale + bias
        # Thus, we can rewrite bn_scale as:
        # X * bn_scale * 1.0 / (sqrt(running_var + eps)) + (bias -
        # running_mean * (1.0 / sqrt(running_var + eps)) * bn_scale)
        # Thus, can just have the affine transform
        # X * A + B
        # where
        # A = bn_scale * 1.0 / (sqrt(running_var + eps))
        # B =  (bias - running_mean * (1.0 / sqrt(running_var + eps))
        # * bn_scale)
        eps = 1.0e-5
        for arg in bn.arg:
            if arg.name == "epsilon":
                eps = arg.f
        A = bn_scale * 1.0 / (np.sqrt(bn_running_var + eps))
        B = bn_bias - bn_running_mean * A

        # This identify should hold if we have correctly fused
        # np.testing.assert_array_equal(
        #     params[conv.output[0]] * A + B,
        #     params[bn.output[0]])

        # Now, we have that the computation made is the following:
        # ((X `conv` W) + b) * A + B
        # Then, we can simply fuse this as follows:
        # (X `conv` (W * A)) + b * A + B
        # which is simply
        # (X `conv` Q) + C
        # where

        # Q = W * A
        # C = b * A + B

        # For ConvTranspose, from the view of convolutions as a
        # Toepeliz multiplication, we have W_ = W^T, so the weights
        # are laid out as (R, S, K, K) (vs (S, R, K, K) for a Conv),
        # so the weights broadcast slightly differently. Remember, our
        # BN scale 'B' is of size (S,)

        A_ = A.reshape(-1, 1, 1, 1) if conv.type == "Conv" else \
            A.reshape(1, -1, 1, 1)

        C = conv_bias * A + B
        Q = conv_weight * A_

        params[fused_conv.input[1]] = Q
        params[fused_conv.input[2]] = C
        new_ops = net.op[:i] + [fused_conv] + net.op[j + 1:]
        del net.op[:]
        removed_tensors.append(bn.input[1])
        removed_tensors.append(bn.input[2])
        removed_tensors.append(bn.input[3])
        removed_tensors.append(bn.input[4])
        del params[bn.input[1]]
        del params[bn.input[2]]
        del params[bn.input[3]]
        del params[bn.input[4]]
        net.op.extend(new_ops)
        break
    return net, params, removed_tensors


def fuse_bn(net, params, ignore_failure):
    # Run until we hit a fixed point
    removed_tensors = []
    while True:
        (next_net, next_params, removed_tensors) = \
            fuse_first_bn(net, params, removed_tensors)
        if len(next_net.op) == len(net.op):
            if (
                any(op.type == "SpatialBN" for op in next_net.op) and
                not ignore_failure
            ):
                raise Exception(
                    "Model contains SpatialBN op after fusion: %s", next_net)
            return (next_net, next_params, removed_tensors)
        net, params, removed_tensors = (next_net, next_params, removed_tensors)


def fuse_first_mul_add(net, params, removed_tensors):
    net = copy.deepcopy(net)
    params = copy.deepcopy(params)

    for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
        if current.type != "Mul" or next_.type != "Add":
            continue

        if next_.input[0] != current.output[0]:
            raise Exception("Failure to fuse")

        if len(blob_uses(net, current.output[0])) != 1:
            raise Exception("Failure to fuse")

        log.info("Fusing at index %s", i)
        mul_ = current
        add_ = next_
        batch_norm = copy.deepcopy(mul_)
        batch_norm.type = "SpatialBN"
        batch_norm.arg.extend([utils.MakeArgument("is_test", 1)])
        batch_norm.arg.extend([utils.MakeArgument("epsilon", float(1e-9))])

        def s(x):
            return "{}{}".format(add_.output[0], x)
        fake_mean = s("_mean")
        fake_var = s("_var")

        del batch_norm.input[:]
        batch_norm.input.extend([mul_.input[0],
                                 mul_.input[1],
                                 add_.input[1],
                                 fake_mean,
                                 fake_var])
        params[fake_mean] = np.zeros_like(params[mul_.input[1]])
        params[fake_var] = np.ones_like(params[mul_.input[1]])
        net.external_input.extend([fake_mean, fake_var])

        batch_norm.output[0] = add_.output[0]
        new_ops = net.op[:i] + [batch_norm] + net.op[j + 1:]
        del net.op[:]
        net.op.extend(new_ops)
        break
    return net, params, removed_tensors


def fuse_mul_add(net, params):
    # Run until we hit a fixed point
    removed_tensors = []
    while True:
        (next_net, next_params, removed_tensors) = \
            fuse_first_mul_add(net, params, removed_tensors)
        if len(next_net.op) == len(net.op):
            return (next_net, next_params, removed_tensors)
        net, params, removed_tensors = (next_net, next_params, removed_tensors)


def add_tensor(net, name, blob):
    ''' Create an operator to store the tensor 'blob',
        run the operator to put the blob to workspace.
        uint8 is stored as an array of string with one element.
    '''
    kTypeNameMapper = {
        np.dtype('float32'): "GivenTensorFill",
        np.dtype('int32'): "GivenTensorIntFill",
        np.dtype('int64'): "GivenTensorInt64Fill",
        np.dtype('uint8'): "GivenTensorStringFill",
    }

    shape = blob.shape
    values = blob
    # pass array of uint8 as a string to save storage
    # storing uint8_t has a large overhead for now
    if blob.dtype == np.dtype('uint8'):
        shape = [1]
        values = [str(blob.data)]

    op = core.CreateOperator(
        kTypeNameMapper[blob.dtype],
        [], [name],
        arg=[
            utils.MakeArgument("shape", shape),
            utils.MakeArgument("values", values),
        ]
    )
    net.op.extend([op])


def gen_init_net_from_blobs(blobs):
    ''' Generate an initialization net based on a blob dict '''
    ret = caffe2_pb2.NetDef()
    for name, blob in blobs.items():
        add_tensor(ret, name, blob)
    return ret


def fuse_conv_relu(net):
    net = copy.deepcopy(net)
    device_option = core.DeviceOption(caffe2_pb2.IDEEP)
    for op in net.op:
        op.device_option.CopyFrom(device_option)

    new_net = caffe2_pb2.NetDef()
    new_net.ParseFromString(C.transform_optimizeForMKLDNN(net.SerializeToString()))
    return new_net


def Optimize(args):
    init_net = caffe2_pb2.NetDef()
    predict_net = caffe2_pb2.NetDef()
    init_net.ParseFromString(args.init_net.read())
    predict_net.ParseFromString(args.pred_net.read())

    workspace.ResetWorkspace()
    workspace.RunNetOnce(init_net)
    param_dict = {p: workspace.FetchBlob(p) for p in workspace.Blobs()}

    external_inputs = {}
    external_outputs = {}
    if args.verify_input:
        value_info = json.load(args.verify_input)
        input_shapes = {k : v[-1] for (k, v) in value_info.items()}
        print("input info: {}".format(input_shapes))
        for k, v in input_shapes.items():
            external_inputs[k] = np.random.randn(*v).astype(np.float32)
            workspace.FeedBlob(k, external_inputs[k])
        workspace.RunNetOnce(predict_net)
        for o in predict_net.external_output:
            external_outputs[o] = workspace.FetchBlob(o)

    if args.fuse_mul_add:
        predict_net, param_dict, _ = fuse_mul_add(predict_net, param_dict)
    if args.fuse_bn:
        predict_net, param_dict, _ = fuse_bn(predict_net, param_dict, False)
    if args.fuse_conv_relu:
        predict_net = fuse_conv_relu(predict_net)

    external_outputs_opt = {}
    if args.verify_input:
        workspace.ResetWorkspace()
        device_option = core.DeviceOption(caffe2_pb2.IDEEP) if args.fuse_conv_relu else core.DeviceOption(caffe2_pb2.CPU)
        with core.DeviceScope(device_option):
            for k, v in param_dict.items():
                workspace.FeedBlob(k, v, device_option)
            for k, v in external_inputs.items():
                workspace.FeedBlob(k, v, device_option)
            workspace.RunNetOnce(predict_net)
            for o in predict_net.external_output:
                external_outputs_opt[o] = workspace.FetchBlob(o)
                assert np.allclose(external_outputs[o],
                                   external_outputs_opt[o],
                                   atol=1e-3,
                                   rtol=1e-3)

    for i, o in enumerate(predict_net.op):
        print("op[{}]: {}".format(i, o.type))
    init_net = gen_init_net_from_blobs(param_dict)
    with open('init_net.pb', 'wb') as f:
        f.write(init_net.SerializeToString())
    with open('predict_net.pb', 'wb') as f:
        f.write(predict_net.SerializeToString())

if __name__ == '__main__':
    args = GetArgumentParser().parse_args()
    Optimize(args)