Skip to content Skip to sidebar Skip to footer

Keras Tensors - Get Values With Indices Coming From Another Tensor

Suppose I have these two tensors: valueMatrix, shaped as (?, 3), where ? is the batch size indexMatrix, shaped as (?, 1) I want to retrieve values from valueMatrix at the ind

Solution 1:

import tensorflow as tf
valueMatrix = tf.constant([[7,15,5],[4,6,8]])
indexMatrix = tf.constant([[1],[0]])

# create the row index with tf.range
row_idx = tf.reshape(tf.range(indexMatrix.shape[0]), (-1,1))
# stack with column index
idx = tf.stack([row_idx, indexMatrix], axis=-1)
# extract the elements with gather_nd
values = tf.gather_nd(valueMatrix, idx)

with tf.Session() as sess:
    print(sess.run(values))
#[[15]
# [ 4]]

Post a Comment for "Keras Tensors - Get Values With Indices Coming From Another Tensor"