Repository URL to install this package:
|
Version:
1.1.3 ▾
|
import os
import typing as t
from logging import getLogger
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from .exceptions import (
ConversationFormatError,
UnrecognizedRoleError,
MessageFormatError,
)
from .utils import (
SampleType,
TokenSample,
FinetuningAssistantMessage,
SystemMessage,
UserMessage,
Roles,
TrainingInstructSample,
ChatMessage,
)
logger = getLogger(__name__)
class Tokenizer:
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: t.Dict[str, int]
num_reserved_special_tokens = 256
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
def __init__(self, model_path: str):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path
mergeable_ranks = load_tiktoken_bpe(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [
f"<|reserved_special_token_{i}|>"
for i in range(5, self.num_reserved_special_tokens - 5)
]
self.special_tokens = {
token: num_base_tokens + i
for i, token in enumerate(special_tokens)
}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
logger.info(f"Reloaded tiktoken model from {model_path}")
self.n_words: int = self.model.n_vocab
# BOS / EOS token IDs
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
self.pad_id: int = -1
self.stop_tokens = {
self.special_tokens["<|end_of_text|>"],
self.special_tokens["<|eot_id|>"],
}
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
def encode(
self,
s: str,
*,
bos: bool,
eos: bool,
allowed_special: t.Union[t.Literal["all"], t.AbstractSet[str]] = set(),
disallowed_special: t.Union[t.Literal["all"], t.Collection[str]] = (),
) -> t.List[int]:
"""
Encodes a string into a list of token IDs.
Args:
s (str): The input string to be encoded.
bos (bool): Whether to prepend the beginning-of-sequence token.
eos (bool): Whether to append the end-of-sequence token.
allowed_tokens ("all"|set[str]): allowed special tokens in string
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
Returns:
list[int]: A list of token IDs.
By default, setting disallowed_special=() encodes a string by ignoring
special tokens. Specifically:
- Setting `disallowed_special` to () will cause all text corresponding
to special tokens to be encoded as natural text (insteading of raising
an error).
- Setting `allowed_special` to "all" will treat all text corresponding
to special tokens to be encoded as special tokens.
"""
assert type(s) is str
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
# https://github.com/openai/tiktoken/issues/195
# Here we iterate over subsequences and split if we exceed the limit
# of max consecutive non-whitespace or whitespace characters.
MAX_NO_WHITESPACES_CHARS = 25_000
substrs = (
substr
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
for substr in self._split_whitespaces_or_nonwhitespaces(
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
vals: t.List[int] = []
for substr in substrs:
vals.extend(
self.model.encode(
substr,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if bos:
vals.insert(0, self.bos_id)
if eos:
vals.append(self.eos_id)
return vals
def decode(self, tokens: t.Sequence[int]) -> str:
"""
Decodes a list of token IDs into a string.
Args:
tokens(List[int]): The list of token IDs to be decoded.
Returns:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(t.cast(t.List[int], tokens))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(
s: str, max_consecutive_slice_len: int
) -> t.Iterator[str]:
"""
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
consecutive whitespaces or consecutive non-whitespaces.
"""
current_slice_len = 0
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
slice_start = 0
for i in range(len(s)):
is_now_space = s[i].isspace()
if current_slice_is_space ^ is_now_space:
current_slice_len = 1
current_slice_is_space = is_now_space
else:
current_slice_len += 1
if current_slice_len > max_consecutive_slice_len:
yield s[slice_start:i]
slice_start = i
current_slice_len = 1
yield s[slice_start:]
class SarusLLamaTokenizer:
def __init__(self):
self.tokenizer = Tokenizer(
os.path.join(str(Path(__file__).parent), "llama_tokenizer.model")
)
def tokenize_pretrain(self, sample: str, training: bool) -> TokenSample:
tokens = self.tokenizer.encode(sample, bos=True, eos=training)
masks = [True] * len(tokens)
return TokenSample(tokens, masks)
def tokenize(
self, sample: t.Union[str, TrainingInstructSample], training: bool
) -> TokenSample:
if isinstance(sample, str):
return self.tokenize_pretrain(sample, training)
elif isinstance(sample, TrainingInstructSample):
return self.tokenize_instruct(sample, training)
raise ValueError(
f"`sample` has to be either of type `str` or `TrainingInstructSample`, not {type(sample)}."
)
def get_pretrain_sample(self, data: t.Dict[str, t.Any]) -> str:
content_keys = ["text", "content"]
assert not all(k in data for k in content_keys), (
"Make sure to have either 'text' or 'content' in your data. Not both."
)
assert any(data.get(k) is not None for k in content_keys), (
f"Must have one of 'text' or 'content' in your data. Only have {data.keys()}"
)
# get first non-None value
sample = None
for key in content_keys:
sample = data[key] if key in data else sample
assert isinstance(sample, str), sample
return sample
def build_instruct_sample(
self, data: t.Dict[str, t.Any], training: bool
) -> TrainingInstructSample:
messages: t.List[
t.Union[SystemMessage, UserMessage, FinetuningAssistantMessage]
] = []
# optional data fields that might be set
system_prompt = data.get("system_prompt")
messages_keys = ["messages", "interactions"]
content_keys = ["content", "text"] # both are accepted
allowed_roles = [role.value for role in Roles]
if not any(messages_key in data for messages_key in messages_keys):
err = f"The conversation does not contain one of '{', '.join(messages_keys)}' key, but only {', '.join(data.keys())}. Make sure that the conversation includes one of '{', '.join(messages_keys)}'."
raise ConversationFormatError(err, str(data))
if all(messages_key in data for messages_key in messages_keys):
err = f"The conversation cannot contain both of '{', '.join(messages_keys)}' key, but only one of the two."
raise ConversationFormatError(err, str(data))
# get first non-None value
data_messages: t.Optional[t.List[t.Dict[str, t.Any]]] = None
for key in messages_keys:
data_messages = data[key] if key in data else data_messages
assert data_messages is not None, "data_messages can't be None"
for data_message in data_messages:
if "role" not in data_message:
err = f"A message does not contain a 'role' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'role'."
raise MessageFormatError(err, str(data))
role = data_message["role"]
if all(key in data_message for key in content_keys):
err = f"A {role} message contains both a 'text' and 'content' key. Make sure that there is only one of the two."
raise MessageFormatError(err, str(data))
content: t.Optional[str] = None
for key in content_keys:
content = (
content if content is not None else data_message.get(key)
)
# non-function call message must have content
if content is None:
err = (
f"A {role} message does not contain one of '{content_keys}' key, but only {', '.join(data_message.keys())}."
f" Make sure that the message includes one of '{content_keys}' keys."
)
raise MessageFormatError(err, str(data))
if role not in allowed_roles:
raise UnrecognizedRoleError(role, allowed_roles)
if data_message["role"] == "user":
assert content is not None
messages.append(UserMessage(content=content))
elif data_message["role"] == "assistant":
weight = data_message.get("weight")
messages.append(
FinetuningAssistantMessage(content=content, weight=weight)
)
elif data_message["role"] == "system":
if system_prompt is not None:
err = "Multiple messages with role 'system' encountered. Only one is allowed."
raise MessageFormatError(err, str(data))
system_prompt = content
# TODO: add validation for Llama messages
# whether to train only on last assistant message
only_last = data.get("only_last", False)
if system_prompt is not None:
messages.insert(0, SystemMessage(content=system_prompt))
return TrainingInstructSample(
messages=messages,
only_last=only_last,
)
def tokenize_instruct(
self, sample: TrainingInstructSample, training: bool
) -> TokenSample:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
mask = [not training] * len(tokens)
for message in sample.messages:
curr_tok, curr_mask = self.encode_message(message, training)
tokens.extend(curr_tok)
mask.extend(curr_mask)
if not training:
# add empty message for chat completion
last_tok = self.encode_header(
FinetuningAssistantMessage(content="")
)
tokens.extend(last_tok)
mask.extend([True] * len(last_tok))
return TokenSample(tokens=tokens, masks=mask)
def encode_header(self, message: ChatMessage) -> t.List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(
self.tokenizer.encode(message.role.value, bos=False, eos=False)
)
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
def encode_message(
self, message: ChatMessage, training: bool
) -> t.Tuple[t.List[int], t.List[bool]]:
tokens = self.encode_header(message)
n_false = len(tokens) if training else 0
tokens.extend(
self.tokenizer.encode(
message.content.strip(), bos=False, eos=False
)
)
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
mask = (
[False] * n_false + [True] * (len(tokens) - n_false)
if isinstance(message, FinetuningAssistantMessage)
else [not training] * len(tokens)
)
return tokens, mask
def decode(self, tokens: t.List[int]) -> str:
return self.tokenizer.decode(tokens)
def stop_token(self, sample_type: SampleType) -> int:
if sample_type == SampleType.PRETRAIN:
return self.tokenizer.eos_id
else:
return self.tokenizer.special_tokens["<|eot_id|>"]