Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ quantization / server / utils.py



import copy
import logging
from collections import defaultdict

import numpy as np
from caffe2.python import core, utils
from caffe2.python.fb import hardcode_scale_zp  # type: ignore[import]


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    from itertools import tee

    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)


def blob_uses(net, blob):
    u = []
    for i, op in enumerate(net.op):
        if blob in op.input or blob in op.control_input:
            u.append(i)
    return u


def fuse_first_bn(net, params, removed_tensors, begin_op_index):
    net = copy.deepcopy(net)
    params = copy.deepcopy(params)

    for i, conv in enumerate(net.op[begin_op_index:], begin_op_index):
        if conv.type not in ["Conv", "ConvTranspose"]:
            continue

        uses = blob_uses(net, conv.output[0])
        if len(uses) == 0:
            continue

        j = uses[0]
        bn = net.op[j]
        if bn.type != "SpatialBN" or (len(uses) > 1 and conv.output[0] != bn.output[0]):
            if bn.type == "SpatialBN":
                logger.debug("Can't fuse if more than one user {}".format(uses))
            # Can't fuse if more than one user unless SpatialBN is inplace
            # An example of inplace SpatialBN where we want to allow multiple uses:
            # x = Conv(...)
            # ... // no interferring use or def of x (will be checked below)
            # x = SpatialBN(x, ...)
            # ...
            # z = Foo(..., x, ...)
            # ...
            # w = Boo(..., x, ...)
            # Here, we still want to fuse Conv and SpatialBN
            continue

        # There shouldn't be any def of conv.output[0] and any use or def of bn.output[0] between conv and bn
        if any(
            blob in net.op[k].input or blob in net.op[k].output
            for blob in [conv.output[0], bn.output[0]]
            for k in range(i + 1, j)
        ):
            logger.debug(
                "Can't fuse because of the following interferring uses or defs:"
            )
            for k in range(i, j + 1):
                logger.debug(net.op[k])
            continue

        # else, can fuse
        fused_conv = copy.deepcopy(conv)
        fused_conv.output[0] = bn.output[0]
        conv_weight = params[conv.input[1]]
        if len(conv.input) > 2:
            conv_bias = params[conv.input[2]]
        else:
            conv_bias = np.zeros(len(params[bn.input[2]])).astype(np.float32)

        bn_scale = params[bn.input[1]]
        bn_bias = params[bn.input[2]]
        bn_running_mean = params[bn.input[3]]
        bn_running_var = params[bn.input[4]]

        # First, BN computation can be phrased as follows:
        # (X - running_mean) * (1.0 / sqrt(running_var + eps)) *
        # bn_scale + bias
        # Thus, we can rewrite bn_scale as:
        # X * bn_scale * 1.0 / (sqrt(running_var + eps)) + (bias -
        # running_mean * (1.0 / sqrt(running_var + eps)) * bn_scale)
        # Thus, can just have the affine transform
        # X * A + B
        # where
        # A = bn_scale * 1.0 / (sqrt(running_var + eps))
        # B =  (bias - running_mean * (1.0 / sqrt(running_var + eps))
        # * bn_scale)
        eps = 1.0e-5
        for arg in bn.arg:
            if arg.name == "epsilon":
                eps = arg.f
        A = bn_scale * 1.0 / (np.sqrt(bn_running_var + eps))
        B = bn_bias - bn_running_mean * A

        # This identity should hold if we have correctly fused
        # np.testing.assert_array_equal(
        #     params[conv.output[0]] * A + B,
        #     params[bn.output[0]])

        # Now, we have that the computation made is the following:
        # ((X `conv` W) + b) * A + B
        # Then, we can simply fuse this as follows:
        # (X `conv` (W * A)) + b * A + B
        # which is simply
        # (X `conv` Q) + C
        # where

        # Q = W * A
        # C = b * A + B

        # For ConvTranspose, from the view of convolutions as a
        # Toepeliz multiplication, we have W_ = W^T, so the weights
        # are laid out as (R, S, K, K) (vs (S, R, K, K) for a Conv),
        # so the weights broadcast slightly differently. Remember, our
        # BN scale 'B' is of size (S,)

        A_ = (
            A.reshape((-1,) + tuple([1] * (conv_weight.ndim - 1)))
            if conv.type == "Conv"
            else A.reshape((1, -1) + tuple([1] * (conv_weight.ndim - 2)))
        )

        C = conv_bias * A + B
        Q = conv_weight * A_

        assert params[conv.input[1]].shape == Q.shape
        if len(conv.input) > 2:
            assert params[conv.input[2]].shape == C.shape
        else:
            assert bn_bias.shape == C.shape

        params[conv.input[1]] = Q
        if len(conv.input) > 2:
            params[conv.input[2]] = C
        else:
            params[bn.input[2]] = C
            fused_conv.input.append(bn.input[2])

        new_ops = net.op[:i] + [fused_conv] + net.op[i + 1 : j] + net.op[j + 1 :]
        del net.op[:]
        removed_tensors.append(bn.input[1])
        if len(conv.input) > 2:
            removed_tensors.append(bn.input[2])
        removed_tensors.append(bn.input[3])
        removed_tensors.append(bn.input[4])
        del params[bn.input[1]]
        if len(conv.input) > 2:
            del params[bn.input[2]]
        del params[bn.input[3]]
        del params[bn.input[4]]
        net.op.extend(new_ops)
        return net, params, removed_tensors, i + 1

    return net, params, removed_tensors, None


