Repository URL to install this package:
|
Version:
1.1.3 ▾
|
from dataclasses import dataclass
from enum import Enum
from pydantic import Field
import typing as t
import pydantic
from typing_extensions import TypeAlias
Sequence = t.List[int]
Mask = t.List[bool]
class Roles(str, Enum):
system = "system"
user = "user"
assistant = "assistant"
class ChunkTypes(str, Enum):
text = "text"
class ContentChunk(pydantic.BaseModel):
type: ChunkTypes = ChunkTypes.text
text: str
class BaseMessage(pydantic.BaseModel):
role: t.Literal[Roles.system, Roles.user, Roles.assistant]
class UserMessage(BaseMessage):
role: t.Literal[Roles.user] = Roles.user
content: t.Union[str, t.List[ContentChunk]]
class SystemMessage(BaseMessage):
role: t.Literal[Roles.system] = Roles.system
content: t.Union[str, t.List[ContentChunk]]
class AssistantMessage(BaseMessage):
role: t.Literal[Roles.assistant] = Roles.assistant
content: t.Optional[str] = None
class FinetuningAssistantMessage(AssistantMessage):
weight: t.Optional[float] = None
ChatMessage = t.Annotated[
t.Union[
SystemMessage,
UserMessage,
AssistantMessage,
],
Field(discriminator="role"),
]
FinetuningMessage = t.Annotated[
t.Union[SystemMessage, UserMessage, FinetuningAssistantMessage],
Field(discriminator="role"),
]
ChatMessageType = t.TypeVar("ChatMessageType", bound=ChatMessage)
# Used for type hinting in generic classes where we might override the message types
UserMessageType = t.TypeVar("UserMessageType", bound=UserMessage)
AssistantMessageType = t.TypeVar(
"AssistantMessageType", bound=AssistantMessage
)
SystemMessageType = t.TypeVar("SystemMessageType", bound=SystemMessage)
UATS: TypeAlias = t.Union[
UserMessageType, AssistantMessageType, SystemMessageType
]
class InstructRequest(pydantic.BaseModel, t.Generic[ChatMessageType]):
"""
A valid request to be tokenized
"""
messages: t.List[ChatMessageType]
@dataclass()
class TokenSample:
tokens: Sequence
masks: Mask
class SampleType(str, Enum):
PRETRAIN = "pretrain"
INSTRUCT = "instruct"
class TrainingInstructSample(InstructRequest):
only_last: bool = False