Repository URL to install this package:
|
Version:
1.1.3 ▾
|
import typing as t
import logging
import os
from sentencepiece import SentencePieceProcessor
from .exceptions import (
ConversationFormatError,
UnrecognizedRoleError,
MessageFormatError,
)
from pathlib import Path
from .utils import (
SampleType,
TokenSample,
FinetuningAssistantMessage,
SystemMessage,
UserMessage,
Roles,
TrainingInstructSample,
)
logger = logging.getLogger(__name__)
PHI3_SPECIAL_TOKENS = {
"<|endoftext|>": 32000,
"<|assistant|>": 32001,
"<|placeholder1|>": 32002,
"<|placeholder2|>": 32003,
"<|placeholder3|>": 32004,
"<|placeholder4|>": 32005,
"<|system|>": 32006,
"<|end|>": 32007,
"<|placeholder5|>": 32008,
"<|placeholder6|>": 32009,
"<|user|>": 32010,
}
class Tokenizer:
"""
SentencePiece tokenizer configured with Phi3 Mini's special tokens.
Args:
path (str): Path to pretrained tokenizer file.
special_tokens (Optional[Dict[str, int]]): mapping containing special text tokens and
their registered token IDs. If left as None, this will be set to the canonical
Phi3 special tokens.
max_seq_len (Optional[int]): A max sequence length to truncate tokens to.
Default: None
"""
def __init__(
self,
path: str,
special_tokens: t.Optional[t.Dict[str, int]] = None,
):
spm_model = SentencePieceProcessor()
spm_model.load(path)
self._spm_model = spm_model
self.special_tokens = (
special_tokens
if special_tokens is not None
else PHI3_SPECIAL_TOKENS
)
@property
def vocab_size(self) -> int:
return self._spm_model.vocab_size()
@property
def bos_id(self) -> int:
return self._spm_model.bos_id()
@property
def eos_id(self) -> int:
return self.special_tokens["<|endoftext|>"]
@property
def pad_id(self) -> int:
return self.special_tokens["<|endoftext|>"]
def encode(
self,
text: str,
add_bos: bool = True,
add_eos: bool = True,
) -> t.List[int]:
return self._spm_model.encode(text, add_bos=add_bos, add_eos=add_eos)
def decode(self, ids: t.List[int]) -> str:
"""Decode token IDs to strings.
Args:
ids (List[int]): The input token IDs to be decoded.
Returns:
str: The decoded text.
"""
ids_for_decode = []
for token_id in ids:
# Filter out special tokens and the placeholder tokens added
# by the Phi3 team
if token_id >= 32_000 and token_id <= 32_064:
continue
else:
ids_for_decode.append(token_id)
return self._spm_model.decode(ids_for_decode)
class SarusPhi3Tokenizer:
def __init__(self):
self.tokenizer = Tokenizer(
os.path.join(str(Path(__file__).parent), "phi3.model")
)
def get_pretrain_sample(self, data: t.Dict[str, t.Any]) -> str:
content_keys = ["text", "content"]
assert not all(k in data for k in content_keys), (
"Make sure to have either 'text' or 'content' in your data. Not both."
)
assert any(data.get(k) is not None for k in content_keys), (
f"Must have one of 'text' or 'content' in your data. Only have {data.keys()}"
)
# get first non-None value
sample = None
for key in content_keys:
sample = data[key] if key in data else sample
assert isinstance(sample, str), sample
return sample
def build_instruct_sample(
self, data: t.Dict[str, t.Any], training: bool
) -> TrainingInstructSample:
messages: t.List[
t.Union[SystemMessage, UserMessage, FinetuningAssistantMessage]
] = []
# optional data fields that might be set
system_prompt = data.get("system_prompt")
messages_keys = ["messages", "interactions"]
content_keys = ["content", "text"] # both are accepted
allowed_roles = [role.value for role in Roles]
if not any(messages_key in data for messages_key in messages_keys):
err = f"The conversation does not contain one of '{', '.join(messages_keys)}' key, but only {', '.join(data.keys())}. Make sure that the conversation includes one of '{', '.join(messages_keys)}'."
raise ConversationFormatError(err, str(data))
if all(messages_key in data for messages_key in messages_keys):
err = f"The conversation cannot contain both of '{', '.join(messages_keys)}' key, but only one of the two."
raise ConversationFormatError(err, str(data))
# get first non-None value
data_messages: t.Optional[t.List[t.Dict[str, t.Any]]] = None
for key in messages_keys:
data_messages = data[key] if key in data else data_messages
assert data_messages is not None, "data_messages can't be None"
for data_message in data_messages:
if "role" not in data_message:
err = f"A message does not contain a 'role' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'role'."
raise MessageFormatError(err, str(data))
role = data_message["role"]
if all(key in data_message for key in content_keys):
err = f"A {role} message contains both a 'text' and 'content' key. Make sure that there is only one of the two."
raise MessageFormatError(err, str(data))
content: t.Optional[str] = None
for key in content_keys:
content = (
content if content is not None else data_message.get(key)
)
# non-function call message must have content
if content is None:
err = (
f"A {role} message does not contain one of '{content_keys}' key, but only {', '.join(data_message.keys())}."
f" Make sure that the message includes one of '{content_keys}' keys."
)
raise MessageFormatError(err, str(data))
if role not in allowed_roles:
raise UnrecognizedRoleError(role, allowed_roles)
if data_message["role"] == "user":
assert content is not None
messages.append(UserMessage(content=content))
elif data_message["role"] == "assistant":
weight = data_message.get("weight")
messages.append(
FinetuningAssistantMessage(content=content, weight=weight)
)
elif data_message["role"] == "system":
if system_prompt is not None:
err = "Multiple messages with role 'system' encountered. Only one is allowed."
raise MessageFormatError(err, str(data))
system_prompt = content
# TODO: add validation for Phi3 messages
# whether to train only on last assistant message
only_last = data.get("only_last", False)
if system_prompt is not None:
messages.insert(0, SystemMessage(content=system_prompt))
return TrainingInstructSample(
messages=messages,
only_last=only_last,
)
def tokenize(
self, sample: t.Union[str, TrainingInstructSample], training: bool
) -> TokenSample:
if isinstance(sample, str):
return self.tokenize_pretrain(sample, training)
elif isinstance(sample, TrainingInstructSample):
return self.tokenize_instruct(sample, training)
raise ValueError(
f"`sample` has to be either of type `str` or `TrainingInstructSample`, not {type(sample)}."
)
def tokenize_pretrain(self, sample: str, training: bool) -> TokenSample:
tokens = self.tokenizer.encode(
sample,
add_bos=True,
add_eos=training,
)
masks = [True] * len(tokens)
return TokenSample(tokens, masks)
def tokenize_instruct(
self, sample: TrainingInstructSample, training: bool
) -> TokenSample:
start_of_turn = True
end_of_turn = False
tokenized_messages = []
mask = []
# The chat template in HF adds a bunch of newlines
new_line_token_id = self.tokenizer.encode(
"\n", add_bos=False, add_eos=False
)
for message in sample.messages:
# Prepend BOS on start of new turns
if start_of_turn:
tokenized_messages.append(self.tokenizer.bos_id)
mask.append(not training)
# Add special tokens
if message.role == "user":
tokenized_messages.append(
self.tokenizer.special_tokens["<|user|>"]
)
mask.append(not training)
elif message.role == "assistant":
tokenized_messages.append(
self.tokenizer.special_tokens["<|assistant|>"]
)
# If assistant message, this is the end of a turn
end_of_turn = True
mask.append(not training)
elif message.role == "system":
tokenized_messages.append(
self.tokenizer.special_tokens["<|system|>"]
)
mask.append(not training)
else:
raise ValueError(
f"Unknown role '{message.role}' for message: '{message.content}'"
)
# Add new line token
tokenized_messages.extend(new_line_token_id)
mask.extend([not training] * len(new_line_token_id))
# Tokenize current message, append with masks
tokens = self.tokenizer.encode(
message.content.rstrip(" "),
add_bos=False,
add_eos=False,
)
tokens = (
tokens
+ [self.tokenizer.special_tokens["<|end|>"]]
+ new_line_token_id
)
tokenized_messages.extend(tokens)
mask.extend([end_of_turn or not training] * len(tokens))
# If assistant message, append EOS at end
if end_of_turn:
tokenized_messages.append(self.tokenizer.eos_id)
mask.append(True)
end_of_turn = False
start_of_turn = True
else:
start_of_turn = False
if not training:
assert not end_of_turn, (
"When sampling the last sentenceshould be a user input"
)
tokenized_messages.append(
self.tokenizer.special_tokens["<|assistant|>"]
)
mask.append(True)
tokenized_messages.extend(new_line_token_id)
mask.extend([True] * len(new_line_token_id))
return TokenSample(tokenized_messages, mask)
def decode(self, tokens: t.List[int]) -> str:
return self.tokenizer.decode(tokens)
def stop_token(self, sample_type: SampleType) -> int:
return self.tokenizer.eos_id