Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-llm / sarus_llm / generation.py
Size: Mime:
from typing import Callable, List, Optional, cast

import torch
from sarus_llm.models.modules import TransformerDecoder


def multinomial_sample_one(probs: torch.Tensor) -> torch.Tensor:
    """Samples from a multinomial distribution."""

    return torch.multinomial(probs, num_samples=1)


def sample(
    logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None
) -> torch.Tensor:
    """Generic sample from a probability distribution."""
    # scale the logits based on temperature
    logits = logits / max(temperature, 1e-5)
    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        # select the very last value from the top_k above as the pivot
        pivot = v.select(-1, -1).unsqueeze(-1)
        # set everything smaller than pivot value to inf since these
        # should be pruned
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    # change logits into probabilities
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return multinomial_sample_one(probs)


def generate_next_token(
    model: TransformerDecoder,
    input_pos: torch.Tensor,
    x: torch.Tensor,
    mask: torch.Tensor,
    triton_kernel: bool,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
) -> torch.Tensor:
    """Generates the next tokens."""
    # model produces logits in [bsz, seq_length, vocab_size]
    # we want to take the last token's logits as the input to the next model call
    logits = model(
        x, input_pos=input_pos, mask=mask, triton_kernel=triton_kernel
    )[:, -1]
    return sample(logits, temperature, top_k)


def update_stop_tokens_tracker(
    tokens: torch.Tensor,
    stop_tokens: torch.Tensor,
    stop_token_reached: torch.Tensor,
) -> torch.Tensor:
    """Updates which sequences have reached a stop token."""
    # tokens: [bsz, 1]
    # stop_tokens: [num_stop_tokens]
    # stop_token_reached: [bsz]
    stop_token_reached_curr = torch.isin(tokens, stop_tokens).flatten()
    stop_token_reached |= stop_token_reached_curr
    return stop_token_reached


@torch.inference_mode()
def generate(
    model: TransformerDecoder,
    prompt: torch.Tensor,
    *,
    padding_mask: torch.Tensor,
    max_generated_tokens: int,
    pad_id: int = 0,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    stop_tokens: Optional[List[int]] = None,
    custom_generate_next_token: Optional[Callable] = None,
    triton_kernel: bool = False,
) -> List[List[int]]:
    """
    Generates tokens from a model conditioned on a prompt.

    Args:
        model (TransformerDecoder): model used for generation
        prompt (torch.Tensor): tensor with the token IDs associated with the given prompt,
            with shape either [seq_length] or [bsz x seq_length]
        max_generated_tokens (int): number of tokens to be generated
        pad_id (int): token ID to use for padding, default 0.
        temperature (float): value to scale the predicted logits by, default 1.0.
        top_k (Optional[int]): If specified, we prune the sampling to only token ids within the top_k probabilities,
            default None.
        stop_tokens (Optional[List[int]]): If specified, generation is stopped when any of these tokens are generated,
            default None.
        custom_generate_next_token (Optional[Callable]): If specified, we'll use the ``custom_generate_next_token function``.
            This is generally only useful if you want to specify a ``torch.compile`` version of the generate next token for
            performance reasons. If None, we use the default ``generate_next_token`` function. Default is None.


    Returns:
        List[List[int]]: collection of lists of generated tokens
    """

    prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt
    padding_mask = (
        padding_mask.view(1, -1) if padding_mask.ndim == 1 else padding_mask
    )
    model.setup_caches(
        batch_size=prompt.shape[0],
        dtype=torch.float32,
        device=prompt.device,
        max_seq_len=prompt.shape[1] + max_generated_tokens,
    )

    # convert stop tokens to tensor for easy matching
    torch_stop_tokens = (
        torch.tensor(stop_tokens, device=prompt.device)
        if stop_tokens
        else None
    )
    bsz, prompt_length = prompt.size()
    generated_tokens = prompt.clone()
    # keeps track at a high level if we've already hit a stop token in a sequence so we can early stop
    stop_token_reached = torch.zeros(
        bsz, dtype=torch.bool, device=prompt.device
    )
    # everything in stop_token_mask starts as 1s, and we'll set them to 0 for sequences
    # that already hit a stop token
    stop_token_mask = torch.ones(
        (bsz, prompt_length + 1), dtype=torch.int32, device=prompt.device
    )
    if custom_generate_next_token is None:
        custom_generate_next_token = generate_next_token
    input_pos = input_pos_from_padding_mask(padding_mask)
    attention_mask = expand_mask(
        padding_mask=padding_mask,
        max_seq_length=prompt.shape[1] + max_generated_tokens,
        indices_to_take=torch.arange(0, prompt_length, device=prompt.device),
        fix_padding=True,
    )

    # generate the first tokens conditioned on the prompt
    tokens = generate_next_token(
        model,
        mask=attention_mask,
        input_pos=input_pos,
        x=prompt,
        temperature=temperature,
        top_k=top_k,
        triton_kernel=triton_kernel,
    )
    generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
    # stop early if we reach a stop token in every seq
    if torch_stop_tokens is not None:
        stop_token_reached = update_stop_tokens_tracker(
            tokens, torch_stop_tokens, stop_token_reached
        )
        if stop_token_reached.all().item():
            return generated_tokens.tolist()
    input_pos = input_pos[:, -1][:, None] + 1
    curr_mask_idx = prompt_length
    for _ in range(max_generated_tokens - 1):
        # update stop_token_mask if we reached a stop token in a previous step
        # by appending the logical not of stop_token_reached to the end of the mask
        # reshaped to be bsz first
        padding_mask = torch.cat(
            [
                padding_mask,
                torch.ones(
                    size=(padding_mask.shape[0], 1),
                    device=padding_mask.device,
                    dtype=torch.bool,
                ),
            ],
            dim=-1,
        )
        mask = expand_mask(
            padding_mask=padding_mask,
            max_seq_length=prompt.shape[1] + max_generated_tokens,
            indices_to_take=torch.tensor(
                [curr_mask_idx], device=prompt.device
            ),
            fix_padding=False,
        )
        curr_mask_idx += 1
        if torch_stop_tokens is not None:
            stop_token_mask = torch.cat(
                [stop_token_mask, ~stop_token_reached.reshape(bsz, 1)], dim=-1
            )
        tokens = custom_generate_next_token(
            model,
            input_pos=input_pos,
            mask=mask,
            x=tokens,
            temperature=temperature,
            top_k=top_k,
            triton_kernel=triton_kernel,
        )

        generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
        input_pos += 1

        if torch_stop_tokens is not None:
            stop_token_reached = update_stop_tokens_tracker(
                tokens, torch_stop_tokens, stop_token_reached
            )
            if stop_token_reached.all().item():
                break

    # mask out generated tokens in seqs that already hit a stop token
    if torch_stop_tokens is not None:
        generated_tokens = generated_tokens * stop_token_mask
        # if pad_id is not 0, replace 0 with pad_id
        if pad_id != 0:
            generated_tokens[generated_tokens == 0] = pad_id

    return generated_tokens.tolist()


