Repository URL to install this package:
|
Version:
1.23.2 ▾
|
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import numpy as np
import torch
from transformers import WhisperConfig
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
# Create audio_features for encoder
# Shape is (batch_size, feature_size, sequence_length) = (batch_size, num_mel_filters, num_frames)
# where num_mel_filters is a model attribute and num_frames = (chunk_length * sample_rate) // hop_length.
#
# Hard-coded audio hyperparameters:
# SAMPLE_RATE = 16000
# N_FFT = 400
# HOP_LENGTH = 160
# CHUNK_LENGTH = 30 (i.e. 30-second chunk of audio)
# N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE = 30 * 16000 = 480000 (i.e. 480,000 samples in a 30-second chunk of audio)
# N_FRAMES = N_SAMPLES // HOP_LENGTH = 480000 // 160 = 3000 (i.e. 3000 frames in a mel spectrogram input)
#
# N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 = 160 * 2 = 320
# FRAMES_PER_TOKEN = SAMPLE_RATE // HOP_LENGTH = 16000 // 160 = 100 (i.e. 10 ms per audio frame)
# TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN = 16000 // 320 = 50 (i.e. 20 ms per audio token)
def get_sample_audio_features(
config: WhisperConfig,
device: torch.device,
batch_size: int,
sequence_length: int = 3000,
use_fp16: bool = False,
):
torch_dtype = torch.float16 if use_fp16 else torch.float32
audio_features = torch.randn(batch_size, config.num_mel_bins, sequence_length, device=device, dtype=torch_dtype)
return audio_features
# Create input_ids for decoder
# Shape is (batch_size, sequence_length) where sequence_length is the initial decoder sequence length
def get_sample_decoder_input_ids(
config: WhisperConfig,
device: torch.device,
batch_size: int,
sequence_length: int,
use_int32: bool = True,
):
torch_dtype = torch.int32 if use_int32 else torch.int64
decoder_input_ids = torch.randint(
low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=device, dtype=torch_dtype
)
return decoder_input_ids
# Create encoder_hidden_states for decoder-init
# Shape is (batch_size, num_frames // 2, hidden_size)
def get_sample_encoder_hidden_states(
config: WhisperConfig,
device: torch.device,
batch_size: int,
use_fp16: bool = False,
):
torch_dtype = torch.float16 if use_fp16 else torch.float32
encoder_hidden_states = torch.randn(
batch_size, config.max_source_positions, config.d_model, device=device, dtype=torch_dtype
)
return encoder_hidden_states
# Create past_key_values
# Self-attention KV caches are of shape (batch_size, num_heads, past_sequence_length, head_size)
# Cross-attention KV caches are of shape (batch_size, num_heads, num_frames // 2, head_size)
def get_sample_past_key_values(
config: WhisperConfig,
device: torch.device,
batch_size: int,
past_seq_len: int,
use_fp16: bool = False,
):
num_heads = config.decoder_attention_heads
head_size = config.d_model // num_heads
max_source_positions = (
config.max_source_positions
) # equal to num_frames // 2 = encoder's sequence_length // 2 = 3000 // 2 = 1500
torch_dtype = torch.float16 if use_fp16 else torch.float32
self_attention_kv_caches = [
(
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
)
for _ in range(config.decoder_layers)
]
cross_attention_kv_caches = [
(
torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
)
for _ in range(config.decoder_layers)
]
return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches)
# Flatten KV caches into pairs-of-4 where each pair is defined as:
# (self_attn_key_cache, self_attn_value_cache, cross_attn_key_cache, cross_attn_value_cache)
def flatten_past_key_values(
self_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
cross_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
):
past_key_values = []
for (self_k_cache, self_v_cache), (cross_k_cache, cross_v_cache) in zip(
self_attn_kv_caches, cross_attn_kv_caches, strict=False
):
layer_kv_caches = (self_k_cache, self_v_cache, cross_k_cache, cross_v_cache)
past_key_values.append(layer_kv_caches)
return past_key_values
# Group KV caches into two 1D lists where one list contains the self attention KV caches and
# one list contains the cross attention KV caches
def group_past_key_values(
kv_caches: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
):
self_attn_kv_caches, cross_attn_kv_caches = [], []
for self_k_cache, self_v_cache, cross_k_cache, cross_v_cache in kv_caches:
self_attn_kv_caches.append(self_k_cache)
self_attn_kv_caches.append(self_v_cache)
cross_attn_kv_caches.append(cross_k_cache)
cross_attn_kv_caches.append(cross_v_cache)
return self_attn_kv_caches, cross_attn_kv_caches
# Create alignment heads for timestamps
# Shape is (num_alignment_heads, 2)
def get_sample_alignment_heads(
config: WhisperConfig,
device: torch.device,
num_alignment_heads: int = 6,
use_int32: bool = True,
):
torch_dtype = torch.int32 if use_int32 else torch.int64
alignment_heads = torch.ones((num_alignment_heads, 2), device=device, dtype=torch_dtype)
return alignment_heads
# Create length of start-of-transcription sequence for timestamps
# Shape is (1)
def get_sample_sot_sequence_length(
device: torch.device,
sot_sequence_length: int,
use_int32: bool = False,
):
torch_dtype = torch.int32 if use_int32 else torch.int64
sot_length = torch.tensor([sot_sequence_length], device=device, dtype=torch_dtype)
return sot_length
# Create segment length for timestamps
# Shape is (1)
def get_sample_segment_length(
device: torch.device,
segment_length: int,
use_int32: bool = False,
):
torch_dtype = torch.int32 if use_int32 else torch.int64
segment_size = torch.tensor([segment_length], device=device, dtype=torch_dtype)
return segment_size
# Create QKs for timestamps
# Shape is (batch_size, num_heads, sequence_length, num_frames // 2)
def get_sample_QKs( # noqa: N802
config: WhisperConfig,
device: torch.device,
batch_size: int,
sequence_length: int,
use_fp16: bool = False,
):
num_heads = config.decoder_attention_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
QKs = [ # noqa: N806
torch.rand(
batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype
)
for _ in range(config.decoder_layers)
]
return QKs
# Create inputs for encoder component of Whisper
def get_sample_encoder_inputs(
config: WhisperConfig,
device: torch.device,
batch_size: int,
sequence_length: int = 3000,
use_fp16: bool = False,
):
audio_features = get_sample_audio_features(config, device, batch_size, sequence_length, use_fp16)
return {"audio_features": audio_features}
# Create inputs for encoder component + first pass through decoder component of Whisper
def get_sample_encoder_decoder_init_inputs(
config: WhisperConfig,
device: torch.device,
batch_size: int,
decoder_sequence_length: int,
encoder_sequence_length: int = 3000,
use_fp16: bool = False,
use_int32: bool = True,
):
audio_features = get_sample_audio_features(config, device, batch_size, encoder_sequence_length, use_fp16)
decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, decoder_sequence_length, use_int32)
return {"audio_features": audio_features, "decoder_input_ids": decoder_input_ids}
# Create inputs for decoder component of Whisper
# Inputs for first pass through the decoder (i.e. decoder-init): decoder_input_ids, encoder_hidden_states
# Inputs for subsequent passes through the decoder (i.e. decoder-with-past): decoder_input_ids, past_key_values
def get_sample_decoder_inputs(
config: WhisperConfig,
device: torch.device,
batch_size: int,
past_sequence_length: int,
sequence_length: int,
use_fp16: bool = False,
use_int32: bool = True,
):
decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, sequence_length, use_int32)
encoder_hidden_states = get_sample_encoder_hidden_states(config, device, batch_size, use_fp16)
past_key_values = get_sample_past_key_values(config, device, batch_size, past_sequence_length, use_fp16)
return {
"decoder_input_ids": decoder_input_ids,
"encoder_hidden_states": encoder_hidden_states,
"past_key_values": past_key_values,
}
# Create inputs for timestamps component of Whisper
def get_sample_jump_times_inputs(
config: WhisperConfig,
device: torch.device,
batch_size: int,
sequence_length: int,
num_alignment_heads: int,
sot_sequence_length: int,
segment_length: int,
use_fp16: bool = False,
use_int32: bool = True,
):
alignment_heads = get_sample_alignment_heads(config, device, num_alignment_heads, use_int32)
# lengths need to be int64 because subsequent 'Slice' ops only take int64 inputs
sot_sequence_length = get_sample_sot_sequence_length(device, sot_sequence_length)
segment_length = get_sample_segment_length(device, segment_length)
QKs = get_sample_QKs(config, device, batch_size, sequence_length, use_fp16) # noqa: N806
return {
"alignment_heads": alignment_heads,
"sot_sequence_length": sot_sequence_length,
"segment_length": segment_length,
"QKs": QKs,
}
# Convert PyTorch inputs to ONNX Runtime inputs
def convert_inputs_for_ort(
inputs: dict,
model: InferenceSession,
):
self_attn_kv_caches, cross_attn_kv_caches = None, None
batch_size, num_heads, past_seq_len, head_size = 0, 0, 0, 0
num_beams, max_seq_len = 1, 448
if "past_key_values" in inputs:
(self_attn_kv_caches, cross_attn_kv_caches) = group_past_key_values(inputs["past_key_values"])
batch_size, num_heads, past_seq_len, head_size = self_attn_kv_caches[0].shape
ort_inputs = {}
model_inputs = list(map(lambda i: i.name, model.get_inputs())) # noqa: C417
use_buffer_sharing = "cache_indirection" in model_inputs
for name in model_inputs:
if name in {"audio_features", "encoder_input_ids"}:
# Encoder input
ort_inputs[name] = inputs["audio_features"].detach().cpu().numpy()
elif name == "encoder_hidden_states":
# Encoder output
ort_inputs[name] = inputs["encoder_hidden_states"].detach().cpu().numpy()
elif name in {"decoder_input_ids", "input_ids"}:
# Decoder input
ort_inputs[name] = inputs["decoder_input_ids"].detach().cpu().numpy()
elif "past_key_self" in name or "past_value_self" in name:
# Decoder input
orig_kv_cache = self_attn_kv_caches.pop(0).detach().cpu().numpy()
if use_buffer_sharing:
new_kv_cache = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=orig_kv_cache.dtype)
new_kv_cache[:batch_size, :num_heads, :past_seq_len, :head_size] = orig_kv_cache
ort_inputs[name] = new_kv_cache
else:
ort_inputs[name] = orig_kv_cache
elif "past_key_cross" in name or "past_value_cross" in name:
# Decoder input
orig_kv_cache = cross_attn_kv_caches.pop(0).detach().cpu().numpy()
ort_inputs[name] = orig_kv_cache
elif name == "past_sequence_length":
# Decoder input
ort_inputs[name] = np.array([past_seq_len], dtype=np.int32)
elif name == "cache_indirection":
# Decoder input
ort_inputs[name] = np.zeros((batch_size, num_beams, max_seq_len), dtype=np.int32)
elif name == "alignment_heads":
# Jump times input
ort_inputs[name] = inputs["alignment_heads"].detach().cpu().numpy()
elif name == "sot_sequence_length":
# Jump times input
ort_inputs[name] = inputs["sot_sequence_length"].detach().cpu().numpy()
elif name == "segment_length":
# Jump times input
ort_inputs[name] = inputs["segment_length"].detach().cpu().numpy()
elif "cross_qk" in name:
# Jump times input
ort_inputs[name] = inputs["QKs"].pop(0).detach().cpu().numpy()
else:
raise ValueError(f"Unknown name not recognized: {name}")
return ort_inputs
# Get dynamic axes for all inputs and outputs to the model
def get_model_dynamic_axes(
config: WhisperConfig,
input_names: list[str],
output_names: list[str],
):
dynamic_axes = {}
for name in input_names + output_names:
if name in {"audio_features", "encoder_input_ids"}:
# shape is (batch_size, num_mels, num_frames)
dynamic_axes[name] = {0: "batch_size"}
elif name in {"input_ids", "decoder_input_ids"}:
# shape is (batch_size, sequence_length)
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
elif name == "alignment_heads":
# shape is (num_alignment_heads, 2)
dynamic_axes[name] = {0: "num_alignment_heads"}
elif name in {"sot_sequence_length", "segment_length"}:
# shape is (1)
pass
elif name == "logits":
# shape is (batch_size, sequence_length, vocab_size)
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
elif name == "encoder_hidden_states":
# shape is (batch_size, num_frames // 2, hidden_size)
dynamic_axes[name] = {0: "batch_size"}
elif "past_key_self" in name or "past_value_self" in name:
# shape is (batch_size, num_heads, past_sequence_length, head_size)
dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
elif "present_key_self" in name or "present_value_self" in name:
# shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size),
# which is equal to (batch_size, num_heads, total_sequence_length, head_size)
dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"}
elif (
"past_key_cross" in name
or "past_value_cross" in name
or "present_key_cross" in name
or "present_value_cross" in name
):
# shape is (batch_size, num_heads, num_frames // 2, head_size)
dynamic_axes[name] = {0: "batch_size"}
elif "cross_qk" in name:
# shape is (batch_size, num_heads, source_sequence_length, target_sequence_length)
dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"}
elif "jump_times" in name:
# shape is (batch_size, max_length)
dynamic_axes[name] = {0: "batch_size", 1: "max_length"}
else:
raise Exception(f"Unknown input or output name found: {name}")
return dynamic_axes