Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-llm / sarus_llm / data / tokenizers / phi_tokenizer.py
Size: Mime:
import typing as t
import logging
import os

from sentencepiece import SentencePieceProcessor

from .exceptions import (
    ConversationFormatError,
    UnrecognizedRoleError,
    MessageFormatError,
)
from pathlib import Path
from .utils import (
    SampleType,
    TokenSample,
    FinetuningAssistantMessage,
    SystemMessage,
    UserMessage,
    Roles,
    TrainingInstructSample,
)

logger = logging.getLogger(__name__)

PHI3_SPECIAL_TOKENS = {
    "<|endoftext|>": 32000,
    "<|assistant|>": 32001,
    "<|placeholder1|>": 32002,
    "<|placeholder2|>": 32003,
    "<|placeholder3|>": 32004,
    "<|placeholder4|>": 32005,
    "<|system|>": 32006,
    "<|end|>": 32007,
    "<|placeholder5|>": 32008,
    "<|placeholder6|>": 32009,
    "<|user|>": 32010,
}


class Tokenizer:
    """
    SentencePiece tokenizer configured with Phi3 Mini's special tokens.

    Args:
        path (str): Path to pretrained tokenizer file.
        special_tokens (Optional[Dict[str, int]]): mapping containing special text tokens and
            their registered token IDs. If left as None, this will be set to the canonical
            Phi3 special tokens.
        max_seq_len (Optional[int]): A max sequence length to truncate tokens to.
            Default: None
    """

    def __init__(
        self,
        path: str,
        special_tokens: t.Optional[t.Dict[str, int]] = None,
    ):
        spm_model = SentencePieceProcessor()
        spm_model.load(path)
        self._spm_model = spm_model
        self.special_tokens = (
            special_tokens
            if special_tokens is not None
            else PHI3_SPECIAL_TOKENS
        )

    @property
    def vocab_size(self) -> int:
        return self._spm_model.vocab_size()

    @property
    def bos_id(self) -> int:
        return self._spm_model.bos_id()

    @property
    def eos_id(self) -> int:
        return self.special_tokens["<|endoftext|>"]

    @property
    def pad_id(self) -> int:
        return self.special_tokens["<|endoftext|>"]

    def encode(
        self,
        text: str,
        add_bos: bool = True,
        add_eos: bool = True,
    ) -> t.List[int]:
        return self._spm_model.encode(text, add_bos=add_bos, add_eos=add_eos)

    def decode(self, ids: t.List[int]) -> str:
        """Decode token IDs to strings.

        Args:
            ids (List[int]): The input token IDs to be decoded.

        Returns:
            str: The decoded text.
        """
        ids_for_decode = []
        for token_id in ids:
            # Filter out special tokens and the placeholder tokens added
            # by the Phi3 team
            if token_id >= 32_000 and token_id <= 32_064:
                continue
            else:
                ids_for_decode.append(token_id)
        return self._spm_model.decode(ids_for_decode)


