Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / modeling / compute_statistics_for_blobs.py






from caffe2.python import core, schema
from caffe2.python.modeling.net_modifier import NetModifier

import numpy as np


class ComputeStatisticsForBlobs(NetModifier):
    """
    This class modifies the net passed in by adding ops to compute statistics
    for certain blobs. For each blob in the list, its min, max, mean and standard
    deviation will be computed.

    Args:
        blobs: list of blobs to compute norm for
        logging_frequency: frequency for printing norms to logs
    """

    def __init__(self, blobs, logging_frequency):
        self._blobs = blobs
        self._logging_frequency = logging_frequency
        self._field_name_suffix = '_summary'

    def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None,
                   modify_output_record=False):

        for blob_name in self._blobs:
            blob = core.BlobReference(blob_name)
            assert net.BlobIsDefined(blob), 'blob {} is not defined in net {} whose proto is {}'.format(blob, net.Name(), net.Proto())

            cast_blob = net.Cast(blob, to=core.DataType.FLOAT)
            stats_name = net.NextScopedBlob(prefix=blob + self._field_name_suffix)
            stats = net.Summarize(cast_blob, stats_name, to_file=0)
            net.Print(stats, [], every_n=self._logging_frequency)

            if modify_output_record:
                output_field_name = str(blob) + self._field_name_suffix
                output_scalar = schema.Scalar((np.float, (1,)), stats)

                if net.output_record() is None:
                    net.set_output_record(
                        schema.Struct((output_field_name, output_scalar))
                    )
                else:
                    net.AppendOutputRecordField(
                        output_field_name,
                        output_scalar)

    def field_name_suffix(self):
        return self._field_name_suffix