Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Importer for an exported TensorFlow model.
This module provides a function to create a SessionBundle containing both the
Session and MetaGraph.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from tensorflow.contrib.session_bundle import constants
from tensorflow.contrib.session_bundle import manifest_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util.deprecation import deprecated
@deprecated("2017-06-30",
"No longer supported. Switch to SavedModel immediately.")
def maybe_session_bundle_dir(export_dir):
"""Checks if the model path contains session bundle model.
Args:
export_dir: string path to model checkpoint, for example 'model/00000123'
Returns:
true if path contains session bundle model files, ie META_GRAPH_DEF_FILENAME
"""
meta_graph_filename = os.path.join(export_dir,
constants.META_GRAPH_DEF_FILENAME)
return file_io.file_exists(meta_graph_filename)
@deprecated("2017-06-30",
"No longer supported. Switch to SavedModel immediately.")
def load_session_bundle_from_path(export_dir,
target="",
config=None,
meta_graph_def=None):
"""Load session bundle from the given path.
The function reads input from the export_dir, constructs the graph data to the
default graph and restores the parameters for the session created.
Args:
export_dir: the directory that contains files exported by exporter.
target: The execution engine to connect to. See target in
tf.compat.v1.Session()
config: A ConfigProto proto with configuration options. See config in
tf.compat.v1.Session()
meta_graph_def: optional object of type MetaGraphDef. If this object is
present, then it is used instead of parsing MetaGraphDef from export_dir.
Returns:
session: a tensorflow session created from the variable files.
meta_graph: a meta graph proto saved in the exporter directory.
Raises:
RuntimeError: if the required files are missing or contain unrecognizable
fields, i.e. the exported model is invalid.
"""
if not meta_graph_def:
meta_graph_filename = os.path.join(export_dir,
constants.META_GRAPH_DEF_FILENAME)
if not file_io.file_exists(meta_graph_filename):
raise RuntimeError("Expected meta graph file missing %s" %
meta_graph_filename)
# Reads meta graph file.
meta_graph_def = meta_graph_pb2.MetaGraphDef()
meta_graph_def.ParseFromString(
file_io.read_file_to_string(meta_graph_filename, binary_mode=True))
variables_filename = ""
variables_filename_list = []
checkpoint_sharded = False
variables_index_filename = os.path.join(export_dir,
constants.VARIABLES_INDEX_FILENAME_V2)
checkpoint_v2 = file_io.file_exists(variables_index_filename)
# Find matching checkpoint files.
if checkpoint_v2:
# The checkpoint is in v2 format.
variables_filename_pattern = os.path.join(
export_dir, constants.VARIABLES_FILENAME_PATTERN_V2)
variables_filename_list = file_io.get_matching_files(
variables_filename_pattern)
checkpoint_sharded = True
else:
variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME)
if file_io.file_exists(variables_filename):
variables_filename_list = [variables_filename]
else:
variables_filename = os.path.join(export_dir,
constants.VARIABLES_FILENAME_PATTERN)
variables_filename_list = file_io.get_matching_files(variables_filename)
checkpoint_sharded = True
# Prepare the files to restore a session.
if not variables_filename_list:
restore_files = ""
elif checkpoint_v2 or not checkpoint_sharded:
# For checkpoint v2 or v1 with non-sharded files, use "export" to restore
# the session.
restore_files = constants.VARIABLES_FILENAME
else:
restore_files = constants.VARIABLES_FILENAME_PATTERN
assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY)
collection_def = meta_graph_def.collection_def
graph_def = graph_pb2.GraphDef()
if constants.GRAPH_KEY in collection_def:
# Use serving graph_def in MetaGraphDef collection_def if exists
graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
if len(graph_def_any) != 1:
raise RuntimeError("Expected exactly one serving GraphDef in : %s" %
meta_graph_def)
else:
graph_def_any[0].Unpack(graph_def)
# Replace the graph def in meta graph proto.
meta_graph_def.graph_def.CopyFrom(graph_def)
ops.reset_default_graph()
sess = session.Session(target, graph=None, config=config)
# Import the graph.
saver = saver_lib.import_meta_graph(meta_graph_def)
# Restore the session.
if restore_files:
saver.restore(sess, os.path.join(export_dir, restore_files))
init_op_tensor = None
if constants.INIT_OP_KEY in collection_def:
init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
if len(init_ops) != 1:
raise RuntimeError("Expected exactly one serving init op in : %s" %
meta_graph_def)
init_op_tensor = ops.get_collection(constants.INIT_OP_KEY)[0]
# Create asset input tensor list.
asset_tensor_dict = {}
if constants.ASSETS_KEY in collection_def:
assets_any = collection_def[constants.ASSETS_KEY].any_list.value
for asset in assets_any:
asset_pb = manifest_pb2.AssetFile()
asset.Unpack(asset_pb)
asset_tensor_dict[asset_pb.tensor_binding.tensor_name] = os.path.join(
assets_dir, asset_pb.filename)
if init_op_tensor:
# Run the init op.
sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict)
return sess, meta_graph_def