Repository URL to install this package:
|
Version:
0.7.1+cu122 ▾
|
from typing import List, Optional, Union
from torch import LongTensor
from transformers import PreTrainedTokenizer
def postprocess_generation_ids(
input_ids: LongTensor,
output_ids: LongTensor,
num_return_sequences: int,
tokenizer: Optional[PreTrainedTokenizer] = None,
pad_token_ids: Optional[int] = None,
) -> List[List[Union[str, List[int]]]]:
outputs = []
for idx, start in enumerate(range(0, len(output_ids), num_return_sequences)):
sub_output_ids = output_ids[start : start + num_return_sequences]
sub_generated_ids = sub_output_ids[..., input_ids[idx].size(0) :]
if tokenizer:
decoded_bach = (
generated_text
for generated_text in tokenizer.batch_decode(sub_generated_ids, clean_up_tokenization_spaces=True)
)
decoded_bach = list(decoded_bach)
outputs.append(decoded_bach)
else:
sub_generated_ids = sub_output_ids.cpu().numpy().tolist()
for i, one_sub_generated_ids in enumerate(sub_generated_ids):
if pad_token_ids is not None and pad_token_ids in one_sub_generated_ids:
one_sub_generated_ids = one_sub_generated_ids[: one_sub_generated_ids.index(pad_token_ids)]
sub_generated_ids[i] = one_sub_generated_ids
outputs.append(sub_generated_ids)
return outputs
__all__ = ["postprocess_generation_ids"]