Repository URL to install this package:
|
Version:
2.7.2 ▾
|
import os
import tensorflow as tf
from sarus_query_builder.protobuf.query_pb2 import Model
def model_to_pb(model: tf.keras.Model, path: str = "/tmp/model") -> Model:
"""Save a tensorflow model to protobuf"""
model.save(path)
model_pb = Model()
with open(os.path.join(path, "saved_model.pb"), "rb") as f:
model_pb.protobuf = f.read()
with open(os.path.join(path, "keras_metadata.pb"), "rb") as f:
model_pb.keras_metadata = f.read()
for filename in os.listdir(os.path.join(path, "assets")):
with open(os.path.join(path, "assets", filename), "rb") as f:
model_pb.assets[filename] = f.read()
for filename in os.listdir(os.path.join(path, "variables")):
with open(os.path.join(path, "variables", filename), "rb") as f:
model_pb.variables[filename] = f.read()
return model_pb
def pb_to_model(model_pb: Model, path: str = "/tmp/model") -> tf.keras.Model:
"""Load model from protobuf"""
for directory in [
path,
os.path.join(path, "assets"),
os.path.join(path, "variables"),
]:
if not os.path.exists(directory):
os.makedirs(directory)
with open(os.path.join(path, "saved_model.pb"), "wb") as f:
f.write(model_pb.protobuf)
with open(os.path.join(path, "keras_metadata.pb"), "wb") as f:
f.write(model_pb.keras_metadata)
for filename, data in model_pb.assets.items():
with open(os.path.join(path, "assets", filename), "wb") as f:
f.write(data)
for filename, data in model_pb.variables.items():
with open(os.path.join(path, "variables", filename), "wb") as f:
f.write(data)
model = tf.keras.models.load_model(path, compile=False)
return model