## @package rnn_cell
# Module caffe2.python.rnn_cell
import functools
import inspect
import logging
import numpy as np
import random
from future.utils import viewkeys
from caffe2.proto import caffe2_pb2
from caffe2.python.attention import (
apply_dot_attention,
apply_recurrent_attention,
apply_regular_attention,
apply_soft_coverage_attention,
AttentionType,
)
from caffe2.python import core, recurrent, workspace, brew, scope, utils
from caffe2.python.modeling.parameter_sharing import ParameterSharing
from caffe2.python.modeling.parameter_info import ParameterTags
from caffe2.python.modeling.initializers import Initializer
from caffe2.python.model_helper import ModelHelper
def _RectifyName(blob_reference_or_name):
if blob_reference_or_name is None:
return None
if isinstance(blob_reference_or_name, str):
return core.ScopedBlobReference(blob_reference_or_name)
if not isinstance(blob_reference_or_name, core.BlobReference):
raise Exception("Unknown blob reference type")
return blob_reference_or_name
def _RectifyNames(blob_references_or_names):
if blob_references_or_names is None:
return None
return [_RectifyName(i) for i in blob_references_or_names]
class RNNCell(object):
'''
Base class for writing recurrent / stateful operations.
One needs to implement 2 methods: apply_override
and get_state_names_override.
As a result base class will provice apply_over_sequence method, which
allows you to apply recurrent operations over a sequence of any length.
As optional you could add input and output preparation steps by overriding
corresponding methods.
'''
def __init__(self, name=None, forward_only=False, initializer=None):
self.name = name
self.recompute_blobs = []
self.forward_only = forward_only
self._initializer = initializer
@property
def initializer(self):
return self._initializer
@initializer.setter
def initializer(self, value):
self._initializer = value
def scope(self, name):
return self.name + '/' + name if self.name is not None else name
def apply_over_sequence(
self,
model,
inputs,
seq_lengths=None,
initial_states=None,
outputs_with_grads=None,
):
if initial_states is None:
with scope.NameScope(self.name):
if self.initializer is None:
raise Exception("Either initial states "
"or initializer have to be set")
initial_states = self.initializer.create_states(model)
preprocessed_inputs = self.prepare_input(model, inputs)
step_model = ModelHelper(name=self.name, param_model=model)
input_t, timestep = step_model.net.AddScopedExternalInputs(
'input_t',
'timestep',
)
utils.raiseIfNotEqual(
len(initial_states), len(self.get_state_names()),
"Number of initial state values provided doesn't match the number "
"of states"
)
states_prev = step_model.net.AddScopedExternalInputs(*[
s + '_prev' for s in self.get_state_names()
])
states = self._apply(
model=step_model,
input_t=input_t,
seq_lengths=seq_lengths,
states=states_prev,
timestep=timestep,
)
external_outputs = set(step_model.net.Proto().external_output)
for state in states:
if state not in external_outputs:
step_model.net.AddExternalOutput(state)
if outputs_with_grads is None:
outputs_with_grads = [self.get_output_state_index() * 2]
# states_for_all_steps consists of combination of
# states gather for all steps and final states. It looks like this:
# (state_1_all, state_1_final, state_2_all, state_2_final, ...)
states_for_all_steps = recurrent.recurrent_net(
net=model.net,
cell_net=step_model.net,
inputs=[(input_t, preprocessed_inputs)],
initial_cell_inputs=list(zip(states_prev, initial_states)),
links=dict(zip(states_prev, states)),
timestep=timestep,
scope=self.name,
forward_only=self.forward_only,
outputs_with_grads=outputs_with_grads,
recompute_blobs_on_backward=self.recompute_blobs,
)
output = self._prepare_output_sequence(
model,
states_for_all_steps,
)
return output, states_for_all_steps
def apply(self, model, input_t, seq_lengths, states, timestep):
input_t = self.prepare_input(model, input_t)
states = self._apply(
model, input_t, seq_lengths, states, timestep)
output = self._prepare_output(model, states)
return output, states
def _apply(
self,
model, input_t, seq_lengths, states, timestep, extra_inputs=None
):
'''
This method uses apply_override provided by a custom cell.
On the top it takes care of applying self.scope() to all the outputs.
While all the inputs stay within the scope this function was called
from.
'''
args = self._rectify_apply_inputs(
input_t, seq_lengths, states, timestep, extra_inputs)
with core.NameScope(self.name):
return self.apply_override(model, *args)
def _rectify_apply_inputs(
self, input_t, seq_lengths, states, timestep, extra_inputs):
'''
Before applying a scope we make sure that all external blob names
are converted to blob reference. So further scoping doesn't affect them
'''
input_t, seq_lengths, timestep = _RectifyNames(
[input_t, seq_lengths, timestep])
states = _RectifyNames(states)
if extra_inputs:
extra_input_names, extra_input_sizes = zip(*extra_inputs)
extra_inputs = _RectifyNames(extra_input_names)
extra_inputs = zip(extra_input_names, extra_input_sizes)
arg_names = inspect.getargspec(self.apply_override).args
rectified = [input_t, seq_lengths, states, timestep]
if 'extra_inputs' in arg_names:
rectified.append(extra_inputs)
return rectified
def apply_override(
self,
model, input_t, seq_lengths, timestep, extra_inputs=None,
):
'''
A single step of a recurrent network to be implemented by each custom
RNNCell.
model: ModelHelper object new operators would be added to
input_t: singlse input with shape (1, batch_size, input_dim)
seq_lengths: blob containing sequence lengths which would be passed to
LSTMUnit operator
states: previous recurrent states
timestep: current recurrent iteration. Could be used together with
seq_lengths in order to determine, if some shorter sequences
in the batch have already ended.
extra_inputs: list of tuples (input, dim). specifies additional input
which is not subject to prepare_input(). (useful when a cell is a
component of a larger recurrent structure, e.g., attention)
'''
raise NotImplementedError('Abstract method')
def prepare_input(self, model, input_blob):
'''
If some operations in _apply method depend only on the input,
not on recurrent states, they could be computed in advance.
model: ModelHelper object new operators would be added to
input_blob: either the whole input sequence with shape
(sequence_length, batch_size, input_dim) or a single input with shape
(1, batch_size, input_dim).
'''
return input_blob
def get_output_state_index(self):
'''
Return index into state list of the "primary" step-wise output.
'''
return 0
def get_state_names(self):
'''
Returns recurrent state names with self.name scoping applied
'''
return [self.scope(name) for name in self.get_state_names_override()]
def get_state_names_override(self):
'''
Override this function in your custom cell.
It should return the names of the recurrent states.
It's required by apply_over_sequence method in order to allocate
recurrent states for all steps with meaningful names.
'''
raise NotImplementedError('Abstract method')
def get_output_dim(self):
'''
Specifies the dimension (number of units) of stepwise output.
'''
raise NotImplementedError('Abstract method')
def _prepare_output(self, model, states):
'''
Allows arbitrary post-processing of primary output.
'''
return states[self.get_output_state_index()]
def _prepare_output_sequence(self, model, state_outputs):
'''
Allows arbitrary post-processing of primary sequence output.
(Note that state_outputs alternates between full-sequence and final
output for each state, thus the index multiplier 2.)
'''
output_sequence_index = 2 * self.get_output_state_index()
return state_outputs[output_sequence_index]
class LSTMInitializer(object):
def __init__(self, hidden_size):
self.hidden_size = hidden_size
def create_states(self, model):
return [
model.create_param(
param_name='initial_hidden_state',
initializer=Initializer(operator_name='ConstantFill',
value=0.0),
shape=[self.hidden_size],
),
model.create_param(
param_name='initial_cell_state',
initializer=Initializer(operator_name='ConstantFill',
value=0.0),
shape=[self.hidden_size],
)
]
# based on https://pytorch.org/docs/master/nn.html#torch.nn.RNNCell
class BasicRNNCell(RNNCell):
def __init__(
self,
input_size,
hidden_size,
forget_bias,
memory_optimization,
drop_states=False,
initializer=None,
activation=None,
**kwargs
):
super(BasicRNNCell, self).__init__(**kwargs)
self.drop_states = drop_states
self.input_size = input_size
self.hidden_size = hidden_size
self.activation = activation
if self.activation not in ['relu', 'tanh']:
raise RuntimeError(
'BasicRNNCell with unknown activation function (%s)'
% self.activation)
def apply_override(
self,
model,
input_t,
seq_lengths,
states,
timestep,
extra_inputs=None,
):
hidden_t_prev = states[0]
gates_t = brew.fc(
model,
hidden_t_prev,
'gates_t',
dim_in=self.hidden_size,
dim_out=self.hidden_size,
axis=2,
)
brew.sum(model, [gates_t, input_t], gates_t)
if self.activation == 'tanh':
hidden_t = model.net.Tanh(gates_t, 'hidden_t')
elif self.activation == 'relu':
hidden_t = model.net.Relu(gates_t, 'hidden_t')
else:
raise RuntimeError(
'BasicRNNCell with unknown activation function (%s)'
% self.activation)
Loading ...