## @package app
# Module caffe2.python.mint.app
import argparse
import flask
import glob
import numpy as np
import nvd3
import os
import sys
import tornado.httpserver
import tornado.wsgi
__folder__ = os.path.abspath(os.path.dirname(__file__))
app = flask.Flask(
__name__,
template_folder=os.path.join(__folder__, "templates"),
static_folder=os.path.join(__folder__, "static")
)
args = None
def jsonify_nvd3(chart):
chart.buildcontent()
# Note(Yangqing): python-nvd3 does not seem to separate the built HTML part
# and the script part. Luckily, it seems to be the case that the HTML part is
# only a <div>, which can be accessed by chart.container; the script part,
# while the script part occupies the rest of the html content, which we can
# then find by chart.htmlcontent.find['<script>'].
script_start = chart.htmlcontent.find('<script>') + 8
script_end = chart.htmlcontent.find('</script>')
return flask.jsonify(
result=chart.container,
script=chart.htmlcontent[script_start:script_end].strip()
)
def visualize_summary(filename):
try:
data = np.loadtxt(filename)
except Exception as e:
return 'Cannot load file {}: {}'.format(filename, str(e))
chart_name = os.path.splitext(os.path.basename(filename))[0]
chart = nvd3.lineChart(
name=chart_name + '_summary_chart',
height=args.chart_height,
y_axis_format='.03g'
)
if args.sample < 0:
step = max(data.shape[0] / -args.sample, 1)
else:
step = args.sample
xdata = np.arange(0, data.shape[0], step)
# data should have 4 dimensions.
chart.add_serie(x=xdata, y=data[xdata, 0], name='min')
chart.add_serie(x=xdata, y=data[xdata, 1], name='max')
chart.add_serie(x=xdata, y=data[xdata, 2], name='mean')
chart.add_serie(x=xdata, y=data[xdata, 2] + data[xdata, 3], name='m+std')
chart.add_serie(x=xdata, y=data[xdata, 2] - data[xdata, 3], name='m-std')
return jsonify_nvd3(chart)
def visualize_print_log(filename):
try:
data = np.loadtxt(filename)
if data.ndim == 1:
data = data[:, np.newaxis]
except Exception as e:
return 'Cannot load file {}: {}'.format(filename, str(e))
chart_name = os.path.splitext(os.path.basename(filename))[0]
chart = nvd3.lineChart(
name=chart_name + '_log_chart',
height=args.chart_height,
y_axis_format='.03g'
)
if args.sample < 0:
step = max(data.shape[0] / -args.sample, 1)
else:
step = args.sample
xdata = np.arange(0, data.shape[0], step)
# if there is only one curve, we also show the running min and max
if data.shape[1] == 1:
# We also print the running min and max for the steps.
trunc_size = data.shape[0] / step
running_mat = data[:trunc_size * step].reshape((trunc_size, step))
chart.add_serie(
x=xdata[:trunc_size],
y=running_mat.min(axis=1),
name='running_min'
)
chart.add_serie(
x=xdata[:trunc_size],
y=running_mat.max(axis=1),
name='running_max'
)
chart.add_serie(x=xdata, y=data[xdata, 0], name=chart_name)
else:
for i in range(0, min(data.shape[1], args.max_curves)):
# data should have 4 dimensions.
chart.add_serie(
x=xdata,
y=data[xdata, i],
name='{}[{}]'.format(chart_name, i)
)
return jsonify_nvd3(chart)
def visualize_file(filename):
fullname = os.path.join(args.root, filename)
if filename.endswith('summary'):
return visualize_summary(fullname)
elif filename.endswith('log'):
return visualize_print_log(fullname)
else:
return flask.jsonify(
result='Unsupport file: {}'.format(filename),
script=''
)
@app.route('/')
def index():
files = glob.glob(os.path.join(args.root, "*.*"))
files.sort()
names = [os.path.basename(f) for f in files]
return flask.render_template(
'index.html',
root=args.root,
names=names,
debug_messages=names
)
@app.route('/visualization/<string:name>')
def visualization(name):
ret = visualize_file(name)
return ret
def main(argv):
parser = argparse.ArgumentParser("The mint visualizer.")
parser.add_argument(
'-p',
'--port',
type=int,
default=5000,
help="The flask port to use."
)
parser.add_argument(
'-r',
'--root',
type=str,
default='.',
help="The root folder to read files for visualization."
)
parser.add_argument(
'--max_curves',
type=int,
default=5,
help="The max number of curves to show in a dump tensor."
)
parser.add_argument(
'--chart_height',
type=int,
default=300,
help="The chart height for nvd3."
)
parser.add_argument(
'-s',
'--sample',
type=int,
default=-200,
help="Sample every given number of data points. A negative "
"number means the total points we will sample on the "
"whole curve. Default 100 points."
)
global args
args = parser.parse_args(argv)
server = tornado.httpserver.HTTPServer(tornado.wsgi.WSGIContainer(app))
server.listen(args.port)
print("Tornado server starting on port {}.".format(args.port))
tornado.ioloop.IOLoop.instance().start()
if __name__ == '__main__':
main(sys.argv[1:])