import copy
import logging
from collections import defaultdict
import numpy as np
from caffe2.python import core, utils
from caffe2.python.fb import hardcode_scale_zp # type: ignore[import]
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
from itertools import tee
a, b = tee(iterable)
next(b, None)
return zip(a, b)
def blob_uses(net, blob):
u = []
for i, op in enumerate(net.op):
if blob in op.input or blob in op.control_input:
u.append(i)
return u
def fuse_first_bn(net, params, removed_tensors, begin_op_index):
net = copy.deepcopy(net)
params = copy.deepcopy(params)
for i, conv in enumerate(net.op[begin_op_index:], begin_op_index):
if conv.type not in ["Conv", "ConvTranspose"]:
continue
uses = blob_uses(net, conv.output[0])
if len(uses) == 0:
continue
j = uses[0]
bn = net.op[j]
if bn.type != "SpatialBN" or (len(uses) > 1 and conv.output[0] != bn.output[0]):
if bn.type == "SpatialBN":
logger.debug("Can't fuse if more than one user {}".format(uses))
# Can't fuse if more than one user unless SpatialBN is inplace
# An example of inplace SpatialBN where we want to allow multiple uses:
# x = Conv(...)
# ... // no interferring use or def of x (will be checked below)
# x = SpatialBN(x, ...)
# ...
# z = Foo(..., x, ...)
# ...
# w = Boo(..., x, ...)
# Here, we still want to fuse Conv and SpatialBN
continue
# There shouldn't be any def of conv.output[0] and any use or def of bn.output[0] between conv and bn
if any(
blob in net.op[k].input or blob in net.op[k].output
for blob in [conv.output[0], bn.output[0]]
for k in range(i + 1, j)
):
logger.debug(
"Can't fuse because of the following interferring uses or defs:"
)
for k in range(i, j + 1):
logger.debug(net.op[k])
continue
# else, can fuse
fused_conv = copy.deepcopy(conv)
fused_conv.output[0] = bn.output[0]
conv_weight = params[conv.input[1]]
if len(conv.input) > 2:
conv_bias = params[conv.input[2]]
else:
conv_bias = np.zeros(len(params[bn.input[2]])).astype(np.float32)
bn_scale = params[bn.input[1]]
bn_bias = params[bn.input[2]]
bn_running_mean = params[bn.input[3]]
bn_running_var = params[bn.input[4]]
# First, BN computation can be phrased as follows:
# (X - running_mean) * (1.0 / sqrt(running_var + eps)) *
# bn_scale + bias
# Thus, we can rewrite bn_scale as:
# X * bn_scale * 1.0 / (sqrt(running_var + eps)) + (bias -
# running_mean * (1.0 / sqrt(running_var + eps)) * bn_scale)
# Thus, can just have the affine transform
# X * A + B
# where
# A = bn_scale * 1.0 / (sqrt(running_var + eps))
# B = (bias - running_mean * (1.0 / sqrt(running_var + eps))
# * bn_scale)
eps = 1.0e-5
for arg in bn.arg:
if arg.name == "epsilon":
eps = arg.f
A = bn_scale * 1.0 / (np.sqrt(bn_running_var + eps))
B = bn_bias - bn_running_mean * A
# This identity should hold if we have correctly fused
# np.testing.assert_array_equal(
# params[conv.output[0]] * A + B,
# params[bn.output[0]])
# Now, we have that the computation made is the following:
# ((X `conv` W) + b) * A + B
# Then, we can simply fuse this as follows:
# (X `conv` (W * A)) + b * A + B
# which is simply
# (X `conv` Q) + C
# where
# Q = W * A
# C = b * A + B
# For ConvTranspose, from the view of convolutions as a
# Toepeliz multiplication, we have W_ = W^T, so the weights
# are laid out as (R, S, K, K) (vs (S, R, K, K) for a Conv),
# so the weights broadcast slightly differently. Remember, our
# BN scale 'B' is of size (S,)
A_ = (
A.reshape((-1,) + tuple([1] * (conv_weight.ndim - 1)))
if conv.type == "Conv"
else A.reshape((1, -1) + tuple([1] * (conv_weight.ndim - 2)))
)
C = conv_bias * A + B
Q = conv_weight * A_
assert params[conv.input[1]].shape == Q.shape
if len(conv.input) > 2:
assert params[conv.input[2]].shape == C.shape
else:
assert bn_bias.shape == C.shape
params[conv.input[1]] = Q
if len(conv.input) > 2:
params[conv.input[2]] = C
else:
params[bn.input[2]] = C
fused_conv.input.append(bn.input[2])
new_ops = net.op[:i] + [fused_conv] + net.op[i + 1 : j] + net.op[j + 1 :]
del net.op[:]
removed_tensors.append(bn.input[1])
if len(conv.input) > 2:
removed_tensors.append(bn.input[2])
removed_tensors.append(bn.input[3])
removed_tensors.append(bn.input[4])
del params[bn.input[1]]
if len(conv.input) > 2:
del params[bn.input[2]]
del params[bn.input[3]]
del params[bn.input[4]]
net.op.extend(new_ops)
return net, params, removed_tensors, i + 1
return net, params, removed_tensors, None
def fuse_bn(net, params, ignore_failure):
# Run until we hit a fixed point
removed_tensors = []
begin_op_index = 0
while True:
(next_net, next_params, removed_tensors, begin_op_index) = fuse_first_bn(
net, params, removed_tensors, begin_op_index
)
if begin_op_index is None:
if any(op.type == "SpatialBN" for op in next_net.op) and not ignore_failure:
raise Exception(
"Model contains SpatialBN op after fusion: %s", next_net
)
return (next_net, next_params, removed_tensors)
net, params, removed_tensors = (next_net, next_params, removed_tensors)
def fuse_first_scale(net, params, removed_tensors):
net = copy.deepcopy(net)
params = copy.deepcopy(params)
for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
if next_.input[0] != current.output[0]:
continue
if (
current.type != "SpatialBN"
or next_.type != "Mul"
or len(net.op) <= j + 1
or net.op[j + 1].type != "Add"
):
continue
# else, can fuse
bn = current
mul = next_
add = net.op[j + 1]
fused_bn = copy.deepcopy(bn)
fused_bn.output[0] = add.output[0]
bn_scale = params[bn.input[1]]
mul_scale = params[mul.input[1]]
bn_bias = params[bn.input[2]]
add_bias = params[add.input[1]]
params[bn.input[1]] = bn_scale * mul_scale
params[bn.input[2]] = mul_scale * bn_bias + add_bias
new_ops = net.op[:i] + [fused_bn] + net.op[j + 2 :]
del net.op[:]
removed_tensors.append(mul.input[1])
removed_tensors.append(add.input[1])
del params[mul.input[1]]
del params[add.input[1]]
net.op.extend(new_ops)
break
return net, params, removed_tensors
def fuse_scale(net, params, ignore_failure):
# Run until we hit a fixed point
removed_tensors = []
while True:
(next_net, next_params, removed_tensors) = fuse_first_scale(
net, params, removed_tensors
)
if len(next_net.op) == len(net.op):
return (next_net, next_params, removed_tensors)
net, params, removed_tensors = (next_net, next_params, removed_tensors)
def fuse_first_relu(net, begin_op_index, ignore_op_with_output=None):
net = copy.deepcopy(net)
for i, conv in enumerate(net.op[begin_op_index:], begin_op_index):
if conv.type not in ["Conv", "ConvTranspose", "Sum", "SpatialBN"]:
continue
uses = blob_uses(net, conv.output[0])
if (
len(uses) == 0
or ignore_op_with_output
and conv.output[0] in ignore_op_with_output
):
continue
j = uses[0]
relu = net.op[j]
if relu.type != "Relu" or len(uses) > 1 and conv.output[0] != relu.output[0]:
# Can't fuse if more than one user unless Relu is inplace
if relu.type == "Relu":
logger.debug("Can't fuse if more than one user {}".format(uses))
continue
# There shouldn't be any def of conv.output[0] and any use or def of relu.output[0] between conv and relu
if any(
blob in net.op[k].input or blob in net.op[k].output
for blob in [conv.output[0], relu.output[0]]
for k in range(i + 1, j)
):
logger.debug(
"Can't fuse because of the following interferring uses or defs:"
)
for k in range(i, j + 1):
logger.debug(net.op[k])
continue
# else, can fuse
fused_conv = copy.deepcopy(conv)
fused_conv.type = conv.type + "Relu"
fused_conv.output[0] = relu.output[0]
new_ops = net.op[:i] + [fused_conv] + net.op[i + 1 : j] + net.op[j + 1 :]
del net.op[:]
net.op.extend(new_ops)
return net, i + 1
return net, None
def fuse_relu(net, ignore_failure, ignore_op_with_output=None):
# Run until we hit a fixed point
begin_op_index = 0
while True:
next_net, begin_op_index = fuse_first_relu(
net, begin_op_index, ignore_op_with_output
)
if begin_op_index is None:
if any(op.type == "Relu" for op in next_net.op) and not ignore_failure:
raise Exception("Model contains Relu op after fusion: %s", next_net)
return next_net
net = next_net
def last_producer(ops, blob):
for (i, op) in reversed(list(enumerate(ops))):
if op.output[0] == blob:
return i
raise ValueError("Failed to find last producer of blob, %s", blob)
def swap_first_concat_relu(net, ignore_op_with_output=None):
net = copy.deepcopy(net)
for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
if next_.input[0] != current.output[0]:
continue
if current.type != "Concat" or next_.type != "Relu":
continue
if ignore_op_with_output and current.output[0] in ignore_op_with_output:
continue
# else, can swap
concat = copy.deepcopy(current)
relu = copy.deepcopy(next_)
pre_ops = copy.deepcopy(net.op[:i])
post_ops = copy.deepcopy(net.op[j + 1 :])
# Delete the Relu after Concat
concat.output[0] = relu.output[0]
# Insert Relu after each op that produces inputs to Concat
for blob in concat.input:
k = last_producer(pre_ops, blob)
producer = pre_ops[k]
assert producer.output[0] == blob
producer.output[0] = blob + "_pre_relu"
new_relu = copy.deepcopy(relu)
new_relu.input[0] = producer.output[0]
new_relu.output[0] = blob
pre_ops = pre_ops[: k + 1] + [new_relu] + pre_ops[k + 1 :]
new_ops = pre_ops + [concat] + post_ops
Loading ...