Repository URL to install this package:
|
Version:
1.23.2 ▾
|
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""
This converts GPT2 model to onnx. Examples:
(1) Convert pretrained model 'gpt2' to ONNX
python convert_to_onnx.py -m gpt2 --output gpt2.onnx
(2) Convert pretrained model 'distilgpt2' to ONNX, and use optimizer to get float16 model.
python convert_to_onnx.py -m distilgpt2 --output distilgpt2_fp16.onnx -o -p fp16
(3) Convert a model check point to ONNX, and run optimization and int8 quantization
python convert_to_onnx.py -m ./my_model_checkpoint/ --output my_model_int8.onnx -o -p int8
"""
import argparse
import csv
import json
import logging
import os
import shutil
import sys
from pathlib import Path
import numpy
import torch
from benchmark_helper import (
Precision,
create_onnxruntime_session,
get_ort_environment_variables,
prepare_environment,
setup_logger,
)
from gpt2_helper import DEFAULT_TOLERANCE, MODEL_CLASSES, PRETRAINED_GPT2_MODELS, Gpt2Helper
from gpt2_tester import Gpt2Tester
from packaging import version
from quantize_helper import QuantizeHelper
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from onnxruntime import __version__ as ort_version
logger = logging.getLogger("")
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name_or_path",
required=True,
type=str,
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
)
parser.add_argument(
"--model_class",
required=False,
type=str,
default="GPT2LMHeadModel",
choices=list(MODEL_CLASSES.keys()),
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default=os.path.join(".", "cache_models"),
help="Directory to cache pre-trained models",
)
parser.add_argument(
"--output",
required=False,
type=str,
default=os.path.join(".", "onnx_models"),
help="Output directory, or model path ends with .onnx",
)
parser.add_argument(
"-o",
"--optimize_onnx",
required=False,
action="store_true",
help="Use optimizer.py to optimize onnx model",
)
parser.set_defaults(optimize_onnx=False)
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
parser.set_defaults(use_gpu=False)
parser.add_argument(
"--provider",
required=False,
default=None,
choices=["dml", "rocm", "migraphx", "cuda", "tensorrt"],
help="use dml, rocm, cuda, tensorrt or migraphx for respective backend",
)
parser.add_argument(
"--tolerance",
required=False,
type=float,
default=0,
help="the absolute and relative tolerance for parity verification",
)
parser.add_argument(
"--input_test_file",
"-i",
required=False,
type=str,
default="",
help="Path to the file with inputs to test with",
)
parser.add_argument(
"-p",
"--precision",
required=False,
type=Precision,
default=Precision.FLOAT32,
choices=list(Precision),
help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision, and int8 for quantization",
)
parser.add_argument(
"-t",
"--test_cases",
required=False,
type=int,
default=1000,
help="Number of test cases per run for parity",
)
parser.add_argument(
"-r",
"--test_runs",
required=False,
type=int,
default=10,
help="Number of runs for parity. It is used for significance test.",
)
parser.add_argument("--verbose", required=False, action="store_true")
parser.set_defaults(verbose=False)
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
parser.set_defaults(use_external_data_format=False)
parser.add_argument("--overwrite", required=False, action="store_true")
parser.set_defaults(overwrite=False)
parser.add_argument(
"--use_int64_inputs",
required=False,
action="store_true",
help="Use int32 instead of int64 for input_ids, position_ids and attention_mask.",
)
parser.set_defaults(use_int64_inputs=False)
parser.add_argument(
"-s",
"--stage",
type=int,
default=0,
required=False,
choices=[0, 1, 2],
help="Stage in generation: 1 (initial decoder), 2 (decoder), 0 (both). "
"1 - decode the first token when past_sequence_length is zero; "
"2 - decode the remaining tokens when past_sequence_length is not zero; "
"0 - one onnx model for both stages 1 and 2. "
"Note that we will optimize 1 and 2 differently for best performance.",
)
fp16_option_group = parser.add_argument_group(
'float to float16 conversion parameters that works when "--precision fp16" is specified'
)
fp16_option_group.add_argument(
"-a",
"--auto_mixed_precision",
required=False,
action="store_true",
help="Convert to mixed precision automatically. Other float16 conversion parameters will be ignored.",
)
fp16_option_group.set_defaults(auto_mixed_precision=False)
fp16_option_group.add_argument(
"--keep_io_types",
required=False,
action="store_true",
help="Use float32 for past inputs, present and logits outputs.",
)
fp16_option_group.set_defaults(keep_io_types=False)
fp16_option_group.add_argument(
"--io_block_list",
nargs="+",
default=[],
help="List of inputs or outputs in float32 instead of float16",
)
fp16_option_group.add_argument(
"--op_block_list",
nargs="+",
default=[],
help="List of operators (like Add LayerNormalization SkipLayerNormalization EmbedLayerNormalization FastGelu) "
"to compute in float32 instead of float16.",
)
fp16_option_group.add_argument(
"--node_block_list",
nargs="+",
default=[],
help="List of node names to compute in float32 instead of float16.",
)
fp16_option_group.add_argument(
"--force_fp16_initializers",
required=False,
action="store_true",
help="Convert all float initializers to float16.",
)
fp16_option_group.set_defaults(force_fp16_initializers=False)
args = parser.parse_args(argv)
return args
def get_onnx_model_size(onnx_path: str, use_external_data_format: bool):
if not use_external_data_format:
return os.path.getsize(onnx_path)
else:
return sum([f.stat().st_size for f in Path(onnx_path).parent.rglob("*")])
def get_latency_name(batch_size, sequence_length, past_sequence_length):
return f"average_latency(batch_size={batch_size},sequence_length={sequence_length},past_sequence_length={past_sequence_length})"
def main(argv=None, experiment_name: str = "", run_id: str = "0", csv_filename: str = "gpt2_parity_results.csv"):
result = {}
if version.parse(transformers_version) < version.parse(
"3.1.0"
): # past_key_values name does not exist in 3.0.2 or older
raise RuntimeError("This tool requires transformers 3.1.0 or later.")
args = parse_arguments(argv)
setup_logger(args.verbose)
if not experiment_name:
experiment_name = " ".join(argv if argv else sys.argv[1:])
if args.tolerance == 0:
args.tolerance = DEFAULT_TOLERANCE[args.precision]
logger.info(f"Arguments:{args}")
cache_dir = args.cache_dir
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
prepare_environment(cache_dir, output_dir, args.use_gpu)
if args.precision != Precision.FLOAT32:
assert args.optimize_onnx, "fp16/int8 requires --optimize_onnx"
if args.precision == Precision.FLOAT16:
assert args.use_gpu, "fp16 requires --use_gpu"
if args.precision == Precision.INT8:
assert not args.use_gpu, "quantization only supports CPU"
model_class = MODEL_CLASSES[args.model_class][0]
use_padding = MODEL_CLASSES[args.model_class][2]
gpt2helper = Gpt2Helper
config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=cache_dir)
model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=cache_dir)
device = torch.device("cuda:0" if args.use_gpu else "cpu")
model.eval().to(device)
if (not args.use_external_data_format) and (config.n_layer > 24):
logger.info("Try --use_external_data_format when model size > 2GB")
onnx_model_paths = gpt2helper.get_onnx_paths(
output_dir,
args.model_name_or_path,
args.model_class,
new_folder=(args.precision == Precision.INT8),
remove_existing=["fp32", "fp16", "int8"],
) # Do not remove raw model to save time in parity test
raw_onnx_model = onnx_model_paths["raw"]
int_data_type = torch.int64 if args.use_int64_inputs else torch.int32
if os.path.exists(raw_onnx_model) and not args.overwrite:
logger.warning(f"Skip exporting ONNX model since it existed: {raw_onnx_model}")
else:
logger.info(f"Exporting ONNX model to {raw_onnx_model}")
gpt2helper.export_onnx(
model,
device,
raw_onnx_model,
args.verbose,
args.use_external_data_format,
has_position_ids=use_padding,
has_attention_mask=use_padding,
input_ids_dtype=int_data_type,
position_ids_dtype=int_data_type,
attention_mask_dtype=int_data_type,
)
fp16_params = {"keep_io_types": args.keep_io_types}
if args.io_block_list:
fp16_params["keep_io_types"] = args.io_block_list
if args.node_block_list:
fp16_params["node_block_list"] = args.node_block_list
if args.op_block_list:
fp16_params["op_block_list"] = args.op_block_list
if args.force_fp16_initializers:
fp16_params["force_fp16_initializers"] = args.force_fp16_initializers
is_io_float16 = args.precision == Precision.FLOAT16 and not args.keep_io_types
optimized_ops = ""
all_ops = ""
if args.optimize_onnx or args.precision != Precision.FLOAT32:
output_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else "fp32"]
logger.info(f"Optimizing model to {output_path}")
m = gpt2helper.optimize_onnx(
raw_onnx_model,
output_path,
args.precision == Precision.FLOAT16,
model.config.num_attention_heads,
model.config.hidden_size,
args.use_external_data_format,
auto_mixed_precision=args.auto_mixed_precision,
stage=args.stage,
**fp16_params,
)
nodes = m.nodes()
op_list = {node.op_type for node in nodes}
all_ops = ",".join(op_list)
# print optimized operators
optimized_op_counter = m.get_fused_operator_statistics()
if optimized_op_counter:
optimized_ops = ",".join([key for key in optimized_op_counter if optimized_op_counter[key] > 0])
else:
output_path = raw_onnx_model
if args.precision == Precision.INT8:
logger.info("quantizing model...")
QuantizeHelper.quantize_onnx_model(output_path, onnx_model_paths["int8"], args.use_external_data_format)
model = QuantizeHelper.quantize_torch_model(model)
logger.info("finished quantizing model")
output_path = onnx_model_paths["int8"]
if args.output.endswith(".onnx") and output_path != args.output and not args.use_external_data_format:
shutil.move(output_path, args.output)
output_path = args.output
logger.info(f"Output path: {output_path}")
model_size_in_MB = int(get_onnx_model_size(output_path, args.use_external_data_format) / 1024 / 1024) # noqa: N806
provider = args.provider
session = create_onnxruntime_session(
output_path, args.use_gpu, provider, enable_all_optimization=True, verbose=args.verbose
)
if args.model_class == "GPT2LMHeadModel" and session is not None:
parity_result = gpt2helper.test_parity(
session,
model,
device,
is_io_float16,
rtol=args.tolerance,
atol=args.tolerance,
model_class=args.model_class,
has_position_ids=use_padding,
has_attention_mask=use_padding,
input_ids_dtype=int_data_type,
position_ids_dtype=int_data_type,
attention_mask_dtype=int_data_type,
test_cases_per_run=args.test_cases,
total_runs=args.test_runs,
stage=args.stage,
verbose=args.verbose,
)
# An example configuration for testing performance
batch_size = 8
sequence_length = 32 if args.stage == 1 else 1
past_sequence_length = 0 if args.stage == 1 else 32
latency = gpt2helper.test_performance(
session,
model,
device,
is_io_float16,
total_runs=100,
use_io_binding=True,
model_class=args.model_class,
has_position_ids=use_padding,
has_attention_mask=use_padding,
input_ids_dtype=int_data_type,
position_ids_dtype=int_data_type,
attention_mask_dtype=int_data_type,
batch_size=batch_size,
sequence_length=sequence_length,
past_sequence_length=past_sequence_length,
)
if args.precision == Precision.FLOAT16:
logger.info(f"fp16 conversion parameters:{fp16_params}")
# Write results to file
latency_name = get_latency_name(batch_size, sequence_length, past_sequence_length)
csv_file_existed = os.path.exists(csv_filename)
with open(csv_filename, mode="a", newline="") as csv_file:
column_names = [
"experiment",
"run_id",
"model_name",
"model_class",
"stage",
"gpu",
"precision",
"optimizer",
"test_cases",
"runs",
"keep_io_types",
"io_block_list",
"op_block_list",
"node_block_list",
"force_fp16_initializers",
"auto_mixed_precision",
"optimized_operators",
"operators",
"environment_variables",
"onnxruntime",
latency_name,
"top1_match_rate",
"onnx_size_in_MB",
"diff_50_percentile",
"diff_90_percentile",
"diff_95_percentile",
"diff_99_percentile",
"diff_pass_rate",
"nan_rate",
"top1_match_rate_per_run",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
if not csv_file_existed:
csv_writer.writeheader()
row = {
"experiment": experiment_name,
"run_id": run_id,
"model_name": args.model_name_or_path,
"model_class": args.model_class,
"stage": args.stage,
"gpu": args.use_gpu,
"precision": args.precision,
"optimizer": args.optimize_onnx,
"test_cases": args.test_cases,
"runs": args.test_runs,
"keep_io_types": args.keep_io_types,
"io_block_list": args.io_block_list,
"op_block_list": args.op_block_list,
"node_block_list": args.node_block_list,
"force_fp16_initializers": args.force_fp16_initializers,
"auto_mixed_precision": args.auto_mixed_precision,
"optimized_operators": optimized_ops,
"operators": all_ops,
"environment_variables": get_ort_environment_variables(),
"onnxruntime": ort_version,
latency_name: f"{latency:.2f}",
"diff_50_percentile": parity_result["max_diff_percentile_50"],
"diff_90_percentile": parity_result["max_diff_percentile_90"],
"diff_95_percentile": parity_result["max_diff_percentile_95"],
"diff_99_percentile": parity_result["max_diff_percentile_99"],
"diff_pass_rate": parity_result["diff_pass_rate"],
"nan_rate": parity_result["nan_rate"],
"top1_match_rate": parity_result["top1_match_rate"],
"top1_match_rate_per_run": parity_result["top1_match_rate_per_run"],
"onnx_size_in_MB": f"{model_size_in_MB}",
}
logger.info(f"result: {row}")
result.update(row)
csv_writer.writerow(row)
if args.input_test_file:
test_inputs = []
# Each line of test file is a JSON string like:
# {"input_ids": [[14698, 257, 1310, 13688, 319, 326]]}
with open(args.input_test_file) as read_f:
for _, line in enumerate(read_f):
line = line.rstrip() # noqa: PLW2901
data = json.loads(line)
input_ids = torch.from_numpy(numpy.asarray(data["input_ids"], dtype=numpy.int64)).to(device)
if use_padding:
if "attention_mask" in data:
numpy_float = numpy.float16 if is_io_float16 else numpy.float32
attention_mask = torch.from_numpy(numpy.asarray(data["attention_mask"], dtype=numpy_float)).to(
device
)
else:
padding = -1
attention_mask = (input_ids != padding).type(torch.float16 if is_io_float16 else torch.float32)
input_ids.masked_fill_(input_ids == padding, 0)
if "position_ids" in data:
position_ids = torch.from_numpy(numpy.asarray(data["position_ids"], dtype=numpy.int64)).to(
device
)
else:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(position_ids < 0, 0)
inputs = {
"input_ids": input_ids.to(int_data_type),
"position_ids": position_ids.to(int_data_type),
"attention_mask": attention_mask.to(int_data_type),
}
else:
inputs = {"input_ids": input_ids.to(int_data_type)}
test_inputs.append(inputs)
Gpt2Tester.test_generation(
session,
model,
device,
test_inputs,
precision=args.precision,
model_class=args.model_class,
top_k=20,
top_k_no_order=True,
max_steps=24,
max_inputs=0,
verbose=args.verbose,
save_test_data=3,
save_test_data_dir=Path(output_path).parent,
)
logger.info(f"Done. Output model: {output_path}")
return result
if __name__ == "__main__":
main()