Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / operator_test / rnn_cell_test.py






from caffe2.python import (
    core, gradient_checker, rnn_cell, workspace, scope, utils
)
from caffe2.python.attention import AttentionType
from caffe2.python.model_helper import ModelHelper, ExtractPredictorNet
from caffe2.python.rnn.rnn_cell_test_util import sigmoid, tanh, _prepare_rnn
from caffe2.proto import caffe2_pb2
import caffe2.python.hypothesis_test_util as hu

from functools import partial
from hypothesis import assume, given
from hypothesis import settings as ht_settings
import hypothesis.strategies as st
import numpy as np
import unittest


def lstm_unit(*args, **kwargs):
    forget_bias = kwargs.get('forget_bias', 0.0)
    drop_states = kwargs.get('drop_states', False)
    sequence_lengths = kwargs.get('sequence_lengths', True)

    if sequence_lengths:
        hidden_t_prev, cell_t_prev, gates, seq_lengths, timestep = args
    else:
        hidden_t_prev, cell_t_prev, gates, timestep = args
    D = cell_t_prev.shape[2]
    G = gates.shape[2]
    N = gates.shape[1]
    t = (timestep * np.ones(shape=(N, D))).astype(np.int32)
    assert t.shape == (N, D)
    assert G == 4 * D
    # Resize to avoid broadcasting inconsistencies with NumPy
    gates = gates.reshape(N, 4, D)
    cell_t_prev = cell_t_prev.reshape(N, D)
    i_t = gates[:, 0, :].reshape(N, D)
    f_t = gates[:, 1, :].reshape(N, D)
    o_t = gates[:, 2, :].reshape(N, D)
    g_t = gates[:, 3, :].reshape(N, D)
    i_t = sigmoid(i_t)
    f_t = sigmoid(f_t + forget_bias)
    o_t = sigmoid(o_t)
    g_t = tanh(g_t)
    if sequence_lengths:
        seq_lengths = (np.ones(shape=(N, D)) *
                       seq_lengths.reshape(N, 1)).astype(np.int32)
        assert seq_lengths.shape == (N, D)
        valid = (t < seq_lengths).astype(np.int32)
    else:
        valid = np.ones(shape=(N, D))
    assert valid.shape == (N, D)
    cell_t = ((f_t * cell_t_prev) + (i_t * g_t)) * (valid) + \
        (1 - valid) * cell_t_prev * (1 - drop_states)
    assert cell_t.shape == (N, D)
    hidden_t = (o_t * tanh(cell_t)) * valid + hidden_t_prev * (
        1 - valid) * (1 - drop_states)
    hidden_t = hidden_t.reshape(1, N, D)
    cell_t = cell_t.reshape(1, N, D)
    return hidden_t, cell_t


def layer_norm_with_scale_and_bias_ref(X, scale, bias, axis=-1, epsilon=1e-4):
    left = np.prod(X.shape[:axis])
    reshaped = np.reshape(X, [left, -1])
    mean = np.mean(reshaped, axis=1).reshape([left, 1])
    stdev = np.sqrt(
        np.mean(np.square(reshaped), axis=1).reshape([left, 1]) -
        np.square(mean) + epsilon
    )
    norm = (reshaped - mean) / stdev
    norm = np.reshape(norm, X.shape)
    adjusted = scale * norm + bias

    return adjusted


