Dynamic RNN In Keras: Use Custom RNN Cell To Track Other Outputs At Each Timestep
Is there a way to return multiple outputs for a given timestep when implementing a custom cell for an RNN in keras? E.g. outputs with shapes: (sequences=[batch, timesteps, hidden_u
Solution 1:
Figured it out. You can just make the output size a list with any dimensions and then the RNN will track the outputs. The class below also includes the use of constants in the RNN call because the previously mentioned paper passes an encoder latent space (z_enc) to the recurrent decoder:
class CustomMultiTimeStepGRUCell(tf.keras.layers.Layer):
"""Illustrates multiple sequence like (n, timestep, size) outputs."""
def __init__(self, units, arbitrary_units, **kwargs):
"""Defines state for custom cell.
:param units: <class 'int'> Hidden units for the RNN cell.
:param arbitrary_units: <class 'int'> Hidden units for another
dense network that outputs a tensor at each timestep in the
unrolling of the RNN.
"""
super().__init__(**kwargs)
# Save args
self.units = units
self.arbitrary_units = arbitrary_units
# Standard recurrent cell
self.gru = tf.keras.layers.GRUCell(units=self.units)
# For use with 'constant' kwarg in 'call' method
self.concatenate = tf.keras.layers.Concatenate()
self.dense_proj = tf.keras.layers.Dense(units=self.units)
# For arbitrary computation at timestep t
self.other_output = tf.keras.layers.Dense(units=self.arbitrary_units)
# Hidden state size (i.e., h_t)...
# it's useful to know in general that this refers to the following:
# 'gru_cell = tf.keras.GRUCell(units=state_size)'
# 'seq, h_t = gru_cell(data)'
# 'h_t.shape' -> '(?, state_size)'
self.state_size = tf.TensorShape([self.units])
# OUTPUT SIZE: PROBLEM SOLVED!!!!
# This is the last dimension of the RNN sequence output.
# Typically the last dimension matches the dimension of
# self.state_size, and in fact the keras RNN will infer
# the output size based on state size if output size is not
# specified. In the case of output size that does not match the
# state size, you have to specify and in list format if
# multiple outputs can occur per timestep in the RNN.
self.output_size = [tf.TensorShape([self.units]), tf.TensorShape([self.arbitrary_units])]
def call(self, input_at_t, states_at_t, constants):
"""Forward pass for custom RNN cell.
:param inputs_at_t: (batch, features) tensor from (batch, t, features)
inputs
:param states_at_t: <class 'tuple'> that has 1 element if
if using GRUCell (h_t), or 2 elements if using LSTMCell (h_t, c_t)
:param constants: <class 'tuple'> Unchanging tensors to be used
in the unrolling of the RNN.
:return: <class 'tuple'> with two elements.
(1) <class 'list'> Both elements of this list are tensors
that are tracked for each timestep in the unrolling of the RNN.
(2) Tensor representing the hidden state passed to the next
cell.
In the brief graphic below, a_t denotes the arbitrary output
at each timestep. y_t = h_t_plus_1. x_t is some input at
timestep t.
a_t y_t
^ ^
__|____|
h_t | | h_t_plus_1
-----> | | ----------> .....
|______|
^
|
x_t
When all timesteps in x where x = {x_t}_{t=1}^{T} are processed
by the RNN, the resulting shapes of the outputs assuming there
is only a single sample (batch = 1) would be the following:
Y = (1, timesteps, units)
A = (1, timesteps, arbitrary_units)
h_t_plus_1 = (1, units) # Last hidden state
For a concrete example, see the end of this codeblock.
"""
# Get correct inputs -- by default these args are tuples...
# so you must index 0 to get the relevant element.
# Note, if you are using LSTM, then the hidden states passed to the
# the next cell in the RNN will be a tuple with two elements
# i.e., (h_t, c_t) for the hidden and cell state, respectively.
states_at_t = states_at_t[0]
z_enc = constants[0]
# Combine the states with z_enc
combined = self.concatenate([states_at_t, z_enc])
# Project to dimensions for GRU cell
special_states_at_t = self.dense_proj(combined)
# Standard GRU call
output_at_t, states_at_t_plus_1 = self.gru(input_at_t, special_states_at_t)
# Get another output at t
arbitrary_output_at_t = self.other_output(input_at_t)
# The outputs
return [output_at_t, arbitrary_output_at_t], states_at_t_plus_1
# Dims
batch = 4
timesteps = 3
features = 12
latent = 8
hidden_units = 10
arbitary_units = 15
# Data
inputs = tf.random.normal(shape=(batch, timesteps, features))
h_t = tf.zeros(shape=(batch, hidden_units))
z_enc = tf.random.normal(shape=(batch, latent))
# An RNN cell to test multitimestep outputs
custom_multistep_cell = CustomMultiTimeStepGRUCell(units=hidden_units, arbitrary_units=arbitary_units)
custom_multistep_rnn = tf.keras.layers.RNN(custom_multistep_cell, return_sequences=True, return_state=True)
# Call cell
outputs, special_outputs, last_hidden = custom_multistep_rnn(inputs, initial_state=h_t, constants=z_enc)
print(outputs.shape)
print(special_outputs.shape)
print(last_hidden.shape)
>>> (4, 3, 10)
>>> (4, 3, 15)
>>> (4, 10)
Post a Comment for "Dynamic RNN In Keras: Use Custom RNN Cell To Track Other Outputs At Each Timestep"