Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / mint / app.py

## @package app
# Module caffe2.python.mint.app
import argparse
import flask
import glob
import numpy as np
import nvd3
import os
import sys
import tornado.httpserver
import tornado.wsgi

__folder__ = os.path.abspath(os.path.dirname(__file__))

app = flask.Flask(
    __name__,
    template_folder=os.path.join(__folder__, "templates"),
    static_folder=os.path.join(__folder__, "static")
)
args = None


def jsonify_nvd3(chart):
    chart.buildcontent()
    # Note(Yangqing): python-nvd3 does not seem to separate the built HTML part
    # and the script part. Luckily, it seems to be the case that the HTML part is
    # only a <div>, which can be accessed by chart.container; the script part,
    # while the script part occupies the rest of the html content, which we can
    # then find by chart.htmlcontent.find['<script>'].
    script_start = chart.htmlcontent.find('<script>') + 8
    script_end = chart.htmlcontent.find('</script>')
    return flask.jsonify(
        result=chart.container,
        script=chart.htmlcontent[script_start:script_end].strip()
    )


def visualize_summary(filename):
    try:
        data = np.loadtxt(filename)
    except Exception as e:
        return 'Cannot load file {}: {}'.format(filename, str(e))
    chart_name = os.path.splitext(os.path.basename(filename))[0]
    chart = nvd3.lineChart(
        name=chart_name + '_summary_chart',
        height=args.chart_height,
        y_axis_format='.03g'
    )
    if args.sample < 0:
        step = max(data.shape[0] / -args.sample, 1)
    else:
        step = args.sample
    xdata = np.arange(0, data.shape[0], step)
    # data should have 4 dimensions.
    chart.add_serie(x=xdata, y=data[xdata, 0], name='min')
    chart.add_serie(x=xdata, y=data[xdata, 1], name='max')
    chart.add_serie(x=xdata, y=data[xdata, 2], name='mean')
    chart.add_serie(x=xdata, y=data[xdata, 2] + data[xdata, 3], name='m+std')
    chart.add_serie(x=xdata, y=data[xdata, 2] - data[xdata, 3], name='m-std')
    return jsonify_nvd3(chart)


def visualize_print_log(filename):
    try:
        data = np.loadtxt(filename)
        if data.ndim == 1:
            data = data[:, np.newaxis]
    except Exception as e:
        return 'Cannot load file {}: {}'.format(filename, str(e))
    chart_name = os.path.splitext(os.path.basename(filename))[0]
    chart = nvd3.lineChart(
        name=chart_name + '_log_chart',
        height=args.chart_height,
        y_axis_format='.03g'
    )
    if args.sample < 0:
        step = max(data.shape[0] / -args.sample, 1)
    else:
        step = args.sample
    xdata = np.arange(0, data.shape[0], step)
    # if there is only one curve, we also show the running min and max
    if data.shape[1] == 1:
        # We also print the running min and max for the steps.
        trunc_size = data.shape[0] / step
        running_mat = data[:trunc_size * step].reshape((trunc_size, step))
        chart.add_serie(
            x=xdata[:trunc_size],
            y=running_mat.min(axis=1),
            name='running_min'
        )
        chart.add_serie(
            x=xdata[:trunc_size],
            y=running_mat.max(axis=1),
            name='running_max'
        )
        chart.add_serie(x=xdata, y=data[xdata, 0], name=chart_name)
    else:
        for i in range(0, min(data.shape[1], args.max_curves)):
            # data should have 4 dimensions.
            chart.add_serie(
                x=xdata,
                y=data[xdata, i],
                name='{}[{}]'.format(chart_name, i)
            )

    return jsonify_nvd3(chart)


def visualize_file(filename):
    fullname = os.path.join(args.root, filename)
    if filename.endswith('summary'):
        return visualize_summary(fullname)
    elif filename.endswith('log'):
        return visualize_print_log(fullname)
    else:
        return flask.jsonify(
            result='Unsupport file: {}'.format(filename),
            script=''
        )


@app.route('/')
def index():
    files = glob.glob(os.path.join(args.root, "*.*"))
    files.sort()
    names = [os.path.basename(f) for f in files]
    return flask.render_template(
        'index.html',
        root=args.root,
        names=names,
        debug_messages=names
    )


@app.route('/visualization/<string:name>')
def visualization(name):
    ret = visualize_file(name)
    return ret


def main(argv):
    parser = argparse.ArgumentParser("The mint visualizer.")
    parser.add_argument(
        '-p',
        '--port',
        type=int,
        default=5000,
        help="The flask port to use."
    )
    parser.add_argument(
        '-r',
        '--root',
        type=str,
        default='.',
        help="The root folder to read files for visualization."
    )
    parser.add_argument(
        '--max_curves',
        type=int,
        default=5,
        help="The max number of curves to show in a dump tensor."
    )
    parser.add_argument(
        '--chart_height',
        type=int,
        default=300,
        help="The chart height for nvd3."
    )
    parser.add_argument(
        '-s',
        '--sample',
        type=int,
        default=-200,
        help="Sample every given number of data points. A negative "
        "number means the total points we will sample on the "
        "whole curve. Default 100 points."
    )
    global args
    args = parser.parse_args(argv)
    server = tornado.httpserver.HTTPServer(tornado.wsgi.WSGIContainer(app))
    server.listen(args.port)
    print("Tornado server starting on port {}.".format(args.port))
    tornado.ioloop.IOLoop.instance().start()


if __name__ == '__main__':
    main(sys.argv[1:])