Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-llm / sarus_llm / data / tokenizers / llama_tokenizer.py
Size: Mime:
import os
import typing as t
from logging import getLogger
from pathlib import Path

import tiktoken
from tiktoken.load import load_tiktoken_bpe

from .exceptions import (
    ConversationFormatError,
    UnrecognizedRoleError,
    MessageFormatError,
)
from .utils import (
    SampleType,
    TokenSample,
    FinetuningAssistantMessage,
    SystemMessage,
    UserMessage,
    Roles,
    TrainingInstructSample,
    ChatMessage,
)

logger = getLogger(__name__)


class Tokenizer:
    """
    Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
    """

    special_tokens: t.Dict[str, int]

    num_reserved_special_tokens = 256

    pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"  # noqa: E501

    def __init__(self, model_path: str):
        """
        Initializes the Tokenizer with a Tiktoken model.

        Args:
            model_path (str): The path to the Tiktoken model file.
        """
        assert os.path.isfile(model_path), model_path

        mergeable_ranks = load_tiktoken_bpe(model_path)
        num_base_tokens = len(mergeable_ranks)
        special_tokens = [
            "<|begin_of_text|>",
            "<|end_of_text|>",
            "<|reserved_special_token_0|>",
            "<|reserved_special_token_1|>",
            "<|reserved_special_token_2|>",
            "<|reserved_special_token_3|>",
            "<|start_header_id|>",
            "<|end_header_id|>",
            "<|reserved_special_token_4|>",
            "<|eot_id|>",  # end of turn
        ] + [
            f"<|reserved_special_token_{i}|>"
            for i in range(5, self.num_reserved_special_tokens - 5)
        ]
        self.special_tokens = {
            token: num_base_tokens + i
            for i, token in enumerate(special_tokens)
        }
        self.model = tiktoken.Encoding(
            name=Path(model_path).name,
            pat_str=self.pat_str,
            mergeable_ranks=mergeable_ranks,
            special_tokens=self.special_tokens,
        )
        logger.info(f"Reloaded tiktoken model from {model_path}")

        self.n_words: int = self.model.n_vocab
        # BOS / EOS token IDs
        self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
        self.eos_id: int = self.special_tokens["<|end_of_text|>"]
        self.pad_id: int = -1
        self.stop_tokens = {
            self.special_tokens["<|end_of_text|>"],
            self.special_tokens["<|eot_id|>"],
        }
        logger.info(
            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
        )

    def encode(
        self,
        s: str,
        *,
        bos: bool,
        eos: bool,
        allowed_special: t.Union[t.Literal["all"], t.AbstractSet[str]] = set(),
        disallowed_special: t.Union[t.Literal["all"], t.Collection[str]] = (),
    ) -> t.List[int]:
        """
        Encodes a string into a list of token IDs.

        Args:
            s (str): The input string to be encoded.
            bos (bool): Whether to prepend the beginning-of-sequence token.
            eos (bool): Whether to append the end-of-sequence token.
            allowed_tokens ("all"|set[str]): allowed special tokens in string
            disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string

        Returns:
            list[int]: A list of token IDs.

        By default, setting disallowed_special=() encodes a string by ignoring
        special tokens. Specifically:
        - Setting `disallowed_special` to () will cause all text corresponding
          to special tokens to be encoded as natural text (insteading of raising
          an error).
        - Setting `allowed_special` to "all" will treat all text corresponding
          to special tokens to be encoded as special tokens.
        """
        assert type(s) is str

        # The tiktoken tokenizer can handle <=400k chars without
        # pyo3_runtime.PanicException.
        TIKTOKEN_MAX_ENCODE_CHARS = 400_000

        # https://github.com/openai/tiktoken/issues/195
        # Here we iterate over subsequences and split if we exceed the limit
        # of max consecutive non-whitespace or whitespace characters.
        MAX_NO_WHITESPACES_CHARS = 25_000

        substrs = (
            substr
            for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
            for substr in self._split_whitespaces_or_nonwhitespaces(
                s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
            )
        )
        vals: t.List[int] = []
        for substr in substrs:
            vals.extend(
                self.model.encode(
                    substr,
                    allowed_special=allowed_special,
                    disallowed_special=disallowed_special,
                )
            )
        if bos:
            vals.insert(0, self.bos_id)
        if eos:
            vals.append(self.eos_id)
        return vals

    def decode(self, tokens: t.Sequence[int]) -> str:
        """
        Decodes a list of token IDs into a string.

        Args:
             tokens(List[int]): The list of token IDs to be decoded.

        Returns:
            str: The decoded string.
        """
        # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
        return self.model.decode(t.cast(t.List[int], tokens))

    @staticmethod
    def _split_whitespaces_or_nonwhitespaces(
        s: str, max_consecutive_slice_len: int
    ) -> t.Iterator[str]:
        """
        Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
        consecutive whitespaces or consecutive non-whitespaces.
        """
        current_slice_len = 0
        current_slice_is_space = s[0].isspace() if len(s) > 0 else False
        slice_start = 0

        for i in range(len(s)):
            is_now_space = s[i].isspace()

            if current_slice_is_space ^ is_now_space:
                current_slice_len = 1
                current_slice_is_space = is_now_space
            else:
                current_slice_len += 1
                if current_slice_len > max_consecutive_slice_len:
                    yield s[slice_start:i]
                    slice_start = i
                    current_slice_len = 1
        yield s[slice_start:]


