Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ python / operator_test / normalize_op_test.py





import functools

import numpy as np
from hypothesis import given, settings
from caffe2.python import core
import caffe2.python.hypothesis_test_util as hu
import copy


class TestNormalizeOp(hu.HypothesisTestCase):
    @given(
        X=hu.tensor(
            min_dim=1, max_dim=5, elements=hu.floats(min_value=0.5, max_value=1.0)
        ),
        **hu.gcs
    )
    @settings(max_examples=10, deadline=None)
    def test_normalize(self, X, gc, dc):
        def ref_normalize(X, axis):
            x_normed = X / np.maximum(
                np.sqrt((X ** 2).sum(axis=axis, keepdims=True)), 1e-12
            )
            return (x_normed,)

        for axis in range(-X.ndim, X.ndim):
            x = copy.copy(X)
            op = core.CreateOperator("Normalize", "X", "Y", axis=axis)
            self.assertReferenceChecks(
                gc, op, [x], functools.partial(ref_normalize, axis=axis)
            )
            self.assertDeviceChecks(dc, op, [x], [0])
            self.assertGradientChecks(gc, op, [x], 0, [0])

    @given(
        X=hu.tensor(
            min_dim=1, max_dim=5, elements=hu.floats(min_value=0.5, max_value=1.0)
        ),
        **hu.gcs
    )
    @settings(max_examples=10, deadline=None)
    def test_normalize_L1(self, X, gc, dc):
        def ref(X, axis):
            norm = abs(X).sum(axis=axis, keepdims=True)
            return (X / norm,)

        for axis in range(-X.ndim, X.ndim):
            print("axis: ", axis)
            op = core.CreateOperator("NormalizeL1", "X", "Y", axis=axis)
            self.assertReferenceChecks(gc, op, [X], functools.partial(ref, axis=axis))
            self.assertDeviceChecks(dc, op, [X], [0])