python - How to pick the last valid output values from tensorflow RNN -
i'm training lstm cell on batches of sequences have different lengths. tf.nn.rnn
has convenient parameter sequence_length
, after calling it, don't know how pick output rows corresponding last time step of each item in batch.
my code follows:
lstm_cell = tf.nn.rnn_cell.lstmcell(num_lstm_units, input_size) lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
lstm_outputs
list lstm output @ each time step. however, each item in batch has different length, , create tensor containing last lstm output valid each item in batch.
if use numpy indexing, this:
all_outputs = tf.pack(lstm_outputs) last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
but turns out time begin tensorflow doesn't support (i'm aware of feature request).
so, how these values?
a more acceptable workaround published danijar on feature request page linked in question. doesn't need evaluate tensors, big plus.
i got work tensorflow 0.8. here code:
def extract_last_relevant(outputs, length): """ args: outputs: [tensor(batch_size, output_neurons)]: list containing output activations of each in batch each time step returned tensorflow.models.rnn.rnn. length: tensor(batch_size): used sequence length of each example in batch later time steps being zeros. should of type tf.int32. returns: tensor(batch_size, output_neurons): last relevant output activation each example in batch. """ output = tf.transpose(tf.pack(outputs), perm=[1, 0, 2]) # query shape. batch_size = tf.shape(output)[0] max_length = int(output.get_shape()[1]) num_neurons = int(output.get_shape()[2]) # index flattened array workaround. index = tf.range(0, batch_size) * max_length + (length - 1) flat = tf.reshape(output, [-1, num_neurons]) relevant = tf.gather(flat, index) return relevant
Comments
Post a Comment