Repository URL to install this package:
|
Version:
1.23.0 ▾
|
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import numpy as np
from fusion_attention import AttentionMask, FusionAttention
from fusion_base import Fusion
from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization
from fusion_utils import NumpyHelper
from onnx import NodeProto, TensorProto, helper
from onnx_model import OnnxModel
from onnx_model_bert import BertOnnxModel
logger = logging.getLogger(__name__)
class FusionT5Attention(FusionAttention):
"""
Fuse T5 Attention subgraph into one Attention node.
"""
def __init__(
self,
model: OnnxModel,
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask,
):
super().__init__(
model,
hidden_size,
num_heads,
attention_mask,
use_multi_head_attention=False,
search_op_types=["Softmax"],
)
self.static_kv = 1
def make_attention_node(
self,
mask_index: str | None,
q_matmul: NodeProto,
k_matmul: NodeProto,
v_matmul: NodeProto,
num_heads: int,
hidden_size: int,
input: str,
output: str,
attn_bias: str | None,
scale: float,
) -> NodeProto | None:
"""Create an Attention node.
Args:
mask_index (str): mask input
q_matmul (NodeProto): MatMul node in fully connection for Q
k_matmul (NodeProto): MatMul node in fully connection for K
v_matmul (NodeProto): MatMul node in fully connection for V
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
input (str): input name
output (str): output name
Returns:
Union[NodeProto, None]: the node created or None if failed.
"""
assert num_heads > 0
if hidden_size > 0 and (hidden_size % num_heads) != 0:
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
return None
q_weight = self.model.get_initializer(q_matmul.input[1])
k_weight = self.model.get_initializer(k_matmul.input[1])
v_weight = self.model.get_initializer(v_matmul.input[1])
if q_weight is None or k_weight is None or v_weight is None:
matmul = q_matmul if q_weight is None else k_matmul if k_weight is None else v_matmul
print(
f"{matmul.input[1]} is not an initializer. "
"Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
)
return None
qw = NumpyHelper.to_array(q_weight)
kw = NumpyHelper.to_array(k_weight)
vw = NumpyHelper.to_array(v_weight)
# assert q and k have same shape as expected
assert qw.shape == kw.shape
qw_in_size = qw.shape[0]
kw_in_size = kw.shape[0]
vw_in_size = vw.shape[0]
assert qw_in_size == kw_in_size == vw_in_size
if hidden_size > 0 and hidden_size != qw_in_size:
logger.warning(
f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). "
"Please provide a correct input hidden size or pass in 0"
)
qw_out_size = np.prod(qw.shape[1:])
qkv_weight = np.stack((qw, kw, vw), axis=1)
qkv_weight_dim = 3 * qw_out_size
attention_node_name = self.model.create_node_name("Attention")
weight = helper.make_tensor(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
dims=[qw_in_size, qkv_weight_dim],
vals=qkv_weight.tobytes(),
raw=True,
)
self.model.add_initializer(weight, self.this_graph_name)
attention_inputs = [
input,
attention_node_name + "_qkv_weight",
"",
]
if mask_index:
attention_inputs.append(mask_index)
else:
attention_inputs.append("")
if attn_bias:
attention_inputs.append("") # no past
attention_inputs.append(attn_bias)
while attention_inputs and attention_inputs[-1] == "":
attention_inputs.pop()
attention_node = helper.make_node(
"Attention",
inputs=attention_inputs,
outputs=[output],
name=attention_node_name,
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
if scale is not None:
attention_node.attribute.extend([helper.make_attribute("scale", scale)])
if self.mask_filter_value is not None:
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
return attention_node
def create_mha_node(
self,
query: str,
key: str,
value: str,
mask_index: str | None,
attn_bias: str | None,
past_key: str | None,
past_value: str | None,
output: str,
present_key: str | None,
present_value: str | None,
num_heads: int,
hidden_size: int,
) -> NodeProto | None:
assert num_heads > 0 and hidden_size > 0 and query and key and value
if (hidden_size % num_heads) != 0:
logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
return None
attention_node_name = self.model.create_node_name("MultiHeadAttention")
attention_inputs = [
query,
key,
value,
"", # bias
]
if mask_index:
attention_inputs.append(mask_index)
else:
attention_inputs.append("")
if attn_bias:
attention_inputs.append(attn_bias)
else:
attention_inputs.append("")
if past_key:
assert past_value
attention_inputs.append(past_key)
attention_inputs.append(past_value)
while attention_inputs and attention_inputs[-1] == "":
attention_inputs.pop()
attention_outputs = [output]
if present_key:
assert present_value
attention_outputs.append(present_key)
attention_outputs.append(present_value)
print(f"{attention_inputs=}, {attention_outputs=}, {attention_node_name=}")
attention_node = helper.make_node(
"MultiHeadAttention",
inputs=attention_inputs,
outputs=attention_outputs,
name=attention_node_name,
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
attention_node.attribute.extend([helper.make_attribute("scale", 1.0)])
if self.mask_filter_value is not None:
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
self.increase_counter("MultiHeadAttention")
return attention_node
def fuse(self, node, input_name_to_nodes, output_name_to_node):
if self.fuse_t5_encoder(node, input_name_to_nodes, output_name_to_node):
return
self.fuse_t5_decoder(node, input_name_to_nodes, output_name_to_node)
def fuse_t5_encoder(self, softmax_node, input_name_to_nodes, output_name_to_node):
assert softmax_node.op_type == "Softmax"
qkv_nodes = self.model.match_child_path(
softmax_node,
["MatMul", "Transpose", "Reshape"],
edges=[(0, 0), (0, 0), (0, 0)],
input_name_to_nodes=input_name_to_nodes,
)
if qkv_nodes is None:
return False
matmul_qkv, _, reshape_qkv = qkv_nodes
qkv_shape_nodes = self.model.match_parent_path(
reshape_qkv,
["Concat", "Unsqueeze", "Gather", "Shape"],
[1, 0, 0, 0],
output_name_to_node,
)
if qkv_shape_nodes is None:
return False
input_shape_node = qkv_shape_nodes[-1]
v_nodes = self.model.match_parent_path(
matmul_qkv,
["Transpose", "Reshape", "MatMul"],
[1, 0, 0],
output_name_to_node,
)
if v_nodes is None:
return False
_, reshape_v, matmul_v = v_nodes
# todo: check reshape_v parent nodes
qk_nodes = self.model.match_parent_path(
matmul_qkv,
["Softmax", "Add", "MatMul"],
[0, 0, 0],
output_name_to_node,
)
if qk_nodes is None:
return False
_, add_qk, matmul_qk = qk_nodes
mask_nodes = self.model.match_parent_path(
add_qk,
["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
[1, 1, 0, 1, 0, 0],
output_name_to_node,
)
is_pattern_for_one_graph_input = mask_nodes is None
if mask_nodes is not None:
mul_node = mask_nodes[1]
else:
# Pattern for SD3 and Flux.
mask_nodes = self.model.match_parent_path(
add_qk,
["Add", "Slice", "Mul", "Sub", "Unsqueeze", "Unsqueeze"],
[1, 1, 0, 0, 1, 0],
output_name_to_node,
)
# If the model is not optimized by ORT, there might be an additional Cast node.
if mask_nodes is None:
mask_nodes = self.model.match_parent_path(
add_qk,
["Add", "Slice", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
[1, 1, 0, 0, 1, 0, 0],
output_name_to_node,
)
if mask_nodes is None:
return False
mul_node = mask_nodes[2]
_, mul_val = self.model.get_constant_input(mul_node)
if mul_val is None:
return False
if mul_val != -10000:
self.mask_filter_value = float(mul_val)
# If the mask is derived from shape of input_ids, it means there is no padding mask.
mask_nodes_2 = self.model.match_parent_path(
mask_nodes[-1],
["ConstantOfShape", "Concat", "Unsqueeze", "Gather", "Shape"],
[0, 0, 0, 0, 0],
output_name_to_node,
)
mask_nodes_3 = self.model.match_parent_path(
mask_nodes[-1],
["ConstantOfShape", "Concat", "Unsqueeze", "Gather", "Shape"],
[0, 0, 1, 0, 0],
output_name_to_node,
)
if (
mask_nodes_2 is not None
and any(input.name == mask_nodes_2[-1].input[0] for input in self.model.graph().input)
and mask_nodes_3 is not None
and mask_nodes_2[-1].input[0] == mask_nodes_3[-1].input[0]
and len(mask_nodes_2[1].input) == 2
):
mask_index = ""
else:
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
res_pos_bias = None
rpb_nodes = self.model.match_parent_path(
add_qk,
["Add", "RelativePositionBias"],
[1, 0],
)
if rpb_nodes is None and is_pattern_for_one_graph_input:
# Pattern for SD3 and Flux.
rpb_nodes = self.model.match_parent_path(
add_qk,
["Add", "Slice", "RelativePositionBias"],
[1, 0, 0],
)
if rpb_nodes is None:
return False
res_pos_bias = rpb_nodes[-1].output[0]
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "MatMul"],
[1, 0, 0],
)
if k_nodes is None:
return False
_, _, matmul_k = k_nodes
# todo: check reshape_k parent nodes
q_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "MatMul"],
[0, 0, 0],
)
if q_nodes is None:
return False
_, reshape_q, matmul_q = q_nodes
# todo: check reshape_q parent nodes
if matmul_q.input[0] != input_shape_node.input[0]:
return False
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
new_node = self.make_attention_node(
mask_index,
matmul_q,
matmul_k,
matmul_v,
num_heads=q_num_heads,
hidden_size=q_hidden_size,
input=input_shape_node.input[0],
output=reshape_qkv.output[0],
attn_bias=res_pos_bias,
scale=1.0,
)
if new_node is None:
return False
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
self.nodes_to_remove.append(reshape_qkv)
self.prune_graph = True
return True
def fuse_t5_decoder(self, softmax_node, input_name_to_nodes, output_name_to_node):
assert softmax_node.op_type == "Softmax"
qkv_nodes = self.model.match_child_path(
softmax_node,
["MatMul", "Transpose", "Reshape"],
edges=[(0, 0), (0, 0), (0, 0)],
input_name_to_nodes=input_name_to_nodes,
)
if qkv_nodes is None:
return
matmul_qkv, _transpose_qkv, reshape_qkv = qkv_nodes
qkv_shape_nodes = self.model.match_parent_path(
reshape_qkv,
["Concat", "Unsqueeze", "Gather", "Shape"],
[1, 0, 0, 0],
)
if qkv_shape_nodes is None:
return
input_shape_node = qkv_shape_nodes[-1]
value = None
past_value = None
present_value = None
v_nodes = self.model.match_parent_path(
matmul_qkv,
["Concat", "Transpose", "Reshape", "MatMul"],
[1, 1, 0, 0],
)
if v_nodes is None:
v_nodes = self.model.match_parent_path(
matmul_qkv,
["Transpose", "Reshape", "MatMul"],
[1, 0, 0],
)
if v_nodes is not None:
transpose_v, reshape_v, matmul_v = v_nodes
value = reshape_v.input[0]
present_value = transpose_v.output[0]
if "present_value" not in present_value:
return
if matmul_v.input[0] != input_shape_node.input[0]:
self.static_kv = 1
else:
self.static_kv = 0
else:
past_value = matmul_qkv.input[1]
if past_value in output_name_to_node:
return
if "past_value_cross" not in past_value:
return
self.static_kv = 1
else:
concat_v, _, reshape_v, _ = v_nodes
past_value = concat_v.input[0]
if past_value in output_name_to_node:
return
if "past_value_self" not in past_value:
return
present_value = concat_v.output[0]
if "present_value_self" not in present_value:
return
value = reshape_v.input[0]
self.static_kv = 0
qk_nodes = self.model.match_parent_path(
matmul_qkv,
["Softmax", "Add", "MatMul"],
[0, 0, 0],
)
if qk_nodes is None:
return
_, add_qk, matmul_qk = qk_nodes
mask_index = None
res_pos_bias = None
if self.static_kv == 1:
mask_nodes = self.model.match_parent_path(
add_qk,
["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
[1, 1, 0, 1, 0, 0],
)
if mask_nodes is not None:
mul_node = mask_nodes[1]
else:
mask_nodes = self.model.match_parent_path(
add_qk,
["Add", "Slice", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
[1, 1, 0, 0, 1, 0, 0],
)
if mask_nodes is None:
return
mul_node = mask_nodes[2]
_, mul_val = self.model.get_constant_input(mul_node)
if mul_val != -10000:
self.mask_filter_value = mul_val
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
else:
matched_path_index, _, _ = self.model.match_parent_paths(
add_qk,
[
(["Add", "Slice"], [1, 0]),
(["Add", "RelativePositionBias"], [1, 0]),
],
output_name_to_node,
)
if matched_path_index < 0:
logger.debug("Skip MultiHeadAttention fusion since attention bias pattern not matched")
return
res_pos_bias = add_qk.input[1]
key = None
past_key = None
present_key = None
if self.static_kv == 1:
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "MatMul"],
[1, 0, 0],
)
if k_nodes is not None:
transpose_k, reshape_k, _ = k_nodes
key = reshape_k.input[0]
present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]]
for present_key_transpose_node in present_key_transpose_nodes:
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
if present_key_candidate is not None:
present_key = present_key_candidate.name
break
if present_key is None:
return
if "present_key_cross" not in present_key:
return
else:
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose"],
[1],
)
if k_nodes is None:
return
transpose_k = k_nodes[0]
past_key = transpose_k.input[0]
if past_key in output_name_to_node:
return
if "past_key_cross" not in past_key:
return
else:
idx, k_nodes, _ = self.model.match_parent_paths(
matmul_qk,
[
(["Transpose", "Concat", "Reshape", "MatMul"], [1, 0, 1, 0]),
(["Transpose", "Concat", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0]),
],
output_name_to_node,
)
past_key_transpose_node = None
present_key_transpose_nodes = None
if k_nodes is not None:
concat_k, reshape_k = k_nodes[1], k_nodes[-2]
key = reshape_k.input[0]
if idx == 0:
past_key_transpose_node = output_name_to_node[concat_k.input[0]]
past_key = past_key_transpose_node.input[0]
else:
past_key = concat_k.input[0]
if past_key in output_name_to_node:
return
if "past_key_self" not in past_key:
return
if idx == 0:
present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]]
for present_key_transpose_node in present_key_transpose_nodes:
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
if present_key_candidate is not None:
present_key = present_key_candidate.name
break
else:
present_key = concat_k.output[0]
if present_key is None:
return
if "present_key_self" not in present_key:
return
else:
k_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "MatMul"],
[1, 0, 0],
)
if k_nodes is None:
return
_, reshape_k, _ = k_nodes
key = reshape_k.input[0]
present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]]
for present_key_transpose_node in present_key_transpose_nodes:
present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
if present_key_candidate is not None:
present_key = present_key_candidate.name
break
if present_key is None:
return
if "present_key_self" not in present_key:
return
q_nodes = self.model.match_parent_path(
matmul_qk,
["Transpose", "Reshape", "MatMul"],
[0, 0, 0],
)
if q_nodes is None:
return
transpose_q, reshape_q, matmul_q = q_nodes
if matmul_q.input[0] != input_shape_node.input[0]:
return
q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
if self.static_kv == 1 and past_key is not None:
key = past_key
value = past_value
past_key = None
past_value = None
if not (key and value and q_num_heads > 0 and q_hidden_size > 0):
return
new_node = self.create_mha_node(
query=matmul_q.output[0],
key=key,
value=value,
mask_index=mask_index,
attn_bias=res_pos_bias,
past_key=past_key,
past_value=past_value,
output=reshape_qkv.output[0],
present_key=present_key,
present_value=present_value,
num_heads=q_num_heads,
hidden_size=q_hidden_size,
)
if new_node:
self.nodes_to_add.append(new_node)
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
# Since present_* is graph output, we need update the graph to avoid circular.
if present_key or present_value:
for graph_output in [present_key, present_value]:
if not (graph_output and self.model.find_graph_output(graph_output)):
print(f"{graph_output=} does not exist in graph output")
return
assert graph_output in output_name_to_node
output_name_to_node[graph_output].output[0] = graph_output + "_copy"
self.model.replace_input_of_all_nodes(graph_output, graph_output + "_copy")
self.nodes_to_remove.append(reshape_qkv)
self.prune_graph = False
class FusionRelativePositionBiasBlock(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "RelativePositionBias", ["Softmax"])
def fuse(self, node, input_name_to_nodes, output_name_to_node):
compute_bias_nodes = self.model.match_parent_path(
node,
["Add", "Add", "Slice", "Unsqueeze", "Transpose", "Gather", "Where"],
[0, 1, 0, 0, 0, 0, 1],
output_name_to_node,
)
if compute_bias_nodes is None:
compute_bias_nodes = self.model.match_parent_path(
node,
["Add", "Add", "Slice", "Unsqueeze", "Transpose", "Gather", "Add", "Where"],
[0, 1, 0, 0, 0, 0, 1, 1],
output_name_to_node,
)
if compute_bias_nodes is None:
return
gather = compute_bias_nodes[5]
where = compute_bias_nodes[-1]
slice = compute_bias_nodes[2]
unsqueeze = compute_bias_nodes[3]
# Current fusion will not remove the node until the graph is processed.
# This avoids to fuse it again when it is shared by multiple layers.
if unsqueeze in self.nodes_to_remove:
return
compute_buckets_nodes = self.model.match_parent_path(
where,
["Min", "ConstantOfShape", "Shape", "Add", "Cast", "Mul", "Div", "Log", "Div"],
[2, 1, 0, 0, 0, 0, 0, 0, 0],
output_name_to_node,
)
if compute_buckets_nodes is None:
return
# This value is to used to compute max_distance later.
log_max = self.model.get_constant_value(compute_buckets_nodes[-3].input[1])
div = compute_buckets_nodes[-1]
range_nodes = self.model.match_parent_path(
div,
["Cast", "Neg", "Min", "ConstantOfShape", "Shape", "Sub", "Unsqueeze", "Range"],
[0, 0, 0, 1, 0, 0, 0, 0],
output_name_to_node,
)
is_bidirectional = False
if range_nodes is None:
range_nodes = self.model.match_parent_path(
div, ["Cast", "Abs", "Sub", "Unsqueeze", "Range"], [0, 0, 0, 0, 0], output_name_to_node
)
is_bidirectional = True
if range_nodes is None:
return
range_node = range_nodes[-1]
# Double check that the constant relative to max_distance and relative_attention_num_buckets.
# Most t5 models use max_distance=128, so we hardcode it unitl we see a model with different value.
# The log_max is the value of the following formula:
# math.log(max_distance / (relative_attention_num_buckets // (4 if is_bidirectional else 2)))
# See https://github.com/huggingface/transformers/blob/608e163b527eaee41e650ffb9eb4c422d2679902/src/transformers/models/t5/modeling_t5.py#L397.
# Here is the value based on max_distance=128 and relative_attention_num_buckets=32:
max_distance = int(np.round(np.exp(log_max) * (32 // (4 if is_bidirectional else 2))))
if max_distance != 128:
logger.warning(
f"max_distance is {max_distance}, which is different from the default value 128. "
"Please double check the model configuration."
)
node_name = self.model.create_node_name(
"RelativePositionBias", name_prefix="RelPosBias_" + ("encoder" if is_bidirectional else "decoder")
)
table_weight_i = self.model.get_initializer(gather.input[0])
if table_weight_i is None:
return
table_weight = NumpyHelper.to_array(table_weight_i)
table_weight_t = np.transpose(table_weight)
bias_table = helper.make_tensor(
name=node_name + "_bias_table_weight",
data_type=TensorProto.FLOAT,
dims=[np.shape(table_weight)[0], np.shape(table_weight)[1]],
vals=table_weight_t.tobytes(),
raw=True,
)
self.model.add_initializer(bias_table, self.this_graph_name)
# Relative position is like the following in encoder:
# seq_len
# |
# Range(0, *)
# / \
# Unsqueeze(axes=0) Unsqueeze(axes=1)
# \ /
# Sub
# |
# Abs
#
# Relative position is like the following in decoder:
# past_seq_len seq_len
# \ /
# Add
# / \
# Range(0, *) Range(0, *)
# \ /
# Sub
# Note that the graph will slice the attention bias to get last seq_len rows.
#
# In new version of transformers, the pattern of decoder is changed like the following
#
# total_seq_len Range(start=past_seq_len, end=total_seq_len)
# | |
# Range(0, *) Unsqueeze(axes=1)
# | |
# Unsqueeze(axes=0) Cast(to=int64)
# \ /
# Sub
# Currently, there is still Slice to get last seq_len rows so end result is same.
# But need to be careful that the shape of bias tensor is changed before Slice.
#
# RelativePositionBias operator requires query_length == key_length so we shall pass in total_seq_len.
# Here we get the end value of the Range node as length to pass to the RelativePositionBias node.
# TODO: Optimization opportunity: change RelativePositionBias op to support query_length != key_length.
# only compute seq_len rows, then we can remove the Slice after the RelativePositionBias node.
inputs = [bias_table.name, range_node.input[1], range_node.input[1]]
# Use a new tensor name since the shape might be different as mentioned above.
bias_output = node_name + "_rel_pos_bias"
slice.input[0] = bias_output
rpb_node = helper.make_node(
"RelativePositionBias",
inputs=inputs,
outputs=[bias_output],
name=node_name,
)
rpb_node.domain = "com.microsoft"
rpb_node.attribute.extend([helper.make_attribute("max_distance", max_distance)])
rpb_node.attribute.extend([helper.make_attribute("is_bidirectional", is_bidirectional)])
self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name
self.nodes_to_add.append(rpb_node)
self.prune_graph = True
class T5OnnxModel(BertOnnxModel):
def __init__(self, model, num_heads: int = 0, hidden_size: int = 0):
super().__init__(model, num_heads, hidden_size)
self.attention_mask = AttentionMask(self)
# When the model has only one input (input_ids), there is no padding mask.
if len(self.model.graph.input) == 1:
from fusion_options import AttentionMaskFormat # noqa: PLC0415
self.attention_mask.mask_format = AttentionMaskFormat.NoMask
self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask)
self.layer_norm_fusion = FusionSimplifiedLayerNormalization(self)
self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self)
self.rpb_fusion = FusionRelativePositionBiasBlock(self)
def fuse_attention(self):
self.attention_fusion.apply()
def fuse_layer_norm(self):
self.layer_norm_fusion.apply()
def fuse_skip_layer_norm(self, shape_infer=True):
self.skip_layer_norm_fusion.apply()
def adjust_rel_pos_bis_length_input(self):
# For T5 encoder, it uses complex logic to compute the query and key length when there is only one graph input (input_ids)
# We can directly get the length from shape (the 2nd dimension) of input_ids.
for node in self.nodes():
if node.op_type == "RelativePositionBias":
nodes = self.match_parent_path(
node,
[
"Gather",
"Shape",
"Transpose",
"Reshape",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
"SimplifiedLayerNormalization",
"Gather",
],
[1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
)
# TODO: more validation on node attributes
if nodes is not None:
graph_input_names = [input.name for input in self.model.graph.input]
if nodes[-1].input[1] in graph_input_names:
node_name = self.create_node_name("Shape", name_prefix="Added_Shape_")
shape_node = helper.make_node(
"Shape",
inputs=[nodes[-1].input[1]],
outputs=[node_name + "_Output"],
name=node_name,
)
indices_1 = helper.make_tensor(
name="Constant_Index_1",
data_type=TensorProto.INT64,
dims=[1], # Shape of the tensor
vals=[1], # Tensor values
)
self.add_initializer(indices_1)
gather = helper.make_node(
"Gather",
inputs=[node_name + "_Output", "Constant_Index_1"],
outputs=[node_name + "_Output_Gather_1"],
name=self.create_node_name("Gather", name_prefix="Added_Gather_"),
axis=0,
)
self.add_node(shape_node)
self.add_node(gather)
node.input[1] = node_name + "_Output_Gather_1"
node.input[2] = node_name + "_Output_Gather_1"
break
# Remove get_extended_attention_mask() since it generates all zeros.
def remove_extended_mask_decoder_init(self):
nodes_to_remove = []
for node in self.nodes():
if node.op_type == "Add":
extended_mask_nodes = self.match_parent_path(
node,
[
"Mul",
"Sub",
"Mul",
"Unsqueeze",
"Cast",
"LessOrEqual",
"Tile",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
],
[1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
)
if extended_mask_nodes is None:
continue
rpb_nodes = self.match_parent_path(node, ["RelativePositionBias"], [0])
if rpb_nodes is None:
continue
rpb_node = rpb_nodes[0]
rpb_node.output[0] = node.output[0]
nodes_to_remove.extend(extended_mask_nodes)
nodes_to_remove.append(node)
self.remove_nodes(nodes_to_remove)
def remove_extended_mask_decoder(self):
nodes_to_remove = []
for node in self.nodes():
if node.op_type == "Add":
extended_mask_nodes = self.match_parent_path(
node,
[
"Mul",
"Sub",
"Mul",
"Unsqueeze",
"Concat",
"Cast",
"LessOrEqual",
"Tile",
"Concat",
"Unsqueeze",
"Gather",
"Shape",
],
[1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0],
)
if extended_mask_nodes is None:
continue
rpb_nodes = self.match_parent_path(node, ["Slice", "RelativePositionBias"], [0, 0])
if rpb_nodes is None:
continue
rpb_node = rpb_nodes[0]
rpb_node.output[0] = node.output[0]
nodes_to_remove.extend(extended_mask_nodes)
nodes_to_remove.append(node)
self.remove_nodes(nodes_to_remove)
def preprocess(self):
self.adjust_reshape_and_expand()
self.rpb_fusion.apply()
def postprocess(self):
# remove get_extended_attention_mask() since it generates all zeros.
self.remove_extended_mask_decoder_init()
self.remove_extended_mask_decoder()
self.adjust_rel_pos_bis_length_input()
self.prune_graph()