class SarusPhi3Tokenizer:
    def __init__(self):
        self.tokenizer = Tokenizer(
            os.path.join(str(Path(__file__).parent), "phi3.model")
        )

    def get_pretrain_sample(self, data: t.Dict[str, t.Any]) -> str:
        content_keys = ["text", "content"]
        assert not all(k in data for k in content_keys), (
            "Make sure to have either 'text' or 'content' in your data. Not both."
        )
        assert any(data.get(k) is not None for k in content_keys), (
            f"Must have one of 'text' or 'content' in your data. Only have {data.keys()}"
        )

        # get first non-None value
        sample = None
        for key in content_keys:
            sample = data[key] if key in data else sample
        assert isinstance(sample, str), sample
        return sample

    def build_instruct_sample(
        self, data: t.Dict[str, t.Any], training: bool
    ) -> TrainingInstructSample:
        messages: t.List[
            t.Union[SystemMessage, UserMessage, FinetuningAssistantMessage]
        ] = []
        # optional data fields that might be set
        system_prompt = data.get("system_prompt")

        messages_keys = ["messages", "interactions"]
        content_keys = ["content", "text"]  # both are accepted
        allowed_roles = [role.value for role in Roles]

        if not any(messages_key in data for messages_key in messages_keys):
            err = f"The conversation does not contain one of '{', '.join(messages_keys)}' key, but only {', '.join(data.keys())}. Make sure that the conversation includes one of '{', '.join(messages_keys)}'."
            raise ConversationFormatError(err, str(data))

        if all(messages_key in data for messages_key in messages_keys):
            err = f"The conversation cannot contain both of '{', '.join(messages_keys)}' key, but only one of the two."
            raise ConversationFormatError(err, str(data))

        # get first non-None value
        data_messages: t.Optional[t.List[t.Dict[str, t.Any]]] = None
        for key in messages_keys:
            data_messages = data[key] if key in data else data_messages

        assert data_messages is not None, "data_messages can't be None"

        for data_message in data_messages:
            if "role" not in data_message:
                err = f"A message does not contain a 'role' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'role'."
                raise MessageFormatError(err, str(data))

            role = data_message["role"]

            if all(key in data_message for key in content_keys):
                err = f"A {role} message contains both a 'text' and 'content' key. Make sure that there is only one of the two."
                raise MessageFormatError(err, str(data))

            content: t.Optional[str] = None
            for key in content_keys:
                content = (
                    content if content is not None else data_message.get(key)
                )

            # non-function call message must have content
            if content is None:
                err = (
                    f"A {role} message does not contain one of '{content_keys}' key, but only {', '.join(data_message.keys())}."
                    f" Make sure that the message includes one of '{content_keys}' keys."
                )
                raise MessageFormatError(err, str(data))

            if role not in allowed_roles:
                raise UnrecognizedRoleError(role, allowed_roles)

            if data_message["role"] == "user":
                assert content is not None
                messages.append(UserMessage(content=content))

            elif data_message["role"] == "assistant":
                weight = data_message.get("weight")
                messages.append(
                    FinetuningAssistantMessage(content=content, weight=weight)
                )
            elif data_message["role"] == "system":
                if system_prompt is not None:
                    err = "Multiple messages with role 'system' encountered. Only one is allowed."
                    raise MessageFormatError(err, str(data))

                system_prompt = content

        # TODO: add validation for Phi3 messages
        # whether to train only on last assistant message
        only_last = data.get("only_last", False)
        if system_prompt is not None:
            messages.insert(0, SystemMessage(content=system_prompt))
        return TrainingInstructSample(
            messages=messages,
            only_last=only_last,
        )

    def tokenize(
        self, sample: t.Union[str, TrainingInstructSample], training: bool
    ) -> TokenSample:
        if isinstance(sample, str):
            return self.tokenize_pretrain(sample, training)
        elif isinstance(sample, TrainingInstructSample):
            return self.tokenize_instruct(sample, training)

        raise ValueError(
            f"`sample` has to be either of type `str` or `TrainingInstructSample`, not {type(sample)}."
        )

    def tokenize_pretrain(self, sample: str, training: bool) -> TokenSample:
        tokens = self.tokenizer.encode(
            sample,
            add_bos=True,
            add_eos=training,
        )
        masks = [True] * len(tokens)
        return TokenSample(tokens, masks)

    def tokenize_instruct(
        self, sample: TrainingInstructSample, training: bool
    ) -> TokenSample:
        start_of_turn = True
        end_of_turn = False
        tokenized_messages = []
        mask = []

        # The chat template in HF adds a bunch of newlines
        new_line_token_id = self.tokenizer.encode(
            "\n", add_bos=False, add_eos=False
        )

        for message in sample.messages:
            # Prepend BOS on start of new turns
            if start_of_turn:
                tokenized_messages.append(self.tokenizer.bos_id)
                mask.append(not training)

            # Add special tokens
            if message.role == "user":
                tokenized_messages.append(
                    self.tokenizer.special_tokens["<|user|>"]
                )
                mask.append(not training)
            elif message.role == "assistant":
                tokenized_messages.append(
                    self.tokenizer.special_tokens["<|assistant|>"]
                )
                # If assistant message, this is the end of a turn
                end_of_turn = True
                mask.append(not training)
            elif message.role == "system":
                tokenized_messages.append(
                    self.tokenizer.special_tokens["<|system|>"]
                )
                mask.append(not training)
            else:
                raise ValueError(
                    f"Unknown role '{message.role}' for message: '{message.content}'"
                )

            # Add new line token
            tokenized_messages.extend(new_line_token_id)
            mask.extend([not training] * len(new_line_token_id))

            # Tokenize current message, append with masks
            tokens = self.tokenizer.encode(
                message.content.rstrip(" "),
                add_bos=False,
                add_eos=False,
            )
            tokens = (
                tokens
                + [self.tokenizer.special_tokens["<|end|>"]]
                + new_line_token_id
            )
            tokenized_messages.extend(tokens)
            mask.extend([end_of_turn or not training] * len(tokens))

            # If assistant message, append EOS at end
            if end_of_turn:
                tokenized_messages.append(self.tokenizer.eos_id)
                mask.append(True)
                end_of_turn = False
                start_of_turn = True
            else:
                start_of_turn = False
        if not training:
            assert not end_of_turn, (
                "When sampling the last sentenceshould be a user input"
            )
            tokenized_messages.append(
                self.tokenizer.special_tokens["<|assistant|>"]
            )
            mask.append(True)
            tokenized_messages.extend(new_line_token_id)
            mask.extend([True] * len(new_line_token_id))
        return TokenSample(tokenized_messages, mask)

    def decode(self, tokens: t.List[int]) -> str:
        return self.tokenizer.decode(tokens)

    def stop_token(self, sample_type: SampleType) -> int:
        return self.tokenizer.eos_id