Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

/ python / rnn_cell.py

## @package rnn_cell
# Module caffe2.python.rnn_cell





import functools
import inspect
import logging
import numpy as np
import random
from future.utils import viewkeys

from caffe2.proto import caffe2_pb2
from caffe2.python.attention import (
    apply_dot_attention,
    apply_recurrent_attention,
    apply_regular_attention,
    apply_soft_coverage_attention,
    AttentionType,
)
from caffe2.python import core, recurrent, workspace, brew, scope, utils
from caffe2.python.modeling.parameter_sharing import ParameterSharing
from caffe2.python.modeling.parameter_info import ParameterTags
from caffe2.python.modeling.initializers import Initializer
from caffe2.python.model_helper import ModelHelper


def _RectifyName(blob_reference_or_name):
    if blob_reference_or_name is None:
        return None
    if isinstance(blob_reference_or_name, str):
        return core.ScopedBlobReference(blob_reference_or_name)
    if not isinstance(blob_reference_or_name, core.BlobReference):
        raise Exception("Unknown blob reference type")
    return blob_reference_or_name


def _RectifyNames(blob_references_or_names):
    if blob_references_or_names is None:
        return None
    return [_RectifyName(i) for i in blob_references_or_names]


class RNNCell(object):
    '''
    Base class for writing recurrent / stateful operations.

    One needs to implement 2 methods: apply_override
    and get_state_names_override.

    As a result base class will provice apply_over_sequence method, which
    allows you to apply recurrent operations over a sequence of any length.

    As optional you could add input and output preparation steps by overriding
    corresponding methods.
    '''
    def __init__(self, name=None, forward_only=False, initializer=None):
        self.name = name
        self.recompute_blobs = []
        self.forward_only = forward_only
        self._initializer = initializer

    @property
    def initializer(self):
        return self._initializer

    @initializer.setter
    def initializer(self, value):
        self._initializer = value

    def scope(self, name):
        return self.name + '/' + name if self.name is not None else name

    def apply_over_sequence(
        self,
        model,
        inputs,
        seq_lengths=None,
        initial_states=None,
        outputs_with_grads=None,
    ):
        if initial_states is None:
            with scope.NameScope(self.name):
                if self.initializer is None:
                    raise Exception("Either initial states "
                                    "or initializer have to be set")
                initial_states = self.initializer.create_states(model)

        preprocessed_inputs = self.prepare_input(model, inputs)
        step_model = ModelHelper(name=self.name, param_model=model)
        input_t, timestep = step_model.net.AddScopedExternalInputs(
            'input_t',
            'timestep',
        )
        utils.raiseIfNotEqual(
            len(initial_states), len(self.get_state_names()),
            "Number of initial state values provided doesn't match the number "
            "of states"
        )
        states_prev = step_model.net.AddScopedExternalInputs(*[
            s + '_prev' for s in self.get_state_names()
        ])
        states = self._apply(
            model=step_model,
            input_t=input_t,
            seq_lengths=seq_lengths,
            states=states_prev,
            timestep=timestep,
        )

        external_outputs = set(step_model.net.Proto().external_output)
        for state in states:
            if state not in external_outputs:
                step_model.net.AddExternalOutput(state)

        if outputs_with_grads is None:
            outputs_with_grads = [self.get_output_state_index() * 2]

        # states_for_all_steps consists of combination of
        # states gather for all steps and final states. It looks like this:
        # (state_1_all, state_1_final, state_2_all, state_2_final, ...)
        states_for_all_steps = recurrent.recurrent_net(
            net=model.net,
            cell_net=step_model.net,
            inputs=[(input_t, preprocessed_inputs)],
            initial_cell_inputs=list(zip(states_prev, initial_states)),
            links=dict(zip(states_prev, states)),
            timestep=timestep,
            scope=self.name,
            forward_only=self.forward_only,
            outputs_with_grads=outputs_with_grads,
            recompute_blobs_on_backward=self.recompute_blobs,
        )

        output = self._prepare_output_sequence(
            model,
            states_for_all_steps,
        )
        return output, states_for_all_steps

    def apply(self, model, input_t, seq_lengths, states, timestep):
        input_t = self.prepare_input(model, input_t)
        states = self._apply(
            model, input_t, seq_lengths, states, timestep)
        output = self._prepare_output(model, states)
        return output, states

    def _apply(
        self,
        model, input_t, seq_lengths, states, timestep, extra_inputs=None
    ):
        '''
        This  method uses apply_override provided by a custom cell.
        On the top it takes care of applying self.scope() to all the outputs.
        While all the inputs stay within the scope this function was called
        from.
        '''
        args = self._rectify_apply_inputs(
            input_t, seq_lengths, states, timestep, extra_inputs)
        with core.NameScope(self.name):
            return self.apply_override(model, *args)

    def _rectify_apply_inputs(
            self, input_t, seq_lengths, states, timestep, extra_inputs):
        '''
        Before applying a scope we make sure that all external blob names
        are converted to blob reference. So further scoping doesn't affect them
        '''

        input_t, seq_lengths, timestep = _RectifyNames(
            [input_t, seq_lengths, timestep])
        states = _RectifyNames(states)
        if extra_inputs:
            extra_input_names, extra_input_sizes = zip(*extra_inputs)
            extra_inputs = _RectifyNames(extra_input_names)
            extra_inputs = zip(extra_input_names, extra_input_sizes)

        arg_names = inspect.getargspec(self.apply_override).args
        rectified = [input_t, seq_lengths, states, timestep]
        if 'extra_inputs' in arg_names:
            rectified.append(extra_inputs)
        return rectified


    def apply_override(
        self,
        model, input_t, seq_lengths, timestep, extra_inputs=None,
    ):
        '''
        A single step of a recurrent network to be implemented by each custom
        RNNCell.

        model: ModelHelper object new operators would be added to

        input_t: singlse input with shape (1, batch_size, input_dim)

        seq_lengths: blob containing sequence lengths which would be passed to
        LSTMUnit operator

        states: previous recurrent states

        timestep: current recurrent iteration. Could be used together with
        seq_lengths in order to determine, if some shorter sequences
        in the batch have already ended.

        extra_inputs: list of tuples (input, dim). specifies additional input
        which is not subject to prepare_input(). (useful when a cell is a
        component of a larger recurrent structure, e.g., attention)
        '''
        raise NotImplementedError('Abstract method')

    def prepare_input(self, model, input_blob):
        '''
        If some operations in _apply method depend only on the input,
        not on recurrent states, they could be computed in advance.

        model: ModelHelper object new operators would be added to

        input_blob: either the whole input sequence with shape
        (sequence_length, batch_size, input_dim) or a single input with shape
        (1, batch_size, input_dim).
        '''
        return input_blob

    def get_output_state_index(self):
        '''
        Return index into state list of the "primary" step-wise output.
        '''
        return 0

    def get_state_names(self):
        '''
        Returns recurrent state names with self.name scoping applied
        '''
        return [self.scope(name) for name in self.get_state_names_override()]

    def get_state_names_override(self):
        '''
        Override this function in your custom cell.
        It should return the names of the recurrent states.

        It's required by apply_over_sequence method in order to allocate
        recurrent states for all steps with meaningful names.
        '''
        raise NotImplementedError('Abstract method')

    def get_output_dim(self):
        '''
        Specifies the dimension (number of units) of stepwise output.
        '''
        raise NotImplementedError('Abstract method')

    def _prepare_output(self, model, states):
        '''
        Allows arbitrary post-processing of primary output.
        '''
        return states[self.get_output_state_index()]

    def _prepare_output_sequence(self, model, state_outputs):
        '''
        Allows arbitrary post-processing of primary sequence output.

        (Note that state_outputs alternates between full-sequence and final
        output for each state, thus the index multiplier 2.)
        '''
        output_sequence_index = 2 * self.get_output_state_index()
        return state_outputs[output_sequence_index]


