代码之家  ›  专栏  ›  技术社区  ›  Vincent Stimper

利用TensorFlow中的反馈数据进行散点更新

  •  1
  • Vincent Stimper  · 技术社区  · 7 年前

    我正在尝试使用scatter-update更新张量的切片。我第一个熟悉这个函数的代码片段运行得非常好。

    import tensorflow as tf
    import numpy as np
    
    with tf.Session() as sess:
        init_val = tf.Variable(tf.zeros((3, 2)))
        indices = tf.constant([0, 1])
        update = tf.scatter_update(init_val, indices, tf.ones((2, 2)))
    
        init = tf.global_variables_initializer()
        sess.run(init)
        print(sess.run(update))
    

    但是当我尝试将初始值输入到图中时

    with tf.Session() as sess:
        x = tf.placeholder(tf.float32, shape=(3, 2))
        init_val = x
        indices = tf.constant([0, 1])
        update = tf.scatter_update(init_val, indices, tf.ones((2, 2)))
    
        init = tf.global_variables_initializer()
        sess.run(init)
        print(sess.run(update, feed_dict={x: np.zeros((3, 2))}))
    

    我知道这个奇怪的错误

    InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [3,2]
     [[{{node Placeholder_1}} = Placeholder[dtype=DT_FLOAT, shape=[3,2], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
    

    放弃 tf.Variable 围绕 x 当分配给 init_val 也没有帮助,因为我得到了错误

    AttributeError: 'Tensor' object has no attribute '_lazy_read'
    

    (见 this entry 在Github上。有人知道吗?事先谢谢!

    我在CPU上使用TensorFlow 1.12。

    1 回复  |  直到 7 年前
        1
  •  1
  •   javidcf    7 年前

    可以通过构建和更新张量和遮罩张量来在张量中替换:

    import tensorflow as tf
    import numpy as np
    
    with tf.Session() as sess:
        x = tf.placeholder(tf.float32, shape=(3, 2))
        init_val = x
        indices = tf.constant([0, 1])
        x_shape = tf.shape(x)
        indices = tf.expand_dims(indices, 1)
        replacement = tf.ones((2, 2))
        update = tf.scatter_nd(indices, replacement, x_shape)
        mask = tf.scatter_nd(indices, tf.ones_like(replacement, dtype=tf.bool), x_shape)
        result = tf.where(mask, update, x)
        print(sess.run(result, feed_dict={x: np.arange(6).reshape((3, 2))}))
    

    输出:

    [[1. 1.]
     [1. 1.]
     [4. 5.]]