Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import logging
import os
import pickle as pkl
import time
import traceback
import typing as t

from sarus_data_spec import typing as st
from sarus_data_spec.constants import (
    CACHE_PATH,
    CACHE_PROTO,
    CACHE_SCALAR_TASK,
    CACHE_TYPE,
    ScalarCaching,
)
from sarus_data_spec.manager.computations.local.base import LocalComputation
from sarus_data_spec.scalar import Scalar
from sarus_data_spec.status import DataSpecErrorStatus, error, ready
import sarus_data_spec.protobuf as sp

logger = logging.getLogger(__name__)


class CacheScalarComputation(LocalComputation[t.Tuple[str, str]]):
    """Class responsible for handling the caching
    in of a scalar. It wraps a ValueComputation to get the value."""

    task_name = CACHE_SCALAR_TASK

    async def prepare(self, dataspec: st.DataSpec) -> None:
        logger.info(f"STARTING CACHE_SCALAR {dataspec.uuid()}")
        start = time.perf_counter()
        scalar = t.cast(Scalar, dataspec)
        try:
            value = await self.computing_manager().async_value_op(
                scalar=scalar
            )
            if isinstance(value, st.HasProtobuf):
                properties = {
                    CACHE_PROTO: sp.to_base64(value.protobuf()),
                    CACHE_TYPE: sp.type_name(value.prototype()),
                }
            else:
                properties = {
                    CACHE_TYPE: ScalarCaching.PICKLE.value,
                    CACHE_PATH: self.cache_path(scalar),
                }

                with open(self.cache_path(scalar), "wb") as f:
                    pkl.dump(value, f)

        except DataSpecErrorStatus as exception:
            error(
                dataspec=dataspec,
                manager=self.computing_manager(),
                task=self.task_name,
                properties={
                    "message": traceback.format_exc(),
                    "relaunch": str(exception.relaunch),
                },
            )
            raise

        except Exception:
            error(
                dataspec=dataspec,
                manager=self.computing_manager(),
                task=self.task_name,
                properties={
                    "message": traceback.format_exc(),
                    "relaunch": str(False),
                },
            )

            raise DataSpecErrorStatus((False, traceback.format_exc()))
        else:
            end = time.perf_counter()
            logger.info(
                f"FINISHED CACHE_SCALAR {dataspec.uuid()} ({end-start:.2f}s)"
            )
            ready(
                dataspec=dataspec,
                manager=self.computing_manager(),
                task=self.task_name,
                properties=properties,
            )

    async def result_from_stage_properties(
        self,
        dataspec: st.DataSpec,
        properties: t.Mapping[str, str],
        **kwargs: t.Any,
    ) -> t.Tuple[str, str]:
        """Reads the cache and returns the value."""
        if properties[CACHE_TYPE] == ScalarCaching.PICKLE.value:
            return properties[CACHE_TYPE], properties[CACHE_PATH]
        return properties[CACHE_TYPE], properties[CACHE_PROTO]

    def cache_path(self, dataspec: st.DataSpec) -> str:
        """Returns the path where to cache the scalar."""
        return os.path.join(
            dataspec.manager().parquet_dir(), f"{dataspec.uuid()}.pkl"
        )