Repository URL to install this package:
|
Version:
0.7.1+cu122 ▾
|
from typing import Any, Dict, List, Optional
import rouge
from torch import LongTensor
from transformers import GenerationConfig
from ._base import BaseTask
from ._utils.generation_utils import postprocess_generation_ids
class TextSummarizationTask(BaseTask):
def __init__(
self,
model,
tokenizer,
data_name_or_path: str,
prompt_col_name: str,
label_col_name: str,
device: Optional[str] = None,
**kwargs,
):
kwargs["merge_prompt_label"] = False
super().__init__(
model=model,
tokenizer=tokenizer,
data_name_or_path=data_name_or_path,
prompt_col_name=prompt_col_name,
label_col_name=label_col_name,
device=device,
**kwargs,
)
def _predict(self, batch_data: Dict[str, Any], *args, **kwargs) -> List[str]:
generation_config = kwargs["generation_config"]
output_ids = self.model.generate(
input_ids=batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
generation_config=generation_config,
)
return [
each[0].lower().strip()
for each in postprocess_generation_ids(
input_ids=batch_data["input_ids"],
output_ids=output_ids,
num_return_sequences=generation_config.num_return_sequences,
tokenizer=self.tokenizer,
)
]
def _parse_labels(self, label_ids: LongTensor) -> List[str]:
labels = []
for one_label_ids in label_ids:
one_label_ids = one_label_ids[(one_label_ids == -100).sum() :]
label = self.tokenizer.decode(one_label_ids).lower().strip()
labels.append(label)
return labels
def _metric(self, pred: List[Any], label: List[Any]) -> Dict[str, Dict[str, float]]:
metric = rouge.Rouge()
return metric.get_scores(hyps=pred, refs=label, avg=True)
def run(self, generation_config: Optional[GenerationConfig] = None) -> Dict[str, float]:
if not generation_config:
generation_config = GenerationConfig(num_beams=1, do_sample=False, max_new_tokens=128)
generation_config.num_return_sequences = 1
generation_config.eos_token_id = self.tokenizer.eos_token_id
generation_config.pad_token_id = self.tokenizer.pad_token_id
return super().run(generation_config=generation_config)
__all__ = ["TextSummarizationTask"]