Repository URL to install this package:
|
Version:
1.1.3 ▾
|
from __future__ import annotations
import typing as t
import torch
class DataCollator(t.Protocol):
def collate_batch(
self, inputs: t.List[t.Dict[str, t.Any]]
) -> t.Dict[str, torch.Tensor]: ...
class TrainingDataCollator:
"""Class responsible to pack data to train a model"""
def __init__(self, tokenizer_pad_token: int = 0) -> None:
self.tokenizer_pad_token = tokenizer_pad_token
def collate_batch(
self, inputs: t.List[t.Dict[str, t.Any]]
) -> t.Dict[str, torch.Tensor]:
"""Pads list of samples to the maximum length of the batch
and create labels. No padding mask is created as during training
we can pad on the right side automatically, with the causal mask of the model.
So only the labels are adapted to mask certain values for the loss."""
max_length = max(len(el["tokens"]) for el in inputs)
input_ids = []
labels_mask = []
for token_sample in inputs:
curr_input_ids, loss_mask = self._pad(
tokens=token_sample["tokens"],
mask=token_sample["masks"],
max_length=max_length,
)
input_ids.append(curr_input_ids)
labels_mask.append(loss_mask)
output_ids = torch.tensor(input_ids, dtype=torch.long)
labels = output_ids.clone()
# set labels to -100 where mask is 0
labels[~torch.tensor(labels_mask)] = -100
return {
"tokens": output_ids,
"labels": labels,
}
def _pad(
self, tokens: t.List[int], mask: t.List[bool], max_length: int
) -> t.Tuple[t.List[int], t.List[bool]]:
"""Returns padded token, mask for attention and mask for loss"""
assert max_length >= len(tokens), (
"Found sequence longer than the expected one"
)
n_missing = max_length - len(tokens)
return (
tokens + n_missing * [self.tokenizer_pad_token],
mask + n_missing * [False],
)
class PreferenceDataCollator(TrainingDataCollator):
"""DataCollator for preference data. It should take as input
a list of examples, each example is a dict with two fields:
- chosen data
- rejected data.
#TODO: use pydantic everywhere"""
def collate_batch(
self, inputs: t.List[t.Dict[str, t.Any]]
) -> t.Dict[str, torch.Tensor]:
""""""
chosen = super().collate_batch([el["chosen"] for el in inputs])
rejected = super().collate_batch([el["rejected"] for el in inputs])
# need to set everything to the same size
size = max(chosen["tokens"].shape[1], rejected["tokens"].shape[1])
return {
"chosen_tokens": torch.cat(
[
chosen["tokens"],
self.tokenizer_pad_token
* torch.ones(
chosen["tokens"].shape[0],
size - chosen["tokens"].shape[1],
dtype=torch.long,
),
],
dim=1,
),
"chosen_labels": torch.cat(
[
chosen["labels"],
-100
* torch.ones(
chosen["tokens"].shape[0],
size - chosen["tokens"].shape[1],
dtype=torch.long,
),
],
dim=1,
),
"rejected_tokens": torch.cat(
[
rejected["tokens"],
self.tokenizer_pad_token
* torch.ones(
rejected["tokens"].shape[0],
size - rejected["tokens"].shape[1],
dtype=torch.long,
),
],
dim=1,
),
"rejected_labels": torch.cat(
[
rejected["labels"],
-100
* torch.ones(
rejected["tokens"].shape[0],
size - rejected["tokens"].shape[1],
dtype=torch.long,
),
],
dim=1,
),
}
class ServingDataCollator:
"""Class responsible to packing data to serve a model at inference"""
def __init__(self, tokenizer_pad_token: int = 0) -> None:
self.tokenizer_pad_token = tokenizer_pad_token
def collate_batch(
self, inputs: t.List[t.Dict[str, t.Any]]
) -> t.Dict[str, torch.Tensor]:
"""Pads list of samples to the maximum length of the batch
and create labels for huggingface"""
max_length = max(len(el["tokens"]) for el in inputs)
input_ids = []
padding_mask = []
for token_sample in inputs:
curr_input_ids, curr_mask = self._pad(
tokens=token_sample["tokens"],
mask=token_sample["masks"],
max_length=max_length,
)
input_ids.append(curr_input_ids)
padding_mask.append(curr_mask)
return {
"tokens": torch.tensor(input_ids, dtype=torch.long),
"padding_mask": torch.tensor(padding_mask, dtype=torch.bool),
}
def _pad(
self, tokens: t.List[int], mask: t.List[bool], max_length: int
) -> t.Tuple[t.List[int], t.List[bool]]:
"""Returns padded token, mask for attention and mask for loss"""
assert max_length >= len(tokens), (
"Found sequence longer than the expected one"
)
n_missing = max_length - len(tokens)
return (
n_missing * [self.tokenizer_pad_token] + tokens,
n_missing * [False] + mask,
)