Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import os

import tensorflow as tf

from sarus_query_builder.protobuf.query_pb2 import Model


def model_to_pb(model: tf.keras.Model, path: str = "/tmp/model") -> Model:
    """Save a tensorflow model to protobuf"""
    model.save(path)
    model_pb = Model()
    with open(os.path.join(path, "saved_model.pb"), "rb") as f:
        model_pb.protobuf = f.read()

    with open(os.path.join(path, "keras_metadata.pb"), "rb") as f:
        model_pb.keras_metadata = f.read()

    for filename in os.listdir(os.path.join(path, "assets")):
        with open(os.path.join(path, "assets", filename), "rb") as f:
            model_pb.assets[filename] = f.read()

    for filename in os.listdir(os.path.join(path, "variables")):
        with open(os.path.join(path, "variables", filename), "rb") as f:
            model_pb.variables[filename] = f.read()

    return model_pb


def pb_to_model(model_pb: Model, path: str = "/tmp/model") -> tf.keras.Model:
    """Load model from protobuf"""
    for directory in [
        path,
        os.path.join(path, "assets"),
        os.path.join(path, "variables"),
    ]:
        if not os.path.exists(directory):
            os.makedirs(directory)

    with open(os.path.join(path, "saved_model.pb"), "wb") as f:
        f.write(model_pb.protobuf)

    with open(os.path.join(path, "keras_metadata.pb"), "wb") as f:
        f.write(model_pb.keras_metadata)

    for filename, data in model_pb.assets.items():
        with open(os.path.join(path, "assets", filename), "wb") as f:
            f.write(data)

    for filename, data in model_pb.variables.items():
        with open(os.path.join(path, "variables", filename), "wb") as f:
            f.write(data)

    model = tf.keras.models.load_model(path, compile=False)
    return model