Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-llm / sarus_llm / data / tokenizers / utils.py
Size: Mime:
from dataclasses import dataclass
from enum import Enum
from pydantic import Field
import typing as t
import pydantic
from typing_extensions import TypeAlias

Sequence = t.List[int]
Mask = t.List[bool]


class Roles(str, Enum):
    system = "system"
    user = "user"
    assistant = "assistant"


class ChunkTypes(str, Enum):
    text = "text"


class ContentChunk(pydantic.BaseModel):
    type: ChunkTypes = ChunkTypes.text
    text: str


class BaseMessage(pydantic.BaseModel):
    role: t.Literal[Roles.system, Roles.user, Roles.assistant]


class UserMessage(BaseMessage):
    role: t.Literal[Roles.user] = Roles.user
    content: t.Union[str, t.List[ContentChunk]]


class SystemMessage(BaseMessage):
    role: t.Literal[Roles.system] = Roles.system
    content: t.Union[str, t.List[ContentChunk]]


class AssistantMessage(BaseMessage):
    role: t.Literal[Roles.assistant] = Roles.assistant
    content: t.Optional[str] = None


class FinetuningAssistantMessage(AssistantMessage):
    weight: t.Optional[float] = None


ChatMessage = t.Annotated[
    t.Union[
        SystemMessage,
        UserMessage,
        AssistantMessage,
    ],
    Field(discriminator="role"),
]

FinetuningMessage = t.Annotated[
    t.Union[SystemMessage, UserMessage, FinetuningAssistantMessage],
    Field(discriminator="role"),
]

ChatMessageType = t.TypeVar("ChatMessageType", bound=ChatMessage)

# Used for type hinting in generic classes where we might override the message types
UserMessageType = t.TypeVar("UserMessageType", bound=UserMessage)
AssistantMessageType = t.TypeVar(
    "AssistantMessageType", bound=AssistantMessage
)
SystemMessageType = t.TypeVar("SystemMessageType", bound=SystemMessage)

UATS: TypeAlias = t.Union[
    UserMessageType, AssistantMessageType, SystemMessageType
]


class InstructRequest(pydantic.BaseModel, t.Generic[ChatMessageType]):
    """
    A valid request to be tokenized
    """

    messages: t.List[ChatMessageType]


@dataclass()
class TokenSample:
    tokens: Sequence
    masks: Mask


class SampleType(str, Enum):
    PRETRAIN = "pretrain"
    INSTRUCT = "instruct"


class TrainingInstructSample(InstructRequest):
    only_last: bool = False