Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Shim for systems that need to load both SessionBundle and SavedModel.
This is intended to be used during migration to SavedModel.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.contrib.session_bundle import constants as legacy_constants
from tensorflow.contrib.session_bundle import manifest_pb2
from tensorflow.contrib.session_bundle import session_bundle
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import meta_graph
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import signature_constants
def _add_input_to_signature_def(tensor_name, map_key, signature_def):
"""Add input tensor to signature_def.
Args:
tensor_name: string name of tensor to add to signature_def inputs
map_key: string key to key into signature_def inputs map
signature_def: object of type meta_graph_pb2.SignatureDef()
Sideffect: adds a TensorInfo with tensor_name to signature_def inputs map
keyed with map_key
"""
tensor_info = meta_graph_pb2.TensorInfo(name=tensor_name)
signature_def.inputs[map_key].CopyFrom(tensor_info)
def _add_output_to_signature_def(tensor_name, map_key, signature_def):
"""Add output tensor to signature_def.
Args:
tensor_name: string name of tensor to add to signature_def outputs
map_key: string key to key into signature_def outputs map
signature_def: object of type meta_graph_pb2.SignatureDef()
Sideffect: adds a TensorInfo with tensor_name to signature_def outputs map
keyed with map_key
"""
tensor_info = meta_graph_pb2.TensorInfo(name=tensor_name)
signature_def.outputs[map_key].CopyFrom(tensor_info)
def _convert_default_signature_to_signature_def(signatures):
"""Convert default signature to object of type SignatureDef.
Args:
signatures: object of type manifest_pb2.Signatures()
Returns:
object of type SignatureDef which contains a converted version of default
signature from input signatures object
Returns None if signature is of generic type because it cannot be converted
to SignatureDef.
"""
default_signature = signatures.default_signature
signature_def = meta_graph_pb2.SignatureDef()
if (default_signature.WhichOneof("type") ==
legacy_constants.REGRESSION_SIGNATURE):
regression_signature = default_signature.regression_signature
signature_def.method_name = signature_constants.REGRESS_METHOD_NAME
_add_input_to_signature_def(regression_signature.input.tensor_name,
signature_constants.REGRESS_INPUTS,
signature_def)
_add_output_to_signature_def(regression_signature.output.tensor_name,
signature_constants.REGRESS_OUTPUTS,
signature_def)
elif (default_signature.WhichOneof("type") ==
legacy_constants.CLASSIFICATION_SIGNATURE):
classification_signature = default_signature.classification_signature
signature_def.method_name = signature_constants.CLASSIFY_METHOD_NAME
_add_input_to_signature_def(classification_signature.input.tensor_name,
signature_constants.CLASSIFY_INPUTS,
signature_def)
_add_output_to_signature_def(classification_signature.classes.tensor_name,
signature_constants.CLASSIFY_OUTPUT_CLASSES,
signature_def)
_add_output_to_signature_def(classification_signature.scores.tensor_name,
signature_constants.CLASSIFY_OUTPUT_SCORES,
signature_def)
else:
logging.error(
"Only classification and regression default signatures "
"are supported for up-conversion. %s is not "
"supported", default_signature.WhichOneof("type"))
return None
return signature_def
def _convert_named_signatures_to_signature_def(signatures):
"""Convert named signatures to object of type SignatureDef.
Args:
signatures: object of type manifest_pb2.Signatures()
Returns:
object of type SignatureDef which contains a converted version of named
signatures from input signatures object
Raises:
RuntimeError: if input and output named signatures are not of type
GenericSignature
"""
signature_def = meta_graph_pb2.SignatureDef()
input_signature = signatures.named_signatures[
signature_constants.PREDICT_INPUTS]
output_signature = signatures.named_signatures[
signature_constants.PREDICT_OUTPUTS]
# TODO(pdudnik): what if there are other signatures? Mimic cr/140900781 once
# it is submitted.
if (input_signature.WhichOneof("type") != legacy_constants.GENERIC_SIGNATURE
or output_signature.WhichOneof("type") !=
legacy_constants.GENERIC_SIGNATURE):
raise RuntimeError("Named input and output signatures can only be "
"up-converted if they are generic signature. "
"Input signature type is %s, output signature type is "
"%s" % (input_signature.WhichOneof("type"),
output_signature.WhichOneof("type")))
signature_def.method_name = signature_constants.PREDICT_METHOD_NAME
for key, val in input_signature.generic_signature.map.items():
_add_input_to_signature_def(val.tensor_name, key, signature_def)
for key, val in output_signature.generic_signature.map.items():
_add_output_to_signature_def(val.tensor_name, key, signature_def)
return signature_def
def _convert_signatures_to_signature_defs(metagraph_def):
"""Produce default and named upconverted SignatureDef objects from Signatures.
Args:
metagraph_def: object of type meta_graph_pb2.MetaGraphDef containing legacy
format Session Bundle signatures
Returns:
default_signature_def: object of type SignatureDef which contains an
upconverted version of default signatures in metagraph_def
named_signature_def: object of type SignatureDef which contains an
upconverted version of named signatures in metagraph_def
"""
collection_def = metagraph_def.collection_def
signatures_proto = manifest_pb2.Signatures()
signatures = collection_def[legacy_constants.SIGNATURES_KEY].any_list.value[0]
signatures.Unpack(signatures_proto)
default_signature_def = None
named_signature_def = None
if signatures_proto.HasField("default_signature"):
default_signature_def = _convert_default_signature_to_signature_def(
signatures_proto)
if len(signatures_proto.named_signatures) > 1:
named_signature_def = _convert_named_signatures_to_signature_def(
signatures_proto)
return default_signature_def, named_signature_def
def _load_saved_model_from_session_bundle_path(export_dir, target, config):
"""Load legacy TF Exporter/SessionBundle checkpoint.
Args:
export_dir: the directory that contains files exported by exporter.
target: The execution engine to connect to. See target in
tf.compat.v1.Session()
config: A ConfigProto proto with configuration options. See config in
tf.compat.v1.Session()
Returns:
session: a tensorflow session created from the variable files.
metagraph_def: The `MetaGraphDef` protocol buffer loaded in the provided
session. This can be used to further extract signature-defs,
collection-defs, etc.
This model is up-converted to SavedModel format. Specifically, metagraph_def
SignatureDef field is populated with Signatures converted from legacy
signatures contained within CollectionDef
Raises:
RuntimeError: If metagraph already contains signature_def and cannot be
up-converted.
"""
meta_graph_filename = os.path.join(export_dir,
legacy_constants.META_GRAPH_DEF_FILENAME)
metagraph_def = meta_graph.read_meta_graph_file(meta_graph_filename)
if metagraph_def.signature_def:
raise RuntimeError("Legacy graph contains signature def, unable to "
"up-convert.")
# Add SignatureDef to metagraph.
default_signature_def, named_signature_def = (
_convert_signatures_to_signature_defs(metagraph_def))
if default_signature_def:
metagraph_def.signature_def[
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].CopyFrom(
default_signature_def)
if named_signature_def:
signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
if default_signature_def:
signature_def_key += "_from_named"
metagraph_def.signature_def[signature_def_key].CopyFrom(named_signature_def)
# We cannot just output session we loaded with older metagraph_def and
# up-converted metagraph definition because Session has an internal object of
# type Graph which is populated from meta_graph_def. If we do not create
# session with our new meta_graph_def, then Graph will be out of sync with
# meta_graph_def.
sess, metagraph_def = session_bundle.load_session_bundle_from_path(
export_dir, target, config, meta_graph_def=metagraph_def)
return sess, metagraph_def
def load_session_bundle_or_saved_model_bundle_from_path(export_dir,
tags=None,
target="",
config=None):
"""Load session bundle from the given path.
The function reads input from the export_dir, constructs the graph data to the
default graph and restores the parameters for the session created.
Args:
export_dir: the directory that contains files exported by exporter.
tags: Set of string tags to identify the required MetaGraphDef when model is
saved as SavedModel. These should correspond to the tags used when saving
the variables using the SavedModel `save()` API.
target: The execution engine to connect to. See target in
tf.compat.v1.Session()
config: A ConfigProto proto with configuration options. See config in
tf.compat.v1.Session()
Returns:
session: a tensorflow session created from the variable files.
meta_graph: a meta graph proto saved in the exporter directory.
Raises:
RuntimeError: if the required files are missing or contain unrecognizable
fields, i.e. the exported model is invalid.
"""
metagraph_def = None
sess = None
if loader.maybe_saved_model_directory(export_dir):
sess = session.Session(target, graph=None, config=config)
metagraph_def = loader.load(sess, tags, export_dir)
elif session_bundle.maybe_session_bundle_dir(export_dir):
sess, metagraph_def = _load_saved_model_from_session_bundle_path(
export_dir, target, config)
else:
raise RuntimeError("SessionBundle or SavedModelBundle not found at "
"specified export location: %s" % export_dir)
return sess, metagraph_def