python - How to pick the last valid output values from tensorflow RNN -


i'm training lstm cell on batches of sequences have different lengths. tf.nn.rnn has convenient parameter sequence_length, after calling it, don't know how pick output rows corresponding last time step of each item in batch.

my code follows:

lstm_cell = tf.nn.rnn_cell.lstmcell(num_lstm_units, input_size) lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths) 

lstm_outputs list lstm output @ each time step. however, each item in batch has different length, , create tensor containing last lstm output valid each item in batch.

if use numpy indexing, this:

all_outputs = tf.pack(lstm_outputs) last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :] 

but turns out time begin tensorflow doesn't support (i'm aware of feature request).

so, how these values?

a more acceptable workaround published danijar on feature request page linked in question. doesn't need evaluate tensors, big plus.

i got work tensorflow 0.8. here code:

def extract_last_relevant(outputs, length):     """     args:         outputs: [tensor(batch_size, output_neurons)]: list containing output             activations of each in batch each time step returned             tensorflow.models.rnn.rnn.         length: tensor(batch_size): used sequence length of each example in             batch later time steps being zeros. should of type tf.int32.      returns:         tensor(batch_size, output_neurons): last relevant output activation             each example in batch.     """     output = tf.transpose(tf.pack(outputs), perm=[1, 0, 2])     # query shape.     batch_size = tf.shape(output)[0]     max_length = int(output.get_shape()[1])     num_neurons = int(output.get_shape()[2])     # index flattened array workaround.     index = tf.range(0, batch_size) * max_length + (length - 1)     flat = tf.reshape(output, [-1, num_neurons])     relevant = tf.gather(flat, index)     return relevant 

Comments

Popular posts from this blog

java - Run spring boot application error: Cannot instantiate interface org.springframework.context.ApplicationListener -

reactjs - React router and this.props.children - how to pass state to this.props.children -

Excel VBA "Microsoft Windows Common Controls 6.0 (SP6)" Location Changes -