你可以用
tf.scatter_nd_update
为了这个。例如:
A = tf.Variable(
[[1.0986123, 0.6931472, 0. , 0.6931472, 0. ],
[0. , 0. , 0. , 0. , 0. ],
[3.7376697, 3.7612002, 3.7841897, 3.8066626, 3.8286414]], dtype=tf.float32)
B = tf.Variable(
[[2, 1],
[2, 2]], dtype=tf.int64)
C = tf.scatter_nd_update(A, B, tf.zeros(shape=tf.shape(B)[0]))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(C))
A = tf.constant(
[[1.0986123, 0.6931472, 0. , 0.6931472, 0. ],
[0. , 0. , 0. , 0. , 0. ],
[3.7376697, 3.7612002, 3.7841897, 3.8066626, 3.8286414]], dtype=tf.float32)
B = tf.constant(
[[2, 1],
[2, 2]], dtype=tf.int64)
AV = tf.Variable(A)
C = tf.scatter_nd_update(AV, B, tf.zeros(shape=tf.shape(B)[0]))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(C))