def layer_norm_lstm_reference(
    input,
    hidden_input,
    cell_input,
    gates_w,
    gates_b,
    gates_t_norm_scale,
    gates_t_norm_bias,
    seq_lengths,
    forget_bias,
    drop_states=False
):
    T = input.shape[0]
    N = input.shape[1]
    G = input.shape[2]
    D = hidden_input.shape[hidden_input.ndim - 1]
    hidden = np.zeros(shape=(T + 1, N, D))
    cell = np.zeros(shape=(T + 1, N, D))
    assert hidden.shape[0] == T + 1
    assert cell.shape[0] == T + 1
    assert hidden.shape[1] == N
    assert cell.shape[1] == N
    cell[0, :, :] = cell_input
    hidden[0, :, :] = hidden_input
    for t in range(T):
        input_t = input[t].reshape(1, N, G)
        print(input_t.shape)
        hidden_t_prev = hidden[t].reshape(1, N, D)
        cell_t_prev = cell[t].reshape(1, N, D)
        gates = np.dot(hidden_t_prev, gates_w.T) + gates_b
        gates = gates + input_t

        gates = layer_norm_with_scale_and_bias_ref(
            gates, gates_t_norm_scale, gates_t_norm_bias
        )

        hidden_t, cell_t = lstm_unit(
            hidden_t_prev,
            cell_t_prev,
            gates,
            seq_lengths,
            t,
            forget_bias=forget_bias,
            drop_states=drop_states,
        )
        hidden[t + 1] = hidden_t
        cell[t + 1] = cell_t
    return (
        hidden[1:],
        hidden[-1].reshape(1, N, D),
        cell[1:],
        cell[-1].reshape(1, N, D)
    )


def lstm_reference(input, hidden_input, cell_input,
                   gates_w, gates_b, seq_lengths, forget_bias,
                   drop_states=False):
    T = input.shape[0]
    N = input.shape[1]
    G = input.shape[2]
    D = hidden_input.shape[hidden_input.ndim - 1]
    hidden = np.zeros(shape=(T + 1, N, D))
    cell = np.zeros(shape=(T + 1, N, D))
    assert hidden.shape[0] == T + 1
    assert cell.shape[0] == T + 1
    assert hidden.shape[1] == N
    assert cell.shape[1] == N
    cell[0, :, :] = cell_input
    hidden[0, :, :] = hidden_input
    for t in range(T):
        input_t = input[t].reshape(1, N, G)
        hidden_t_prev = hidden[t].reshape(1, N, D)
        cell_t_prev = cell[t].reshape(1, N, D)
        gates = np.dot(hidden_t_prev, gates_w.T) + gates_b
        gates = gates + input_t
        hidden_t, cell_t = lstm_unit(
            hidden_t_prev,
            cell_t_prev,
            gates,
            seq_lengths,
            t,
            forget_bias=forget_bias,
            drop_states=drop_states,
        )
        hidden[t + 1] = hidden_t
        cell[t + 1] = cell_t
    return (
        hidden[1:],
        hidden[-1].reshape(1, N, D),
        cell[1:],
        cell[-1].reshape(1, N, D)
    )


def multi_lstm_reference(input, hidden_input_list, cell_input_list,
                            i2h_w_list, i2h_b_list, gates_w_list, gates_b_list,
                            seq_lengths, forget_bias, drop_states=False):
    num_layers = len(hidden_input_list)
    assert len(cell_input_list) == num_layers
    assert len(i2h_w_list) == num_layers
    assert len(i2h_b_list) == num_layers
    assert len(gates_w_list) == num_layers
    assert len(gates_b_list) == num_layers

    for i in range(num_layers):
        layer_input = np.dot(input, i2h_w_list[i].T) + i2h_b_list[i]
        h_all, h_last, c_all, c_last = lstm_reference(
            layer_input,
            hidden_input_list[i],
            cell_input_list[i],
            gates_w_list[i],
            gates_b_list[i],
            seq_lengths,
            forget_bias,
            drop_states=drop_states,
        )
        input = h_all
    return h_all, h_last, c_all, c_last


def compute_regular_attention_logits(
    hidden_t,
    weighted_decoder_hidden_state_t_w,
    weighted_decoder_hidden_state_t_b,
    attention_weighted_encoder_context_t_prev,
    weighted_prev_attention_context_w,
    weighted_prev_attention_context_b,
    attention_v,
    weighted_encoder_outputs,
    encoder_outputs_for_dot_product,
    coverage_prev,
    coverage_weights,
):
    weighted_hidden_t = np.dot(
        hidden_t,
        weighted_decoder_hidden_state_t_w.T,
    ) + weighted_decoder_hidden_state_t_b
    attention_v = attention_v.reshape([-1])
    return np.sum(
        attention_v * np.tanh(weighted_encoder_outputs + weighted_hidden_t),
        axis=2,
    )


