可以通过构建和更新张量和遮罩张量来在张量中替换:
import tensorflow as tf
import numpy as np
with tf.Session() as sess:
x = tf.placeholder(tf.float32, shape=(3, 2))
init_val = x
indices = tf.constant([0, 1])
x_shape = tf.shape(x)
indices = tf.expand_dims(indices, 1)
replacement = tf.ones((2, 2))
update = tf.scatter_nd(indices, replacement, x_shape)
mask = tf.scatter_nd(indices, tf.ones_like(replacement, dtype=tf.bool), x_shape)
result = tf.where(mask, update, x)
print(sess.run(result, feed_dict={x: np.arange(6).reshape((3, 2))}))
输出:
[[1. 1.]
[1. 1.]
[4. 5.]]