Repository URL to install this package:
|
Version:
1.20.1 ▾
|
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from fusion_base import Fusion
from onnx import helper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionGptAttentionNoPast(Fusion):
"""
Fuse GPT-2 Attention without past state into one Attention node.
This does not support attention_mask graph input right now.
"""
def __init__(self, model: OnnxModel, num_heads: int):
super().__init__(model, "Attention", ["LayerNormalization", "SkipLayerNormalization"], "without past")
# TODO: detect num_heads from graph like FusionAttention
self.num_heads = num_heads
self.mask_filter_value = None
def create_attention_node(self, gemm, gemm_qkv, input, output):
attention_node_name = self.model.create_node_name("Attention")
attention_node = helper.make_node(
"Attention",
inputs=[input, gemm.input[1], gemm.input[2]],
outputs=[attention_node_name + "_output"],
name=attention_node_name,
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend(
[
helper.make_attribute("num_heads", self.num_heads),
helper.make_attribute("unidirectional", 1),
]
)
if self.mask_filter_value is not None:
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
matmul_node = helper.make_node(
"MatMul",
inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
outputs=[attention_node_name + "_matmul_output"],
name=attention_node_name + "_matmul",
)
add_node = helper.make_node(
"Add",
inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]],
outputs=[output],
name=attention_node_name + "_add",
)
self.nodes_to_add.extend([attention_node, matmul_node, add_node])
self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
self.node_name_to_graph_name[add_node.name] = self.this_graph_name
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# (TODO) hasesh/tlwu: Investigate what fixes the following logic needs in order
# to fuse the Attention sub-graph. With some changes to other fusions, this stopped
# working.
return_indice = []
is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
qkv_nodes = None
if not is_normalize_node_skiplayernorm:
qkv_nodes = self.model.match_parent_path(
normalize_node,
["Add", "Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
[0, None, 0, 0, 0, 0, 0],
output_name_to_node=output_name_to_node,
return_indice=return_indice,
)
else:
qkv_nodes = self.model.match_parent_path(
normalize_node,
["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
[None, 0, 0, 0, 0, 0],
output_name_to_node=output_name_to_node,
return_indice=return_indice,
)
if qkv_nodes is None:
return
another_input = None
if not is_normalize_node_skiplayernorm:
(
add_qkv,
reshape_qkv,
gemm_qkv,
reshape_1,
reshape_2,
transpose_qkv,
matmul_qkv,
) = qkv_nodes
another_input = add_qkv.input[1 - return_indice[0]]
else:
(
reshape_qkv,
gemm_qkv,
reshape_1,
reshape_2,
transpose_qkv,
matmul_qkv,
) = qkv_nodes
v_nodes = self.model.match_parent_path(
matmul_qkv,
["Transpose", "Reshape", "Split", "Reshape", "Gemm", "Reshape"],
[1, 0, 0, 0, 0, 0],
)
if v_nodes is None:
logger.debug("fuse_attention: failed to match v path")
return
(
transpose_v,
reshape_v,
split_v,
reshape_after_gemm,
gemm,
reshape_before_gemm,
) = v_nodes
layernorm_before_attention = self.model.get_parent(reshape_before_gemm, 0, output_name_to_node)
if layernorm_before_attention is None or (
layernorm_before_attention.op_type != "LayerNormalization"
and layernorm_before_attention.op_type != "SkipLayerNormalization"
):
if layernorm_before_attention.op_type != "Add":
logger.debug(f"failed to get (skip)layernorm before gemm. Got {layernorm_before_attention.op_type}")
return
# `another_input` will be non-None only if
# (1) SkipLayerNorm fusion wasn't turned ON
# (2) SkipLayerNorm fusion was turned ON but upstream layer's LayerNorm + Add was not
# fused into a SkipLayerNorm. This can happen if the shapes to the Add node are different.
# So, keep the following check if SkipLayerNorm fusion is turned ON or OFF.
if another_input is not None:
if another_input not in layernorm_before_attention.input:
# match openai-gpt
if another_input not in layernorm_before_attention.output:
logger.debug("Add and (Skip)LayerNormalization shall have one same input")
return
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "Div", "MatMul"], [0, 0, 0, 0, 0])
if qk_nodes is not None:
(softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
mask_nodes = self.model.match_parent_path(
sub_qk,
[
"Mul",
"Sub",
"Slice",
"Slice",
"Unsqueeze",
"Sub",
"Squeeze",
"Slice",
"Shape",
"Div",
],
[1, 0, 1, 0, 1, 0, 0, 0, 0, 0],
)
if mask_nodes is None:
logger.debug("fuse_attention: failed to match mask path")
return
div_mask = mask_nodes[-1]
if div_qk != div_mask:
logger.debug("fuse_attention: skip since div_qk != div_mask")
return
if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
_, mul_val = self.model.get_constant_input(mask_nodes[0])
if mul_val != -10000:
self.mask_filter_value = mul_val
else:
# New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Where", "Div", "MatMul"], [0, 0, 1, 0])
if qk_nodes is not None:
(softmax_qk, where_qk, div_qk, matmul_qk) = qk_nodes
mask_nodes = self.model.match_parent_path(
where_qk,
[
"Cast",
"Slice",
"Slice",
"Unsqueeze",
"Sub",
"Squeeze",
"Slice",
"Shape",
"Div",
],
[0, 0, 0, 1, 0, 0, 0, 0, 0],
)
if mask_nodes is None:
logger.debug("fuse_attention: failed to match mask path")
return
div_mask = mask_nodes[-1]
if div_qk != div_mask:
logger.debug("fuse_attention: skip since div_qk != div_mask")
return
else:
# match openai-gpt
qk_nodes = self.model.match_parent_path(
matmul_qkv,
["Softmax", "Add", "Mul", "Div", "MatMul"],
[0, 0, 0, 0, 0],
)
if qk_nodes is None:
logger.debug("fuse_attention: failed to match qk path")
return
(softmax_qk, add_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
mask_nodes = self.model.match_parent_path(
mul_qk,
["Slice", "Slice", "Unsqueeze", "Squeeze", "Slice", "Shape", "Div"],
[1, 0, 2, 0, 0, 0, 0],
)
if mask_nodes is None:
logger.debug("fuse_attention: failed to match mask path")
return
div_mask = mask_nodes[-1]
if div_qk != div_mask:
logger.debug("fuse_attention: skip since div_qk != div_mask")
return
q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [0, 0, 0])
if q_nodes is None:
logger.debug("fuse_attention: failed to match q path")
return
(transpose_q, reshape_q, split_q) = q_nodes
if split_v != split_q:
logger.debug("fuse_attention: skip since split_v != split_q")
return
k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [1, 0, 0])
if k_nodes is None:
logger.debug("fuse_attention: failed to match k path")
return
(transpose_k, reshape_k, split_k) = k_nodes
if split_v != split_k:
logger.debug("fuse_attention: skip since split_v != split_k")
return
self.create_attention_node(gemm, gemm_qkv, layernorm_before_attention.output[0], reshape_qkv.output[0])
# we rely on prune_graph() to clean old subgraph nodes:
# qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
self.prune_graph = True