Repository URL to install this package:
|
Version:
1.23.2 ▾
|
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import tempfile
from itertools import chain
from pathlib import Path
import numpy as np
import onnx
import torch
from float16 import convert_float_to_float16
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
from onnx import ModelProto, ValueInfoProto
from onnx_model import OnnxModel
from past_helper import PastKeyValuesHelper
from transformers import WhisperConfig
from whisper_inputs import (
convert_inputs_for_ort,
get_model_dynamic_axes,
get_sample_decoder_inputs,
group_past_key_values,
)
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class WhisperDecoder(torch.nn.Module):
"""A Whisper decoder with optional past key values"""
def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str, no_beam_search_op: bool = False):
super().__init__()
self.config = config
self.device = model.device
self.model_impl = model_impl
self.no_beam_search_op = no_beam_search_op
self.decoder = None if model_impl == "openai" else model.model.decoder
self.proj_out = None if model_impl == "openai" else model.proj_out
self.model = model if model_impl == "openai" else None
self.max_source_positions = self.config.max_source_positions
self.num_heads = self.config.decoder_attention_heads
self.head_size = self.config.d_model // self.num_heads
def hf_forward(
self,
decoder_input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
past_key_values: list[tuple[torch.Tensor]] | None = None,
):
outputs = self.decoder(
encoder_hidden_states=encoder_hidden_states,
input_ids=decoder_input_ids,
past_key_values=past_key_values,
use_cache=True,
)
logits = self.proj_out(outputs.last_hidden_state)
present_key_values = outputs.past_key_values
if past_key_values is None:
# Return present_self_* and present_cross_* for decoder-init
return logits, present_key_values
# Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
# (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
# After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1), ...,
# (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1), ...
present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(present_key_values)
# Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
return logits, present_self
def oai_forward(
self,
decoder_input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
past_key_values: list[tuple[torch.Tensor]] | None = None,
):
past_kv_cache = {}
if past_key_values is not None:
# Convert past KV caches (BxNxSxH --> BxSxNxH --> BxSxD) for OpenAI's forward pass
self_attn_kv_caches, cross_attn_kv_caches = group_past_key_values(past_key_values)
self_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in self_attn_kv_caches]
self_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in self_attn_kv_caches]
cross_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in cross_attn_kv_caches]
cross_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in cross_attn_kv_caches]
for idx, block in enumerate(self.model.decoder.blocks):
past_kv_cache[block.attn.key] = self_attn_kv_caches[2 * idx]
past_kv_cache[block.attn.value] = self_attn_kv_caches[2 * idx + 1]
past_kv_cache[block.cross_attn.key] = cross_attn_kv_caches[2 * idx]
past_kv_cache[block.cross_attn.value] = cross_attn_kv_caches[2 * idx + 1]
# Install OpenAI's hooks on the forward pass of each nn.Linear for key and value
# since the hooks will capture the output of the key and value MatMuls, which
# represent the current keys and values.
#
# For OpenAI's forward pass, the hook function will also perform the concat
# operation (past_kv + curr_kv --> pres_kv) if needed. However, the ONNX model
# will not contain this concat operation because the present KV caches aren't
# returned by OpenAI's forward pass.
kv_cache, hooks = self.model.install_kv_cache_hooks()
# Run forward pass
# NOTE: There is a bug with openai-whisper==20240930 with the introduction of SDPA.
# In the Whisper codebase, the following line
#
# is_causal = mask is not None and n_ctx > 1
#
# has been added where `mask` is a torch tensor. The right-hand side evaluates to `tensor(True/False)`
# but `is_causal` only accepts the boolean value. The fix is to apply `.item()` after the right-hand
# side has been evaluated. In other words, the line should be
#
# is_causal = (mask is not None and n_ctx > 1).item()
#
# instead.
logits = self.model.decoder(x=decoder_input_ids, xa=encoder_hidden_states, kv_cache=past_kv_cache)
# Re-do concat operation on self attention KV caches for ONNX export (if past self attention KV caches exist)
if past_key_values is not None:
for block in self.model.decoder.blocks:
kv_cache[block.attn.key] = torch.cat(
[past_kv_cache[block.attn.key], kv_cache[block.attn.key]], dim=1
).detach()
kv_cache[block.attn.value] = torch.cat(
[past_kv_cache[block.attn.value], kv_cache[block.attn.value]], dim=1
).detach()
present_self, present_cross = [], []
for block in self.model.decoder.blocks:
# Group self and cross values
present_self.append(kv_cache[block.attn.key])
present_self.append(kv_cache[block.attn.value])
if past_key_values is None:
# Return present_self_* and present_cross_* for decoder-init
present_cross.append(kv_cache[block.cross_attn.key])
present_cross.append(kv_cache[block.cross_attn.value])
# Convert present KV caches (BxSxD --> BxSxNxH --> BxNxSxH) after OpenAI's forward pass
present_self = [
present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
for present_kv in present_self
]
present_cross = [
present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
for present_kv in present_cross
]
# Remove OpenAI's hooks since they can persist after this function completes
for hook in hooks:
hook.remove()
if past_key_values is None:
# Return present_self_* and present_cross_* for decoder-init
present_key_values = PastKeyValuesHelper.group_by_layer(
present_self + present_cross, len(present_self) // 2
)
return logits, present_key_values
# Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
return logits, present_self
def forward(
self,
decoder_input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
past_key_values: list[tuple[torch.Tensor]] | None = None,
):
if self.model_impl == "openai":
return self.oai_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
return self.hf_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
def input_names(self):
if self.first_pass:
input_names = ["input_ids", "encoder_hidden_states"]
else:
input_names = [
"input_ids",
"encoder_hidden_states",
*list(
chain.from_iterable(
(f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}")
for i in range(self.config.decoder_layers)
)
),
]
return input_names
def output_names(self):
if self.first_pass:
output_names = [
"logits",
*list(
chain.from_iterable(
(
f"present_key_self_{i}",
f"present_value_self_{i}",
f"present_key_cross_{i}",
f"present_value_cross_{i}",
)
for i in range(self.config.decoder_layers)
)
),
]
else:
output_names = [
"logits",
*list(
chain.from_iterable(
(f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers)
)
),
]
return output_names
def dynamic_axes(self, input_names, output_names):
dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
if "input_ids" in dynamic_axes and not self.no_beam_search_op:
# Set dynamic axes for `input_ids` when using beam search op to {0: "batch_size"} only
del dynamic_axes["input_ids"][1]
return dynamic_axes
def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
inputs = get_sample_decoder_inputs(
self.config,
self.device,
batch_size=2,
past_sequence_length=(0 if self.first_pass else 6),
sequence_length=(6 if self.first_pass else 1),
use_fp16=use_fp16_inputs,
use_int32=use_int32_inputs,
)
if return_dict:
if self.first_pass:
del inputs["past_key_values"]
return inputs
if self.first_pass:
return (
inputs["decoder_input_ids"],
inputs["encoder_hidden_states"],
)
return (
inputs["decoder_input_ids"],
inputs["encoder_hidden_states"],
inputs["past_key_values"],
)
def fix_key_value_cache_dims(self, io: ValueInfoProto, is_cross: bool = False, is_output: bool = False):
# Shape should be (batch_size, num_heads, sequence_length, head_size) for self attention KV caches
# and (batch_size, num_heads, num_frames // 2, head_size) for cross attention KV caches
num_heads = io.type.tensor_type.shape.dim[1]
if "_dim_" in num_heads.dim_param:
num_heads.Clear()
num_heads.dim_value = self.num_heads
sequence_length = io.type.tensor_type.shape.dim[2]
if "_dim_" in sequence_length.dim_param:
sequence_length.Clear()
if is_cross:
sequence_length.dim_value = self.max_source_positions
else:
sequence_length.dim_param = "total_sequence_length" if is_output else "past_sequence_length"
head_size = io.type.tensor_type.shape.dim[3]
if "_dim_" in head_size.dim_param:
head_size.Clear()
head_size.dim_value = self.head_size
return io
def fix_io(self, io_list: RepeatedCompositeFieldContainer, is_output: bool = False):
# Fix order of inputs/outputs and each dim_value of input/output
reordered_io = []
self_attn_kv_caches = []
cross_attn_kv_caches = []
for io in io_list:
if "past" not in io.name and "present" not in io.name:
reordered_io.append(io)
elif "self" in io.name:
# Self attention KV caches
new_io = self.fix_key_value_cache_dims(io, is_cross=False, is_output=is_output)
if self.no_beam_search_op:
reordered_io.append(new_io)
else:
self_attn_kv_caches.append(new_io)
else:
# Cross attention KV caches
new_io = self.fix_key_value_cache_dims(io, is_cross=True, is_output=is_output)
if self.no_beam_search_op:
reordered_io.append(new_io)
else:
cross_attn_kv_caches.append(new_io)
if not self.no_beam_search_op:
reordered_io += self_attn_kv_caches + cross_attn_kv_caches
return reordered_io
def fix_inputs_and_outputs(self, model: ModelProto):
# ONNX exporter might mark dimensions like 'Transposepresent_value_self_1_dim_2' in shape inference.
# We now change the dim_values to the correct one.
reordered_inputs = self.fix_io(model.graph.input, is_output=False)
while len(model.graph.input) > 0:
model.graph.input.pop()
model.graph.input.extend(reordered_inputs)
reordered_outputs = self.fix_io(model.graph.output, is_output=True)
while len(model.graph.output) > 0:
model.graph.output.pop()
model.graph.output.extend(reordered_outputs)
return model
def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
if self.model_impl == "openai" and use_fp16_inputs:
# Cast ONNX model to float16 to ensure LayerNorm weights are converted from
# float32 to float16 since exported model already has float16 weights everywhere
# except for LayerNorm ops. This happens because OpenAI always upcasts to float32
# when computing LayerNorm.
#
# Reference:
# https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
model = convert_float_to_float16(model)
return model
def export_onnx(
self,
onnx_model_path: str,
provider: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_fp16_inputs: bool = False,
use_int32_inputs: bool = True,
use_encoder_hidden_states: bool = False,
use_kv_cache_inputs: bool = True,
):
"""Export decoder to ONNX
Args:
onnx_model_path (str): path to save ONNX model
provider (str): provider to use for verifying parity on ONNX model
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
use_fp16_inputs (bool, optional): use float16 inputs for the KV caches. Defaults to False.
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
use_encoder_hidden_states (bool, optional): use encoder_hidden_states as model input for decoder-init/decoder-without-past models. Defaults to False.
use_kv_cache_inputs (bool, optional): use KV caches as model inputs for decoder-with-past models. Defaults to True.
"""
# Shape of decoder's tensors:
# Required Inputs:
# decoder_input_ids: (batch_size, sequence_length)
# Optional Inputs:
# encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
# past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
# past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
# Outputs:
# logits: (batch_size, sequence_length, vocab_size)
# present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
# present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
# For the first pass through the decoder (i.e. decoder-init/decoder-without-past)
self.first_pass = use_encoder_hidden_states and not use_kv_cache_inputs
# For subsequent passes through the decoder (i.e. decoder-with-past)
self.later_pass = not use_encoder_hidden_states and use_kv_cache_inputs
assert self.first_pass or self.later_pass, (
"Only one of `use_encoder_hidden_states` and `use_kv_cache_inputs` can be true at once."
)
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
input_names = self.input_names()
output_names = self.output_names()
dynamic_axes = self.dynamic_axes(input_names, output_names)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
torch.onnx.export(
self,
args=inputs,
f=out_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=17,
do_constant_folding=True,
verbose=verbose,
)
model = onnx.load_model(out_path, load_external_data=use_external_data_format)
model = self.fix_inputs_and_outputs(model)
model = self.fix_layernorm_weights(model, use_fp16_inputs)
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=use_external_data_format,
all_tensors_to_one_file=True,
)
self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
def verify_onnx(
self,
onnx_model_path: str,
provider: str,
use_fp16_inputs: bool,
use_int32_inputs: bool,
):
"""Verify ONNX model outputs and PyTorch model outputs match
Args:
onnx_model_path (str): path to save ONNX model
provider (str): execution provider for ONNX model
use_fp16_inputs (bool, optional): use float16 inputs for the KV caches
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids
"""
# Shape of decoder's tensors:
# Required Inputs:
# decoder_input_ids: (batch_size, sequence_length)
# Optional Inputs:
# encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
# past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
# past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
# Outputs:
# logits: (batch_size, sequence_length, vocab_size)
# present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
# present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
# Run PyTorch model
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
pt_outputs = []
if self.first_pass:
out = self.forward(**inputs)
pt_outputs.append(out[0].detach().cpu().numpy())
for present_key_value_layer in out[1]:
for present_key_value in present_key_value_layer:
pt_outputs.append(present_key_value.detach().cpu().numpy())
else:
out = self.forward(**inputs)
pt_outputs.append(out[0].detach().cpu().numpy())
for present_self_key_value in out[1]:
pt_outputs.append(present_self_key_value.detach().cpu().numpy())
# Run ONNX model
sess = InferenceSession(onnx_model_path, providers=[provider])
ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
# Calculate output difference
try:
for i, output_name in enumerate(self.output_names()):
diff = np.abs(pt_outputs[i] - ort_outputs[i])
logger.warning(f"Comparing {output_name}...")
logger.warning(f"Max diff: {np.max(diff)}")
except: # noqa: E722
pass