Repository URL to install this package:
|
Version:
1.1.3 ▾
|
from __future__ import annotations
import typing as t
class ModelProvider(t.Protocol):
"""Interface to retrieve model/tokenizers from Sarus
pretrained ones and to push them back after training"""
def torch_model(self, device: t.Union[int, str]) -> t.Any:
"""Load the pretrained model and the optimizer."""
def push_to_large_object_storage(
self, finetuned_uri: str, local_dir: str
) -> None:
"""Responsible to push"""
class LargeObjectStorage(t.Protocol):
"""Interface to download/push weights of peft model
from/to client bucket large object storage"""
def cache_pretrained_peft_locally(self) -> None: ...
def missing_local_pretrained_cache(self) -> bool: ...
def push_to_large_object_storage(
self, finetuned_uri: str, local_dir: str
) -> None: ...
def local_cache_dir(self) -> None: ...