Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
libnvinfer-samples / usr / src / tensorrt / samples / sampleMLP / convert_weights.py
Size: Mime:
#!/usr/bin/python

# Script to convert from TensorFlow weights to TensorRT weights for multilayer perceptron sample.
# Change the remap to properly remap the weights to the name from your trained model
# to the sample expected format.

import sys
import struct
import argparse

try:
    from tensorflow.python import pywrap_tensorflow as pyTF
except ImportError as err:
    sys.stderr.write("""Error: Failed to import module ({})""".format(err))
    sys.exit()

parser = argparse.ArgumentParser(description='TensorFlow to TensorRT Weight Dumper')

parser.add_argument('-m', '--model', required=True, help='The checkpoint file basename, example basename(model.ckpt-766908.data-00000-of-00001) -> model.ckpt-766908')
parser.add_argument('-o', '--output', required=True, help='The weight file to dump all the weights to.')

opt = parser.parse_args()

print "Outputting the trained weights in TensorRT's wts v2 format. This format is documented as:"
print "Line 0: <number of buffers in the file>"
print "Line 1-Num: [buffer name] [buffer type] [(buffer shape{e.g. (1, 2, 3)}] <buffer shaped size bytes of data>"

inputbase = opt.model
outputbase = opt.output

# This dictionary translates from the TF weight names to the weight names expected 
# by the sampleMLP sample. This is the location that needs to be changed if training
# something other than what is specified in README.txt.
remap = {
    'Variable': 'hiddenWeights0',
    'Variable_1': 'hiddenWeights1',
    'Variable_2': 'outputWeights',
    'Variable_3': 'hiddenBias0',
    'Variable_4': 'hiddenBias1',
    'Variable_5': 'outputBias'
}

def float_to_hex(f):
    return hex(struct.unpack('<I', struct.pack('<f', f))[0])

try:
   reader = pyTF.NewCheckpointReader(inputbase)
   tensorDict = reader.get_variable_to_shape_map()
   outputFileName = outputbase + ".wts2"
   outputFile = open(outputFileName, 'w')
   count = 0

   for key in sorted(tensorDict):
       # Don't count weights that aren't used for inferencing.
       if ("Adam" in key or "power" in key):
           continue
       count += 1
   outputFile.write("%s\n"%(count))

   for key in sorted(tensorDict):
       # In order to save space, we don't dump weights that aren't required.
       if ("Adam" in key or "power" in key):
           continue
       tensor = reader.get_tensor(key)
       file_key = remap[key.replace('/','_')]
       val = tensor.shape
       print("%s 0 %s "%(file_key, val))
       flat_tensor = tensor.flatten()
       outputFile.write("%s 0 %s "%(file_key, val))
       outputFile.write(flat_tensor.tobytes())
       outputFile.write("\n");
   outputFile.close()

except Exception as error:
    print(str(error))