def expand_mask(
    padding_mask: torch.Tensor,
    max_seq_length: int,
    indices_to_take: torch.Tensor,
    fix_padding: bool,
) -> torch.Tensor:
    """create a 3d padding mask for the attention from a 2 one, the returned mask
    has shape b,q,s where b is the batch size, q is the query length and s is
    the max sequence length"""
    increase_by = max(0, max_seq_length - padding_mask.shape[1])
    mask = torch.cat(
        [
            padding_mask,
            torch.zeros(
                (padding_mask.shape[0], increase_by),
                device=padding_mask.device,
                dtype=torch.bool,
            ),
        ],
        dim=-1,
    )
    expanded_mask = mask.unsqueeze(1) * mask.unsqueeze(2)
    causal_mask = torch.tril(
        torch.ones(
            (mask.shape[0], mask.shape[1], mask.shape[1]),
            dtype=torch.bool,
            device=padding_mask.device,
        )
    )
    out = (expanded_mask * causal_mask)[:, indices_to_take]
    if fix_padding:
        # need to add this because of a bug with torch sdpa
        # when a query does not attend to any key, which happens
        # with padding
        # https://github.com/pytorch/pytorch/issues/103749
        out += torch.eye(
            out.shape[1], out.shape[2], dtype=torch.bool, device=out.device
        )[None, :]
    return cast(torch.Tensor, out)


def input_pos_from_padding_mask(padding_mask: torch.Tensor) -> torch.Tensor:
    """Creates position ids from padding pask, where mask is 0, id position is set
    to 0 by default"""
    shift_input_pos = (
        torch.argmax(padding_mask.float(), dim=1)[:, None]
        .repeat((1, padding_mask.shape[1]))
        .to(padding_mask.device)
    )
    position_ids = torch.arange(
        0, padding_mask.shape[1], device=padding_mask.device
    ).repeat((padding_mask.shape[0], 1))
    position_ids -= shift_input_pos
    return torch.where(position_ids < 0, 0, position_ids)