Repository URL to install this package:
|
Version:
1.1.3 ▾
|
from typing import Callable, List, Optional, cast
import torch
from sarus_llm.models.modules import TransformerDecoder
def multinomial_sample_one(probs: torch.Tensor) -> torch.Tensor:
"""Samples from a multinomial distribution."""
return torch.multinomial(probs, num_samples=1)
def sample(
logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None
) -> torch.Tensor:
"""Generic sample from a probability distribution."""
# scale the logits based on temperature
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
# select the very last value from the top_k above as the pivot
pivot = v.select(-1, -1).unsqueeze(-1)
# set everything smaller than pivot value to inf since these
# should be pruned
logits = torch.where(logits < pivot, -float("Inf"), logits)
# change logits into probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
return multinomial_sample_one(probs)
def generate_next_token(
model: TransformerDecoder,
input_pos: torch.Tensor,
x: torch.Tensor,
mask: torch.Tensor,
triton_kernel: bool,
temperature: float = 1.0,
top_k: Optional[int] = None,
) -> torch.Tensor:
"""Generates the next tokens."""
# model produces logits in [bsz, seq_length, vocab_size]
# we want to take the last token's logits as the input to the next model call
logits = model(
x, input_pos=input_pos, mask=mask, triton_kernel=triton_kernel
)[:, -1]
return sample(logits, temperature, top_k)
def update_stop_tokens_tracker(
tokens: torch.Tensor,
stop_tokens: torch.Tensor,
stop_token_reached: torch.Tensor,
) -> torch.Tensor:
"""Updates which sequences have reached a stop token."""
# tokens: [bsz, 1]
# stop_tokens: [num_stop_tokens]
# stop_token_reached: [bsz]
stop_token_reached_curr = torch.isin(tokens, stop_tokens).flatten()
stop_token_reached |= stop_token_reached_curr
return stop_token_reached
@torch.inference_mode()
def generate(
model: TransformerDecoder,
prompt: torch.Tensor,
*,
padding_mask: torch.Tensor,
max_generated_tokens: int,
pad_id: int = 0,
temperature: float = 1.0,
top_k: Optional[int] = None,
stop_tokens: Optional[List[int]] = None,
custom_generate_next_token: Optional[Callable] = None,
triton_kernel: bool = False,
) -> List[List[int]]:
"""
Generates tokens from a model conditioned on a prompt.
Args:
model (TransformerDecoder): model used for generation
prompt (torch.Tensor): tensor with the token IDs associated with the given prompt,
with shape either [seq_length] or [bsz x seq_length]
max_generated_tokens (int): number of tokens to be generated
pad_id (int): token ID to use for padding, default 0.
temperature (float): value to scale the predicted logits by, default 1.0.
top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities,
default None.
stop_tokens (Optional[List[int]]): If specified, generation is stopped when any of these tokens are generated,
default None.
custom_generate_next_token (Optional[Callable]): If specified, we'll use the ``custom_generate_next_token function``.
This is generally only useful if you want to specify a ``torch.compile`` version of the generate next token for
performance reasons. If None, we use the default ``generate_next_token`` function. Default is None.
Returns:
List[List[int]]: collection of lists of generated tokens
"""
prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt
padding_mask = (
padding_mask.view(1, -1) if padding_mask.ndim == 1 else padding_mask
)
model.setup_caches(
batch_size=prompt.shape[0],
dtype=torch.float32,
device=prompt.device,
max_seq_len=prompt.shape[1] + max_generated_tokens,
)
# convert stop tokens to tensor for easy matching
torch_stop_tokens = (
torch.tensor(stop_tokens, device=prompt.device)
if stop_tokens
else None
)
bsz, prompt_length = prompt.size()
generated_tokens = prompt.clone()
# keeps track at a high level if we've already hit a stop token in a sequence so we can early stop
stop_token_reached = torch.zeros(
bsz, dtype=torch.bool, device=prompt.device
)
# everything in stop_token_mask starts as 1s, and we'll set them to 0 for sequences
# that already hit a stop token
stop_token_mask = torch.ones(
(bsz, prompt_length + 1), dtype=torch.int32, device=prompt.device
)
if custom_generate_next_token is None:
custom_generate_next_token = generate_next_token
input_pos = input_pos_from_padding_mask(padding_mask)
attention_mask = expand_mask(
padding_mask=padding_mask,
max_seq_length=prompt.shape[1] + max_generated_tokens,
indices_to_take=torch.arange(0, prompt_length, device=prompt.device),
fix_padding=True,
)
# generate the first tokens conditioned on the prompt
tokens = generate_next_token(
model,
mask=attention_mask,
input_pos=input_pos,
x=prompt,
temperature=temperature,
top_k=top_k,
triton_kernel=triton_kernel,
)
generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
# stop early if we reach a stop token in every seq
if torch_stop_tokens is not None:
stop_token_reached = update_stop_tokens_tracker(
tokens, torch_stop_tokens, stop_token_reached
)
if stop_token_reached.all().item():
return generated_tokens.tolist()
input_pos = input_pos[:, -1][:, None] + 1
curr_mask_idx = prompt_length
for _ in range(max_generated_tokens - 1):
# update stop_token_mask if we reached a stop token in a previous step
# by appending the logical not of stop_token_reached to the end of the mask
# reshaped to be bsz first
padding_mask = torch.cat(
[
padding_mask,
torch.ones(
size=(padding_mask.shape[0], 1),
device=padding_mask.device,
dtype=torch.bool,
),
],
dim=-1,
)
mask = expand_mask(
padding_mask=padding_mask,
max_seq_length=prompt.shape[1] + max_generated_tokens,
indices_to_take=torch.tensor(
[curr_mask_idx], device=prompt.device
),
fix_padding=False,
)
curr_mask_idx += 1
if torch_stop_tokens is not None:
stop_token_mask = torch.cat(
[stop_token_mask, ~stop_token_reached.reshape(bsz, 1)], dim=-1
)
tokens = custom_generate_next_token(
model,
input_pos=input_pos,
mask=mask,
x=tokens,
temperature=temperature,
top_k=top_k,
triton_kernel=triton_kernel,
)
generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
input_pos += 1
if torch_stop_tokens is not None:
stop_token_reached = update_stop_tokens_tracker(
tokens, torch_stop_tokens, stop_token_reached
)
if stop_token_reached.all().item():
break
# mask out generated tokens in seqs that already hit a stop token
if torch_stop_tokens is not None:
generated_tokens = generated_tokens * stop_token_mask
# if pad_id is not 0, replace 0 with pad_id
if pad_id != 0:
generated_tokens[generated_tokens == 0] = pad_id
return generated_tokens.tolist()
def expand_mask(
padding_mask: torch.Tensor,
max_seq_length: int,
indices_to_take: torch.Tensor,
fix_padding: bool,
) -> torch.Tensor:
"""create a 3d padding mask for the attention from a 2 one, the returned mask
has shape b,q,s where b is the batch size, q is the query length and s is
the max sequence length"""
increase_by = max(0, max_seq_length - padding_mask.shape[1])
mask = torch.cat(
[
padding_mask,
torch.zeros(
(padding_mask.shape[0], increase_by),
device=padding_mask.device,
dtype=torch.bool,
),
],
dim=-1,
)
expanded_mask = mask.unsqueeze(1) * mask.unsqueeze(2)
causal_mask = torch.tril(
torch.ones(
(mask.shape[0], mask.shape[1], mask.shape[1]),
dtype=torch.bool,
device=padding_mask.device,
)
)
out = (expanded_mask * causal_mask)[:, indices_to_take]
if fix_padding:
# need to add this because of a bug with torch sdpa
# when a query does not attend to any key, which happens
# with padding
# https://github.com/pytorch/pytorch/issues/103749
out += torch.eye(
out.shape[1], out.shape[2], dtype=torch.bool, device=out.device
)[None, :]
return cast(torch.Tensor, out)
def input_pos_from_padding_mask(padding_mask: torch.Tensor) -> torch.Tensor:
"""Creates position ids from padding pask, where mask is 0, id position is set
to 0 by default"""
shift_input_pos = (
torch.argmax(padding_mask.float(), dim=1)[:, None]
.repeat((1, padding_mask.shape[1]))
.to(padding_mask.device)
)
position_ids = torch.arange(
0, padding_mask.shape[1], device=padding_mask.device
).repeat((padding_mask.shape[0], 1))
position_ids -= shift_input_pos
return torch.where(position_ids < 0, 0, position_ids)