Repository URL to install this package:
|
Version:
1.23.2 ▾
|
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# -------------------------------------------------------------------------
import logging
import random
import torch
from transformers import MT5Config, T5Config
logger = logging.getLogger(__name__)
class T5Encoder(torch.nn.Module):
"""T5 encoder outputs only the last hidden state"""
def __init__(self, encoder, config: T5Config | MT5Config):
super().__init__()
self.encoder = encoder
self.config = config
def forward(self, input_ids, attention_mask):
return self.encoder(input_ids, attention_mask)[0]
class T5EncoderInputs:
def __init__(self, input_ids, attention_mask):
self.input_ids: torch.LongTensor = input_ids
self.attention_mask: torch.LongTensor = attention_mask
@staticmethod
def create_dummy(
batch_size: int,
sequence_length: int,
vocab_size: int,
device: torch.device,
use_int32_inputs: bool = False,
): # -> T5EncoderInputs
"""Create dummy inputs for T5 encoder.
Args:
batch_size (int): batch size
sequence_length (int): sequence length
vocab_size (int): vocabulary size
device (torch.device): device of output tensors
Returns:
T5EncoderInputs: dummy inputs for encoder
"""
dtype = torch.int32 if use_int32_inputs else torch.int64
input_ids = torch.randint(
low=0,
high=vocab_size - 1,
size=(batch_size, sequence_length),
dtype=dtype,
device=device,
)
attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
if sequence_length >= 2:
for i in range(batch_size):
padding_position = random.randint(0, sequence_length - 1)
attention_mask[i, :padding_position] = 0
return T5EncoderInputs(input_ids, attention_mask)
def to_list(self) -> list:
input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None]
return input_list