def compute_recurrent_attention_logits(
    hidden_t,
    weighted_decoder_hidden_state_t_w,
    weighted_decoder_hidden_state_t_b,
    attention_weighted_encoder_context_t_prev,
    weighted_prev_attention_context_w,
    weighted_prev_attention_context_b,
    attention_v,
    weighted_encoder_outputs,
    encoder_outputs_for_dot_product,
    coverage_prev,
    coverage_weights,
):
    weighted_hidden_t = np.dot(
        hidden_t,
        weighted_decoder_hidden_state_t_w.T,
    ) + weighted_decoder_hidden_state_t_b
    weighted_prev_attention_context = np.dot(
        attention_weighted_encoder_context_t_prev,
        weighted_prev_attention_context_w.T
    ) + weighted_prev_attention_context_b
    attention_v = attention_v.reshape([-1])
    return np.sum(
        attention_v * np.tanh(
            weighted_encoder_outputs + weighted_hidden_t +
            weighted_prev_attention_context
        ),
        axis=2,
    )


def compute_dot_attention_logits(
    hidden_t,
    weighted_decoder_hidden_state_t_w,
    weighted_decoder_hidden_state_t_b,
    attention_weighted_encoder_context_t_prev,
    weighted_prev_attention_context_w,
    weighted_prev_attention_context_b,
    attention_v,
    weighted_encoder_outputs,
    encoder_outputs_for_dot_product,
    coverage_prev,
    coverage_weights,
):
    hidden_t_for_dot_product = np.transpose(hidden_t, axes=[1, 2, 0])
    if (
        weighted_decoder_hidden_state_t_w is not None and
        weighted_decoder_hidden_state_t_b is not None
    ):
        hidden_t_for_dot_product = np.matmul(
            weighted_decoder_hidden_state_t_w,
            hidden_t_for_dot_product,
        ) + np.expand_dims(weighted_decoder_hidden_state_t_b, axis=1)
    attention_logits_t = np.sum(
        np.matmul(
            encoder_outputs_for_dot_product,
            hidden_t_for_dot_product,
        ),
        axis=2,
    )
    return np.transpose(attention_logits_t)


def compute_coverage_attention_logits(
    hidden_t,
    weighted_decoder_hidden_state_t_w,
    weighted_decoder_hidden_state_t_b,
    attention_weighted_encoder_context_t_prev,
    weighted_prev_attention_context_w,
    weighted_prev_attention_context_b,
    attention_v,
    weighted_encoder_outputs,
    encoder_outputs_for_dot_product,
    coverage_prev,
    coverage_weights,
):
    weighted_hidden_t = np.dot(
        hidden_t,
        weighted_decoder_hidden_state_t_w.T,
    ) + weighted_decoder_hidden_state_t_b
    coverage_part = coverage_prev.T * coverage_weights
    encoder_part = weighted_encoder_outputs + coverage_part
    attention_v = attention_v.reshape([-1])
    return np.sum(
        attention_v * np.tanh(encoder_part + weighted_hidden_t),
        axis=2,
    )


def lstm_with_attention_reference(
    input,
    initial_hidden_state,
    initial_cell_state,
    initial_attention_weighted_encoder_context,
    gates_w,
    gates_b,
    decoder_input_lengths,
    encoder_outputs_transposed,
    weighted_prev_attention_context_w,
    weighted_prev_attention_context_b,
    weighted_decoder_hidden_state_t_w,
    weighted_decoder_hidden_state_t_b,
    weighted_encoder_outputs,
    coverage_weights,
    attention_v,
    attention_zeros,
    compute_attention_logits,
):
    encoder_outputs = np.transpose(encoder_outputs_transposed, axes=[2, 0, 1])
    encoder_outputs_for_dot_product = np.transpose(
        encoder_outputs_transposed,
        [0, 2, 1],
    )
    decoder_input_length = input.shape[0]
    batch_size = input.shape[1]
    decoder_input_dim = input.shape[2]
    decoder_state_dim = initial_hidden_state.shape[2]
    encoder_output_dim = encoder_outputs.shape[2]
    hidden = np.zeros(
Loading ...