代码之家  ›  专栏  ›  技术社区  ›  Shayan RC

使用argmax获得索引的散射更新张量

  •  0
  • Shayan RC  · 技术社区  · 6 年前

    我试图用另一个值来更新张量的最大值,如下所示:

    actions = tf.argmax(output, axis=1)
    gen_targets = tf.scatter_nd_update(output, actions, q_value)
    

    我有个错误: AttributeError: 'Tensor' object has no attribute 'handle' 在…上 scatter_nd_update .

    这个 output actions 占位符是否声明为:

    output = tf.placeholder('float', shape=[None, num_action])
    reward = tf.placeholder('float', shape=[None])
    

    我做错了什么?正确的方法是什么?

    1 回复  |  直到 6 年前
        1
  •  2
  •   abhuse    6 年前

    您正在尝试更新的值 output 这是一种 tf.placeholder .占位符是不可变的对象,不能更新占位符的值。你试图更新的张量应该是变量类型,例如。 tf.Variable ,以便 tf.scatter_nd_update() 能够更新其值。 解决这个问题的一种方法是创建一个变量,然后使用 tf.assign() .因为占位符的一个维度是 None 并且可能在运行时具有任意大小,您可能需要设置 validate_shape 争论 tf.assign() False ,这样占位符的形状就不需要与变量的形状匹配。作业结束后,学生们会 var_output 将匹配通过占位符输入的对象的实际形状。

    output = tf.placeholder('float', shape=[None, num_action])
    # dummy variable initialization
    var_output = tf.Variable(0, dtype=output.dtype)
    
    # assign value of placeholder to the var_output
    var_output = tf.assign(var_output, output, validate_shape=False)
    # ...
    gen_targets = tf.scatter_nd_update(var_output, actions, q_value)
    # ...
    sess.run(gen_targets, feed_dict={output: feed_your_placeholder_here})