class LSTMInitializer(object):
    def __init__(self, hidden_size):
        self.hidden_size = hidden_size

    def create_states(self, model):
        return [
            model.create_param(
                param_name='initial_hidden_state',
                initializer=Initializer(operator_name='ConstantFill',
                                        value=0.0),
                shape=[self.hidden_size],
            ),
            model.create_param(
                param_name='initial_cell_state',
                initializer=Initializer(operator_name='ConstantFill',
                                        value=0.0),
                shape=[self.hidden_size],
            )
        ]


# based on https://pytorch.org/docs/master/nn.html#torch.nn.RNNCell
class BasicRNNCell(RNNCell):
    def __init__(
        self,
        input_size,
        hidden_size,
        forget_bias,
        memory_optimization,
        drop_states=False,
        initializer=None,
        activation=None,
        **kwargs
    ):
        super(BasicRNNCell, self).__init__(**kwargs)
        self.drop_states = drop_states
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.activation = activation

        if self.activation not in ['relu', 'tanh']:
            raise RuntimeError(
                'BasicRNNCell with unknown activation function (%s)'
                % self.activation)

    def apply_override(
        self,
        model,
        input_t,
        seq_lengths,
        states,
        timestep,
        extra_inputs=None,
    ):
        hidden_t_prev = states[0]

        gates_t = brew.fc(
            model,
            hidden_t_prev,
            'gates_t',
            dim_in=self.hidden_size,
            dim_out=self.hidden_size,
            axis=2,
        )

        brew.sum(model, [gates_t, input_t], gates_t)
        if self.activation == 'tanh':
            hidden_t = model.net.Tanh(gates_t, 'hidden_t')
        elif self.activation == 'relu':
            hidden_t = model.net.Relu(gates_t, 'hidden_t')
        else:
            raise RuntimeError(
                'BasicRNNCell with unknown activation function (%s)'
                % self.activation)
Loading ...