Repository URL to install this package:
|
Version:
2.12.0 ▾
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for RNN cells."""
from keras.engine import base_layer
from keras.layers.rnn import rnn_utils
# isort: off
from tensorflow.python.util.tf_export import keras_export
@keras_export("keras.layers.AbstractRNNCell")
class AbstractRNNCell(base_layer.Layer):
"""Abstract object representing an RNN cell.
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
This is the base class for implementing RNN cells with custom behavior.
Every `RNNCell` must have the properties below and implement `call` with
the signature `(output, next_state) = call(input, state)`.
Examples:
```python
class MinimalRNNCell(AbstractRNNCell):
def __init__(self, units, **kwargs):
self.units = units
super(MinimalRNNCell, self).__init__(**kwargs)
@property
def state_size(self):
return self.units
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = backend.dot(inputs, self.kernel)
output = h + backend.dot(prev_output, self.recurrent_kernel)
return output, output
```
This definition of cell differs from the definition used in the literature.
In the literature, 'cell' refers to an object with a single scalar output.
This definition refers to a horizontal array of such units.
An RNN cell, in the most abstract setting, is anything that has
a state and performs some operation that takes a matrix of inputs.
This operation results in an output matrix with `self.output_size` columns.
If `self.state_size` is an integer, this operation also results in a new
state matrix with `self.state_size` columns. If `self.state_size` is a
(possibly nested tuple of) TensorShape object(s), then it should return a
matching structure of Tensors having shape `[batch_size].concatenate(s)`
for each `s` in `self.batch_size`.
"""
def call(self, inputs, states):
"""The function that contains the logic for one RNN step calculation.
Args:
inputs: the input tensor, which is a slide from the overall RNN input
by the time dimension (usually the second dimension).
states: the state tensor from previous step, which has the same shape
as `(batch, state_size)`. In the case of timestep 0, it will be the
initial state user specified, or zero filled tensor otherwise.
Returns:
A tuple of two tensors:
1. output tensor for the current timestep, with size `output_size`.
2. state tensor for next step, which has the shape of `state_size`.
"""
raise NotImplementedError
@property
def state_size(self):
"""size(s) of state(s) used by this cell.
It can be represented by an Integer, a TensorShape or a tuple of
Integers or TensorShapes.
"""
raise NotImplementedError
@property
def output_size(self):
"""Integer or TensorShape: size of outputs produced by this cell."""
raise NotImplementedError
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
return rnn_utils.generate_zero_filled_state_for_cell(
self, inputs, batch_size, dtype
)