Repository URL to install this package:
|
Version:
4.5.4.dev1 ▾
|
from types import TracebackType
import typing as t
from sarus_data_spec.bounds import Bounds
from sarus_data_spec.context.state import (
pop_global_context,
push_global_context,
)
from sarus_data_spec.context.typing import Context
from sarus_data_spec.factory import Factory
from sarus_data_spec.links import Links
from sarus_data_spec.manager.typing import Manager
from sarus_data_spec.marginals import Marginals
from sarus_data_spec.multiplicity import Multiplicity
from sarus_data_spec.predicate import Predicate
from sarus_data_spec.protobuf import type_name
from sarus_data_spec.scalar import random_seed
from sarus_data_spec.schema import Schema
from sarus_data_spec.size import Size
from sarus_data_spec.storage.typing import Storage
from sarus_data_spec.transform import derive_seed
import sarus_data_spec as s
import sarus_data_spec.protobuf as sp
import sarus_data_spec.typing as st
class Base(Context):
"""A factory class with all the config"""
def __init__(self, seed: int = 1234) -> None:
self.master_seed = seed
self._factory = Factory()
# Register relevant classes
self.factory().register(
type_name(sp.Dataset),
lambda protobuf, store: s.Dataset(
t.cast(sp.Dataset, protobuf), store
),
)
self.factory().register(
type_name(sp.Scalar),
lambda protobuf, store: s.Scalar(
t.cast(sp.Scalar, protobuf), store
),
)
self.factory().register(
type_name(sp.Status),
lambda protobuf, store: s.Status(
t.cast(sp.Status, protobuf), store
),
)
self.factory().register(
type_name(sp.Transform),
lambda protobuf, store: s.Transform(
t.cast(sp.Transform, protobuf), store
),
)
self.factory().register(
type_name(sp.Attribute),
lambda protobuf, store: s.Attribute(
t.cast(sp.Attribute, protobuf), store
),
)
self.factory().register(
type_name(sp.VariantConstraint),
lambda protobuf, store: s.VariantConstraint(
t.cast(sp.VariantConstraint, protobuf), store
),
)
self.factory().register(
type_name(sp.Predicate),
lambda protobuf: Predicate(t.cast(sp.Predicate, protobuf)),
)
self.factory().register(
type_name(sp.Schema),
lambda protobuf, store: Schema(t.cast(sp.Schema, protobuf), store),
)
self.factory().register(
type_name(sp.Marginals),
lambda protobuf, store: Marginals(
t.cast(sp.Marginals, protobuf), store
),
)
self.factory().register(
type_name(sp.Size),
lambda protobuf, store: Size(t.cast(sp.Size, protobuf), store),
)
self.factory().register(
type_name(sp.Multiplicity),
lambda protobuf, store: Multiplicity(
t.cast(sp.Multiplicity, protobuf), store
),
)
self.factory().register(
type_name(sp.Bounds),
lambda protobuf, store: Bounds(t.cast(sp.Bounds, protobuf), store),
)
self.factory().register(
type_name(sp.Links),
lambda protobuf, store: Links(t.cast(sp.Links, protobuf), store),
)
def generate_seed(self, salt: int = 0) -> st.Scalar:
"""Generate a new seed from the master seed."""
return t.cast(
st.Scalar, derive_seed(salt)(random_seed(self.master_seed))
)
def factory(self) -> Factory:
return self._factory
def storage(self) -> Storage:
raise NotImplementedError()
def manager(self) -> Manager:
raise NotImplementedError()
def __enter__(self) -> Context:
push_global_context(self)
return self
def __exit__(
self,
type: t.Optional[t.Type[BaseException]],
value: t.Optional[BaseException],
traceback: t.Optional[TracebackType],
) -> None:
pop_global_context()
# We do not return True so that errors are passed over