Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus-llm / sarus_llm / data / data_collator.py
Size: Mime:
from __future__ import annotations
import typing as t
import torch


class DataCollator(t.Protocol):
    def collate_batch(
        self, inputs: t.List[t.Dict[str, t.Any]]
    ) -> t.Dict[str, torch.Tensor]: ...


class TrainingDataCollator:
    """Class responsible to pack data to train a model"""

    def __init__(self, tokenizer_pad_token: int = 0) -> None:
        self.tokenizer_pad_token = tokenizer_pad_token

    def collate_batch(
        self, inputs: t.List[t.Dict[str, t.Any]]
    ) -> t.Dict[str, torch.Tensor]:
        """Pads list of samples to the maximum length of the batch
        and create labels. No padding mask is created as during training
        we can pad on the right side automatically, with the causal mask of the model.
        So only the labels are adapted to mask certain values for the loss."""

        max_length = max(len(el["tokens"]) for el in inputs)
        input_ids = []
        labels_mask = []
        for token_sample in inputs:
            curr_input_ids, loss_mask = self._pad(
                tokens=token_sample["tokens"],
                mask=token_sample["masks"],
                max_length=max_length,
            )
            input_ids.append(curr_input_ids)
            labels_mask.append(loss_mask)

        output_ids = torch.tensor(input_ids, dtype=torch.long)
        labels = output_ids.clone()
        # set labels to -100 where mask is 0
        labels[~torch.tensor(labels_mask)] = -100
        return {
            "tokens": output_ids,
            "labels": labels,
        }

    def _pad(
        self, tokens: t.List[int], mask: t.List[bool], max_length: int
    ) -> t.Tuple[t.List[int], t.List[bool]]:
        """Returns padded token, mask for attention and mask for loss"""
        assert max_length >= len(tokens), (
            "Found sequence longer than the expected one"
        )
        n_missing = max_length - len(tokens)
        return (
            tokens + n_missing * [self.tokenizer_pad_token],
            mask + n_missing * [False],
        )


class PreferenceDataCollator(TrainingDataCollator):
    """DataCollator for preference data. It should take as input
    a list of examples, each example is a dict with two fields:
    - chosen data
    - rejected data.
    #TODO: use pydantic everywhere"""

    def collate_batch(
        self, inputs: t.List[t.Dict[str, t.Any]]
    ) -> t.Dict[str, torch.Tensor]:
        """"""

        chosen = super().collate_batch([el["chosen"] for el in inputs])
        rejected = super().collate_batch([el["rejected"] for el in inputs])

        # need to set everything to the same size
        size = max(chosen["tokens"].shape[1], rejected["tokens"].shape[1])

        return {
            "chosen_tokens": torch.cat(
                [
                    chosen["tokens"],
                    self.tokenizer_pad_token
                    * torch.ones(
                        chosen["tokens"].shape[0],
                        size - chosen["tokens"].shape[1],
                        dtype=torch.long,
                    ),
                ],
                dim=1,
            ),
            "chosen_labels": torch.cat(
                [
                    chosen["labels"],
                    -100
                    * torch.ones(
                        chosen["tokens"].shape[0],
                        size - chosen["tokens"].shape[1],
                        dtype=torch.long,
                    ),
                ],
                dim=1,
            ),
            "rejected_tokens": torch.cat(
                [
                    rejected["tokens"],
                    self.tokenizer_pad_token
                    * torch.ones(
                        rejected["tokens"].shape[0],
                        size - rejected["tokens"].shape[1],
                        dtype=torch.long,
                    ),
                ],
                dim=1,
            ),
            "rejected_labels": torch.cat(
                [
                    rejected["labels"],
                    -100
                    * torch.ones(
                        rejected["tokens"].shape[0],
                        size - rejected["tokens"].shape[1],
                        dtype=torch.long,
                    ),
                ],
                dim=1,
            ),
        }


class ServingDataCollator:
    """Class responsible to packing data to serve a model at inference"""

    def __init__(self, tokenizer_pad_token: int = 0) -> None:
        self.tokenizer_pad_token = tokenizer_pad_token

    def collate_batch(
        self, inputs: t.List[t.Dict[str, t.Any]]
    ) -> t.Dict[str, torch.Tensor]:
        """Pads list of samples to the maximum length of the batch
        and create labels for huggingface"""

        max_length = max(len(el["tokens"]) for el in inputs)
        input_ids = []
        padding_mask = []
        for token_sample in inputs:
            curr_input_ids, curr_mask = self._pad(
                tokens=token_sample["tokens"],
                mask=token_sample["masks"],
                max_length=max_length,
            )
            input_ids.append(curr_input_ids)
            padding_mask.append(curr_mask)

        return {
            "tokens": torch.tensor(input_ids, dtype=torch.long),
            "padding_mask": torch.tensor(padding_mask, dtype=torch.bool),
        }

    def _pad(
        self, tokens: t.List[int], mask: t.List[bool], max_length: int
    ) -> t.Tuple[t.List[int], t.List[bool]]:
        """Returns padded token, mask for attention and mask for loss"""
        assert max_length >= len(tokens), (
            "Found sequence longer than the expected one"
        )
        n_missing = max_length - len(tokens)
        return (
            n_missing * [self.tokenizer_pad_token] + tokens,
            n_missing * [False] + mask,
        )