from caffe2.python import (
core, gradient_checker, rnn_cell, workspace, scope, utils
)
from caffe2.python.attention import AttentionType
from caffe2.python.model_helper import ModelHelper, ExtractPredictorNet
from caffe2.python.rnn.rnn_cell_test_util import sigmoid, tanh, _prepare_rnn
from caffe2.proto import caffe2_pb2
import caffe2.python.hypothesis_test_util as hu
from functools import partial
from hypothesis import assume, given
from hypothesis import settings as ht_settings
import hypothesis.strategies as st
import numpy as np
import unittest
def lstm_unit(*args, **kwargs):
forget_bias = kwargs.get('forget_bias', 0.0)
drop_states = kwargs.get('drop_states', False)
sequence_lengths = kwargs.get('sequence_lengths', True)
if sequence_lengths:
hidden_t_prev, cell_t_prev, gates, seq_lengths, timestep = args
else:
hidden_t_prev, cell_t_prev, gates, timestep = args
D = cell_t_prev.shape[2]
G = gates.shape[2]
N = gates.shape[1]
t = (timestep * np.ones(shape=(N, D))).astype(np.int32)
assert t.shape == (N, D)
assert G == 4 * D
# Resize to avoid broadcasting inconsistencies with NumPy
gates = gates.reshape(N, 4, D)
cell_t_prev = cell_t_prev.reshape(N, D)
i_t = gates[:, 0, :].reshape(N, D)
f_t = gates[:, 1, :].reshape(N, D)
o_t = gates[:, 2, :].reshape(N, D)
g_t = gates[:, 3, :].reshape(N, D)
i_t = sigmoid(i_t)
f_t = sigmoid(f_t + forget_bias)
o_t = sigmoid(o_t)
g_t = tanh(g_t)
if sequence_lengths:
seq_lengths = (np.ones(shape=(N, D)) *
seq_lengths.reshape(N, 1)).astype(np.int32)
assert seq_lengths.shape == (N, D)
valid = (t < seq_lengths).astype(np.int32)
else:
valid = np.ones(shape=(N, D))
assert valid.shape == (N, D)
cell_t = ((f_t * cell_t_prev) + (i_t * g_t)) * (valid) + \
(1 - valid) * cell_t_prev * (1 - drop_states)
assert cell_t.shape == (N, D)
hidden_t = (o_t * tanh(cell_t)) * valid + hidden_t_prev * (
1 - valid) * (1 - drop_states)
hidden_t = hidden_t.reshape(1, N, D)
cell_t = cell_t.reshape(1, N, D)
return hidden_t, cell_t
def layer_norm_with_scale_and_bias_ref(X, scale, bias, axis=-1, epsilon=1e-4):
left = np.prod(X.shape[:axis])
reshaped = np.reshape(X, [left, -1])
mean = np.mean(reshaped, axis=1).reshape([left, 1])
stdev = np.sqrt(
np.mean(np.square(reshaped), axis=1).reshape([left, 1]) -
np.square(mean) + epsilon
)
norm = (reshaped - mean) / stdev
norm = np.reshape(norm, X.shape)
adjusted = scale * norm + bias
return adjusted
def layer_norm_lstm_reference(
input,
hidden_input,
cell_input,
gates_w,
gates_b,
gates_t_norm_scale,
gates_t_norm_bias,
seq_lengths,
forget_bias,
drop_states=False
):
T = input.shape[0]
N = input.shape[1]
G = input.shape[2]
D = hidden_input.shape[hidden_input.ndim - 1]
hidden = np.zeros(shape=(T + 1, N, D))
cell = np.zeros(shape=(T + 1, N, D))
assert hidden.shape[0] == T + 1
assert cell.shape[0] == T + 1
assert hidden.shape[1] == N
assert cell.shape[1] == N
cell[0, :, :] = cell_input
hidden[0, :, :] = hidden_input
for t in range(T):
input_t = input[t].reshape(1, N, G)
print(input_t.shape)
hidden_t_prev = hidden[t].reshape(1, N, D)
cell_t_prev = cell[t].reshape(1, N, D)
gates = np.dot(hidden_t_prev, gates_w.T) + gates_b
gates = gates + input_t
gates = layer_norm_with_scale_and_bias_ref(
gates, gates_t_norm_scale, gates_t_norm_bias
)
hidden_t, cell_t = lstm_unit(
hidden_t_prev,
cell_t_prev,
gates,
seq_lengths,
t,
forget_bias=forget_bias,
drop_states=drop_states,
)
hidden[t + 1] = hidden_t
cell[t + 1] = cell_t
return (
hidden[1:],
hidden[-1].reshape(1, N, D),
cell[1:],
cell[-1].reshape(1, N, D)
)
def lstm_reference(input, hidden_input, cell_input,
gates_w, gates_b, seq_lengths, forget_bias,
drop_states=False):
T = input.shape[0]
N = input.shape[1]
G = input.shape[2]
D = hidden_input.shape[hidden_input.ndim - 1]
hidden = np.zeros(shape=(T + 1, N, D))
cell = np.zeros(shape=(T + 1, N, D))
assert hidden.shape[0] == T + 1
assert cell.shape[0] == T + 1
assert hidden.shape[1] == N
assert cell.shape[1] == N
cell[0, :, :] = cell_input
hidden[0, :, :] = hidden_input
for t in range(T):
input_t = input[t].reshape(1, N, G)
hidden_t_prev = hidden[t].reshape(1, N, D)
cell_t_prev = cell[t].reshape(1, N, D)
gates = np.dot(hidden_t_prev, gates_w.T) + gates_b
gates = gates + input_t
hidden_t, cell_t = lstm_unit(
hidden_t_prev,
cell_t_prev,
gates,
seq_lengths,
t,
forget_bias=forget_bias,
drop_states=drop_states,
)
hidden[t + 1] = hidden_t
cell[t + 1] = cell_t
return (
hidden[1:],
hidden[-1].reshape(1, N, D),
cell[1:],
cell[-1].reshape(1, N, D)
)
def multi_lstm_reference(input, hidden_input_list, cell_input_list,
i2h_w_list, i2h_b_list, gates_w_list, gates_b_list,
seq_lengths, forget_bias, drop_states=False):
num_layers = len(hidden_input_list)
assert len(cell_input_list) == num_layers
assert len(i2h_w_list) == num_layers
assert len(i2h_b_list) == num_layers
assert len(gates_w_list) == num_layers
assert len(gates_b_list) == num_layers
for i in range(num_layers):
layer_input = np.dot(input, i2h_w_list[i].T) + i2h_b_list[i]
h_all, h_last, c_all, c_last = lstm_reference(
layer_input,
hidden_input_list[i],
cell_input_list[i],
gates_w_list[i],
gates_b_list[i],
seq_lengths,
forget_bias,
drop_states=drop_states,
)
input = h_all
return h_all, h_last, c_all, c_last
def compute_regular_attention_logits(
hidden_t,
weighted_decoder_hidden_state_t_w,
weighted_decoder_hidden_state_t_b,
attention_weighted_encoder_context_t_prev,
weighted_prev_attention_context_w,
weighted_prev_attention_context_b,
attention_v,
weighted_encoder_outputs,
encoder_outputs_for_dot_product,
coverage_prev,
coverage_weights,
):
weighted_hidden_t = np.dot(
hidden_t,
weighted_decoder_hidden_state_t_w.T,
) + weighted_decoder_hidden_state_t_b
attention_v = attention_v.reshape([-1])
return np.sum(
attention_v * np.tanh(weighted_encoder_outputs + weighted_hidden_t),
axis=2,
)
def compute_recurrent_attention_logits(
hidden_t,
weighted_decoder_hidden_state_t_w,
weighted_decoder_hidden_state_t_b,
attention_weighted_encoder_context_t_prev,
weighted_prev_attention_context_w,
weighted_prev_attention_context_b,
attention_v,
weighted_encoder_outputs,
encoder_outputs_for_dot_product,
coverage_prev,
coverage_weights,
):
weighted_hidden_t = np.dot(
hidden_t,
weighted_decoder_hidden_state_t_w.T,
) + weighted_decoder_hidden_state_t_b
weighted_prev_attention_context = np.dot(
attention_weighted_encoder_context_t_prev,
weighted_prev_attention_context_w.T
) + weighted_prev_attention_context_b
attention_v = attention_v.reshape([-1])
return np.sum(
attention_v * np.tanh(
weighted_encoder_outputs + weighted_hidden_t +
weighted_prev_attention_context
),
axis=2,
)
def compute_dot_attention_logits(
hidden_t,
weighted_decoder_hidden_state_t_w,
weighted_decoder_hidden_state_t_b,
attention_weighted_encoder_context_t_prev,
weighted_prev_attention_context_w,
weighted_prev_attention_context_b,
attention_v,
weighted_encoder_outputs,
encoder_outputs_for_dot_product,
coverage_prev,
coverage_weights,
):
hidden_t_for_dot_product = np.transpose(hidden_t, axes=[1, 2, 0])
if (
weighted_decoder_hidden_state_t_w is not None and
weighted_decoder_hidden_state_t_b is not None
):
hidden_t_for_dot_product = np.matmul(
weighted_decoder_hidden_state_t_w,
hidden_t_for_dot_product,
) + np.expand_dims(weighted_decoder_hidden_state_t_b, axis=1)
attention_logits_t = np.sum(
np.matmul(
encoder_outputs_for_dot_product,
hidden_t_for_dot_product,
),
axis=2,
)
return np.transpose(attention_logits_t)
def compute_coverage_attention_logits(
hidden_t,
weighted_decoder_hidden_state_t_w,
weighted_decoder_hidden_state_t_b,
attention_weighted_encoder_context_t_prev,
weighted_prev_attention_context_w,
weighted_prev_attention_context_b,
attention_v,
weighted_encoder_outputs,
encoder_outputs_for_dot_product,
coverage_prev,
coverage_weights,
):
weighted_hidden_t = np.dot(
hidden_t,
weighted_decoder_hidden_state_t_w.T,
) + weighted_decoder_hidden_state_t_b
coverage_part = coverage_prev.T * coverage_weights
encoder_part = weighted_encoder_outputs + coverage_part
attention_v = attention_v.reshape([-1])
return np.sum(
attention_v * np.tanh(encoder_part + weighted_hidden_t),
axis=2,
)
def lstm_with_attention_reference(
input,
initial_hidden_state,
initial_cell_state,
initial_attention_weighted_encoder_context,
gates_w,
gates_b,
decoder_input_lengths,
encoder_outputs_transposed,
weighted_prev_attention_context_w,
weighted_prev_attention_context_b,
weighted_decoder_hidden_state_t_w,
weighted_decoder_hidden_state_t_b,
weighted_encoder_outputs,
coverage_weights,
attention_v,
attention_zeros,
compute_attention_logits,
):
encoder_outputs = np.transpose(encoder_outputs_transposed, axes=[2, 0, 1])
encoder_outputs_for_dot_product = np.transpose(
encoder_outputs_transposed,
[0, 2, 1],
)
decoder_input_length = input.shape[0]
batch_size = input.shape[1]
decoder_input_dim = input.shape[2]
decoder_state_dim = initial_hidden_state.shape[2]
encoder_output_dim = encoder_outputs.shape[2]
hidden = np.zeros(
Loading ...