Repository URL to install this package:
|
Version:
1.23.2 ▾
|
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script converts Longformer model from huggingface transformers 4.0 or later to ONNX.
# It translates LongformerSelfAttention to the LongformerAttention operator in ONNX Runtime.
#
# Before running this script, prepare a python environment in Linux with PyTorch 1.9.0 and other packages installed.
# Then run "python setup.py install" in ./torch_extensions directory. If your python version is not 3.8, you will need
# update this script with correct name of longformer_attention.cpython-*.so (search TODO below).
#
# It is tested in Ubuntu 18.04 with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.9.0, transformers 4.18.0.
# Warning: Using PyTorch 1.10 or newer version might encounter issue in exporting, but they are fine for benchmarking.
#
# Example commands to export longformer base model in Linux:
# conda create -n longformer python=3.8
# conda activate longformer
# python3 -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
# python3 -m pip install coloredlogs flatbuffers numpy packaging sympy protobuf==3.20.1 onnx==1.12.0 transformers==4.18.0
# python3 -m pip install -i https://test.pypi.org/simple/ ort-nightly-gpu
# cd ./torch_extensions
# rm -rf build
# python setup.py install
# cd ..
# python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx
# python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx --no_merge_qkv
#
# GPU is not needed for this script. You can run it in CPU. For --optimize_onnx, you can use either onnxruntime or onnxruntime-gpu package.
#
# For inference of the onnx model, you will need onnxruntime-gpu 1.7.0 or newer version.
import argparse
import inspect
from pathlib import Path
import torch
import transformers
from longformer_helper import PRETRAINED_LONGFORMER_MODELS
from onnx import load_model
from onnx_model_bert import BertOnnxModel
from packaging import version
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args
from torch_onnx_export_helper import torch_onnx_export
from transformers import LongformerModel, LongformerSelfAttention
# Supports format 0 or 1
weight_bias_format = 0
@parse_args("v", "v", "v", "v", "v", "v", "v", "i", "i")
def my_longformer_attention(
g,
input,
weight,
bias,
mask,
global_weight,
global_bias,
global_mask,
num_heads,
window,
):
return g.op(
"com.microsoft::LongformerAttention",
input,
weight,
bias,
mask,
global_weight,
global_bias,
global_mask,
num_heads_i=num_heads,
window_i=window,
)
# namespace is onnxruntime which is registered in longformer_attention.cpp
register_custom_op_symbolic("onnxruntime::LongformerAttention", my_longformer_attention, 9)
# TODO: search the directory to find correct output filename of "python setup.py install" when python version is not 3.8
torch.ops.load_library(
r"./torch_extensions/build/lib.linux-x86_64-3.8/longformer_attention.cpython-38-x86_64-linux-gnu.so"
)
def parse_arguments():
"""Parse arguments
Returns:
args: Namespace
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
required=False,
type=str,
default="longformer-base-4096",
help="Checkpoint directory or pre-trained model names in the list: "
+ ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()),
)
parser.add_argument(
"--export_padding",
required=False,
action="store_true",
help="Export padding logic to ONNX graph. If not enabled, user need pad input so that sequence length is multiple of window size.",
)
parser.set_defaults(export_padding=False)
parser.add_argument(
"--no_merge_qkv",
required=False,
action="store_true",
help="Stack the weights of q, k and v on dimension 0 instead of dimension 1.",
)
parser.set_defaults(no_merge_qkv=False)
parser.add_argument(
"-o",
"--optimize_onnx",
required=False,
action="store_true",
help="Use optimizer.py to optimize onnx model.",
)
parser.set_defaults(optimize_onnx=False)
parser.add_argument(
"-p",
"--precision",
required=False,
type=str,
default="fp32",
choices=["fp32", "fp16"],
help="Precision of model to run: fp32 for full precision, fp16 for mixed precision",
)
args = parser.parse_args()
return args
# Create a dummy input for ONNX export.
def get_dummy_inputs(config, export_padding, device):
# When sequence length is multiple of windows size, there is no padding logic in ONNX graph
sequence_length = config.attention_window[0] + 1 if export_padding else config.attention_window[0]
# Create dummy inputs
input_ids = torch.arange(sequence_length).unsqueeze(0).to(device)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
attention_mask[:, sequence_length - 1] = 0 # last token is masked
global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
global_attention_mask[:, 0] = 1 # first token is global token
return input_ids, attention_mask, global_attention_mask
# A new function to replace LongformerSelfAttention.forward
# For transformers 4.0.0
def my_longformer_self_attention_forward_4(
self,
hidden_states,
attention_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
):
global_mask = is_index_global_attn.int()
# The following check is based on the dummy inputs (only the first token is global).
assert (
len(global_mask.shape) == 2
and global_mask.shape[0] == 1
and global_mask.count_nonzero().item() == 1
and global_mask.tolist()[0][0] == 1
)
input_mask = is_index_masked.float()
# TODO: The filtering value may be -10000.0 or -inf. Check the huggingface implementation.
input_mask = input_mask.masked_fill(is_index_masked, -10000.0)
# Yet another way to generate input_mask = torch.masked_fill(attention_mask, is_index_global_attn, 0.0)
# TODO: add postprocessing of ONNX model to calculate based on graph input: input_mask = (attention_mask - 1) * 10000.0
# TODO: add postprocessing of ONNX model to use graph input directly: global_mask = global_attention_mask
# The following check is based on the dummy inputs (only the last token is masked).
assert (
len(input_mask.shape) == 2
and input_mask.shape[0] == 1
and input_mask.count_nonzero().item() == 1
and input_mask.tolist()[0][-1] == -10000.0
)
weight = torch.stack(
(
self.query.weight.transpose(0, 1),
self.key.weight.transpose(0, 1),
self.value.weight.transpose(0, 1),
),
dim=weight_bias_format,
)
if weight_bias_format == 1:
# shape is (hidden_size, 3*hidden_size) for format 1, otherwise (3, hidden_size, hidden_size) by default
weight = weight.reshape(self.embed_dim, 3 * self.embed_dim)
global_weight = torch.stack(
(
self.query_global.weight.transpose(0, 1),
self.key_global.weight.transpose(0, 1),
self.value_global.weight.transpose(0, 1),
),
dim=weight_bias_format,
)
if weight_bias_format == 1:
global_weight = global_weight.reshape(self.embed_dim, 3 * self.embed_dim)
if weight_bias_format == 1:
bias = torch.stack((self.query.bias, self.key.bias, self.value.bias), dim=0)
bias = bias.reshape(3 * self.embed_dim)
global_bias = torch.stack((self.query_global.bias, self.key_global.bias, self.value_global.bias), dim=0)
global_bias = global_bias.reshape(3 * self.embed_dim)
else:
bias = torch.stack(
(self.query.bias, self.key.bias, self.value.bias, self.key_global.bias, self.value_global.bias), dim=0
)
bias = bias.reshape(5 * self.embed_dim)
global_bias = self.query_global.bias
global_bias = global_bias.reshape(1 * self.embed_dim)
attn_output = torch.ops.onnxruntime.LongformerAttention(
hidden_states,
weight,
bias,
input_mask,
global_weight,
global_bias,
global_mask,
self.num_heads,
self.one_sided_attn_window_size,
)
assert attn_output.size() == hidden_states.size(), "Unexpected size"
outputs = (attn_output,)
return outputs
# For transformers 4.3.0
def my_longformer_self_attention_forward_4_3(
self,
hidden_states,
attention_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=False,
):
assert output_attentions is False
return my_longformer_self_attention_forward_4(
self,
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
)
# For transformers 4.3.2 or later versions
def my_longformer_self_attention_forward_4_3_2(
self,
hidden_states,
attention_mask=None,
layer_head_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=False,
):
assert output_attentions is False
assert layer_head_mask is None
return my_longformer_self_attention_forward_4(
self,
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
)
def export_longformer(model: LongformerModel, onnx_model_path: str, export_padding: bool):
"""Export longformer model to ONNX
Args:
model (LongformerModel): longformer model
onnx_model_path (str): output onnx path
export_padding (bool): whether export padding logic to ONNX so that input string can be any length.
Raises:
RuntimeError: This tool requires transformers 4.0.0 or later.
RuntimeError: LongformerSelfAttention.forward arguments are different.
"""
input_ids, attention_mask, global_attention_mask = get_dummy_inputs(
model.config, export_padding, device=torch.device("cpu")
)
_ = model(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
)
if version.parse(transformers.__version__) < version.parse("4.0.0"):
raise RuntimeError("This tool requires transformers 4.0.0 or later.")
# Here we replace LongformerSelfAttention.forward using our implementation for exporting ONNX model
key = " ".join(inspect.getfullargspec(LongformerSelfAttention.forward).args)
args_to_func = {
"self hidden_states attention_mask layer_head_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3_2,
"self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3,
"self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn": my_longformer_self_attention_forward_4,
}
if key not in args_to_func:
print(
"Current arguments",
inspect.getfullargspec(LongformerSelfAttention.forward).args,
)
raise RuntimeError(
"LongformerSelfAttention.forward arguments are different. Please install supported version (like transformers 4.3.0)."
)
# Store for restoring later
original_forward = LongformerSelfAttention.forward
LongformerSelfAttention.forward = args_to_func[key]
example_inputs = (input_ids, attention_mask, global_attention_mask)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
model,
example_inputs,
onnx_model_path,
opset_version=12,
input_names=["input_ids", "attention_mask", "global_attention_mask"],
output_names=["last_state", "pooler"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"global_attention_mask": {0: "batch_size", 1: "sequence_length"},
"last_state": {0: "batch_size", 1: "sequence_length"},
"pooler": {0: "batch_size", 1: "sequence_length"},
},
custom_opsets={"com.microsoft": 1},
)
print(f"ONNX model exported to {onnx_model_path}")
# Restore original implementation:
LongformerSelfAttention.forward = original_forward
def optimize_longformer(onnx_model_path: str, fp32_model_path: str, fp16_model_path=None):
"""Optimize longformer onnx model
Args:
onnx_model_path (str): path of original ONNX model.
fp32_model_path (str): path of optimized fp32 model.
fp16_model_path (str, optional): path of optimized fp16 model. Defaults to None.
"""
model = load_model(onnx_model_path, format=None, load_external_data=True)
optimizer = BertOnnxModel(model)
optimizer.optimize()
use_external_data_format = False
if fp32_model_path:
optimizer.save_model_to_file(fp32_model_path, use_external_data_format)
print(f"optimized fp32 model saved to {fp32_model_path}")
if fp16_model_path:
optimizer.convert_float_to_float16(keep_io_types=True)
optimizer.save_model_to_file(fp16_model_path, use_external_data_format)
print(f"optimized fp16 model saved to {fp16_model_path}")
def main(args):
model_name = args.model
onnx_model_path = model_name + ".onnx"
global weight_bias_format # noqa: PLW0603
weight_bias_format = 0 if args.no_merge_qkv else 1
model = LongformerModel.from_pretrained(PRETRAINED_LONGFORMER_MODELS[model_name])
export_longformer(model, onnx_model_path, args.export_padding)
if args.optimize_onnx or args.precision != "fp32":
fp32_model_path = model_name + f"_f{weight_bias_format}" + "_fp32.onnx"
fp16_model_path = model_name + f"_f{weight_bias_format}" + "_fp16.onnx" if args.precision == "fp16" else None
optimize_longformer(onnx_model_path, fp32_model_path, fp16_model_path)
if __name__ == "__main__":
args = parse_arguments()
main(args)