class SarusLLamaTokenizer:
    def __init__(self):
        self.tokenizer = Tokenizer(
            os.path.join(str(Path(__file__).parent), "llama_tokenizer.model")
        )

    def tokenize_pretrain(self, sample: str, training: bool) -> TokenSample:
        tokens = self.tokenizer.encode(sample, bos=True, eos=training)
        masks = [True] * len(tokens)
        return TokenSample(tokens, masks)

    def tokenize(
        self, sample: t.Union[str, TrainingInstructSample], training: bool
    ) -> TokenSample:
        if isinstance(sample, str):
            return self.tokenize_pretrain(sample, training)
        elif isinstance(sample, TrainingInstructSample):
            return self.tokenize_instruct(sample, training)

        raise ValueError(
            f"`sample` has to be either of type `str` or `TrainingInstructSample`, not {type(sample)}."
        )

    def get_pretrain_sample(self, data: t.Dict[str, t.Any]) -> str:
        content_keys = ["text", "content"]
        assert not all(k in data for k in content_keys), (
            "Make sure to have either 'text' or 'content' in your data. Not both."
        )
        assert any(data.get(k) is not None for k in content_keys), (
            f"Must have one of 'text' or 'content' in your data. Only have {data.keys()}"
        )

        # get first non-None value
        sample = None
        for key in content_keys:
            sample = data[key] if key in data else sample
        assert isinstance(sample, str), sample
        return sample

    def build_instruct_sample(
        self, data: t.Dict[str, t.Any], training: bool
    ) -> TrainingInstructSample:
        messages: t.List[
            t.Union[SystemMessage, UserMessage, FinetuningAssistantMessage]
        ] = []
        # optional data fields that might be set
        system_prompt = data.get("system_prompt")

        messages_keys = ["messages", "interactions"]
        content_keys = ["content", "text"]  # both are accepted
        allowed_roles = [role.value for role in Roles]

        if not any(messages_key in data for messages_key in messages_keys):
            err = f"The conversation does not contain one of '{', '.join(messages_keys)}' key, but only {', '.join(data.keys())}. Make sure that the conversation includes one of '{', '.join(messages_keys)}'."
            raise ConversationFormatError(err, str(data))

        if all(messages_key in data for messages_key in messages_keys):
            err = f"The conversation cannot contain both of '{', '.join(messages_keys)}' key, but only one of the two."
            raise ConversationFormatError(err, str(data))

        # get first non-None value
        data_messages: t.Optional[t.List[t.Dict[str, t.Any]]] = None
        for key in messages_keys:
            data_messages = data[key] if key in data else data_messages

        assert data_messages is not None, "data_messages can't be None"

        for data_message in data_messages:
            if "role" not in data_message:
                err = f"A message does not contain a 'role' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'role'."
                raise MessageFormatError(err, str(data))

            role = data_message["role"]

            if all(key in data_message for key in content_keys):
                err = f"A {role} message contains both a 'text' and 'content' key. Make sure that there is only one of the two."
                raise MessageFormatError(err, str(data))

            content: t.Optional[str] = None
            for key in content_keys:
                content = (
                    content if content is not None else data_message.get(key)
                )

            # non-function call message must have content
            if content is None:
                err = (
                    f"A {role} message does not contain one of '{content_keys}' key, but only {', '.join(data_message.keys())}."
                    f" Make sure that the message includes one of '{content_keys}' keys."
                )
                raise MessageFormatError(err, str(data))

            if role not in allowed_roles:
                raise UnrecognizedRoleError(role, allowed_roles)

            if data_message["role"] == "user":
                assert content is not None
                messages.append(UserMessage(content=content))

            elif data_message["role"] == "assistant":
                weight = data_message.get("weight")
                messages.append(
                    FinetuningAssistantMessage(content=content, weight=weight)
                )
            elif data_message["role"] == "system":
                if system_prompt is not None:
                    err = "Multiple messages with role 'system' encountered. Only one is allowed."
                    raise MessageFormatError(err, str(data))

                system_prompt = content

        # TODO: add validation for Llama messages
        # whether to train only on last assistant message
        only_last = data.get("only_last", False)
        if system_prompt is not None:
            messages.insert(0, SystemMessage(content=system_prompt))
        return TrainingInstructSample(
            messages=messages,
            only_last=only_last,
        )

    def tokenize_instruct(
        self, sample: TrainingInstructSample, training: bool
    ) -> TokenSample:
        tokens = []
        tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
        mask = [not training] * len(tokens)
        for message in sample.messages:
            curr_tok, curr_mask = self.encode_message(message, training)
            tokens.extend(curr_tok)
            mask.extend(curr_mask)
        if not training:
            # add empty message for chat completion
            last_tok = self.encode_header(
                FinetuningAssistantMessage(content="")
            )
            tokens.extend(last_tok)
            mask.extend([True] * len(last_tok))
        return TokenSample(tokens=tokens, masks=mask)

    def encode_header(self, message: ChatMessage) -> t.List[int]:
        tokens = []
        tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
        tokens.extend(
            self.tokenizer.encode(message.role.value, bos=False, eos=False)
        )
        tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
        tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
        return tokens

    def encode_message(
        self, message: ChatMessage, training: bool
    ) -> t.Tuple[t.List[int], t.List[bool]]:
        tokens = self.encode_header(message)
        n_false = len(tokens) if training else 0
        tokens.extend(
            self.tokenizer.encode(
                message.content.strip(), bos=False, eos=False
            )
        )
        tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
        mask = (
            [False] * n_false + [True] * (len(tokens) - n_false)
            if isinstance(message, FinetuningAssistantMessage)
            else [not training] * len(tokens)
        )
        return tokens, mask

    def decode(self, tokens: t.List[int]) -> str:
        return self.tokenizer.decode(tokens)

    def stop_token(self, sample_type: SampleType) -> int:
        if sample_type == SampleType.PRETRAIN:
            return self.tokenizer.eos_id
        else:
            return self.tokenizer.special_tokens["<|eot_id|>"]