def fuse_bn(net, params, ignore_failure):
    # Run until we hit a fixed point
    removed_tensors = []
    begin_op_index = 0
    while True:
        (next_net, next_params, removed_tensors, begin_op_index) = fuse_first_bn(
            net, params, removed_tensors, begin_op_index
        )
        if begin_op_index is None:
            if any(op.type == "SpatialBN" for op in next_net.op) and not ignore_failure:
                raise Exception(
                    "Model contains SpatialBN op after fusion: %s", next_net
                )
            return (next_net, next_params, removed_tensors)
        net, params, removed_tensors = (next_net, next_params, removed_tensors)


def fuse_first_scale(net, params, removed_tensors):
    net = copy.deepcopy(net)
    params = copy.deepcopy(params)

    for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
        if next_.input[0] != current.output[0]:
            continue

        if (
            current.type != "SpatialBN"
            or next_.type != "Mul"
            or len(net.op) <= j + 1
            or net.op[j + 1].type != "Add"
        ):
            continue

        # else, can fuse
        bn = current
        mul = next_
        add = net.op[j + 1]

        fused_bn = copy.deepcopy(bn)
        fused_bn.output[0] = add.output[0]
        bn_scale = params[bn.input[1]]
        mul_scale = params[mul.input[1]]
        bn_bias = params[bn.input[2]]
        add_bias = params[add.input[1]]

        params[bn.input[1]] = bn_scale * mul_scale
        params[bn.input[2]] = mul_scale * bn_bias + add_bias

        new_ops = net.op[:i] + [fused_bn] + net.op[j + 2 :]
        del net.op[:]
        removed_tensors.append(mul.input[1])
        removed_tensors.append(add.input[1])
        del params[mul.input[1]]
        del params[add.input[1]]
        net.op.extend(new_ops)
        break
    return net, params, removed_tensors


def fuse_scale(net, params, ignore_failure):
    # Run until we hit a fixed point
    removed_tensors = []
    while True:
        (next_net, next_params, removed_tensors) = fuse_first_scale(
            net, params, removed_tensors
        )
        if len(next_net.op) == len(net.op):
            return (next_net, next_params, removed_tensors)
        net, params, removed_tensors = (next_net, next_params, removed_tensors)


def fuse_first_relu(net, begin_op_index, ignore_op_with_output=None):
    net = copy.deepcopy(net)

    for i, conv in enumerate(net.op[begin_op_index:], begin_op_index):
        if conv.type not in ["Conv", "ConvTranspose", "Sum", "SpatialBN"]:
            continue

        uses = blob_uses(net, conv.output[0])
        if (
            len(uses) == 0
            or ignore_op_with_output
            and conv.output[0] in ignore_op_with_output
        ):
            continue

        j = uses[0]
        relu = net.op[j]
        if relu.type != "Relu" or len(uses) > 1 and conv.output[0] != relu.output[0]:
            # Can't fuse if more than one user unless Relu is inplace
            if relu.type == "Relu":
                logger.debug("Can't fuse if more than one user {}".format(uses))
            continue

        # There shouldn't be any def of conv.output[0] and any use or def of relu.output[0] between conv and relu
        if any(
            blob in net.op[k].input or blob in net.op[k].output
            for blob in [conv.output[0], relu.output[0]]
            for k in range(i + 1, j)
        ):
            logger.debug(
                "Can't fuse because of the following interferring uses or defs:"
            )
            for k in range(i, j + 1):
                logger.debug(net.op[k])
            continue

        # else, can fuse
        fused_conv = copy.deepcopy(conv)
        fused_conv.type = conv.type + "Relu"
        fused_conv.output[0] = relu.output[0]

        new_ops = net.op[:i] + [fused_conv] + net.op[i + 1 : j] + net.op[j + 1 :]
        del net.op[:]
        net.op.extend(new_ops)
        return net, i + 1
    return net, None


def fuse_relu(net, ignore_failure, ignore_op_with_output=None):
    # Run until we hit a fixed point
    begin_op_index = 0
    while True:
        next_net, begin_op_index = fuse_first_relu(
            net, begin_op_index, ignore_op_with_output
        )
        if begin_op_index is None:
            if any(op.type == "Relu" for op in next_net.op) and not ignore_failure:
                raise Exception("Model contains Relu op after fusion: %s", next_net)
            return next_net
        net = next_net


def last_producer(ops, blob):
    for (i, op) in reversed(list(enumerate(ops))):
        if op.output[0] == blob:
            return i
    raise ValueError("Failed to find last producer of blob, %s", blob)


def swap_first_concat_relu(net, ignore_op_with_output=None):
    net = copy.deepcopy(net)

    for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
        if next_.input[0] != current.output[0]:
            continue

        if current.type != "Concat" or next_.type != "Relu":
            continue

        if ignore_op_with_output and current.output[0] in ignore_op_with_output:
            continue

        # else, can swap
        concat = copy.deepcopy(current)
        relu = copy.deepcopy(next_)
        pre_ops = copy.deepcopy(net.op[:i])
        post_ops = copy.deepcopy(net.op[j + 1 :])

        # Delete the Relu after Concat
        concat.output[0] = relu.output[0]

        # Insert Relu after each op that produces inputs to Concat
        for blob in concat.input:
            k = last_producer(pre_ops, blob)
            producer = pre_ops[k]
            assert producer.output[0] == blob
            producer.output[0] = blob + "_pre_relu"

            new_relu = copy.deepcopy(relu)
            new_relu.input[0] = producer.output[0]
            new_relu.output[0] = blob

            pre_ops = pre_ops[: k + 1] + [new_relu] + pre_ops[k + 1 :]

        new_ops = pre_ops + [concat] + post_ops
Loading ...