代码之家  ›  专栏  ›  技术社区  ›  Vincent Stimper

更新TensorFlow中的张量切片

  •  2
  • Vincent Stimper  · 技术社区  · 7 年前

    我想更新一个三维张量的切片。跟随 How to do slice assignment in Tensorflow 我想做点什么

    import tensorflow as tf
    
    with tf.Session() as sess:
        init_val = tf.Variable(tf.zeros((2, 3, 3)))
        indices = tf.constant([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1]])
        update = tf.scatter_nd_add(init_val, indices, tf.ones(4))
    
        init = tf.global_variables_initializer()
        sess.run(init)
        print(sess.run(update))
    

    这是可行的,但是由于我的实际问题更复杂,我希望通过定义切片的开始和大小,以某种方式自动生成一组索引,例如如果您要使用 tf.slice(...) . 你有什么想法吗?事先谢谢!

    我使用的是TensorFlow 1.12,它是当前最新的版本。

    1 回复  |  直到 7 年前
        1
  •  1
  •   javidcf    7 年前

    tf.strided_slice 支架通过A var 用于指示切片引用的变量的参数,因此当传递该变量时,它将返回一个可分配对象(我不确定为什么它们不只是根据输入的类型执行此操作,而是执行其他操作)。你可以这样做:

    import tensorflow as tf
    import numpy as np
    
    var = tf.Variable(np.ones((3, 4), dtype=np.float32))
    s = tf.strided_slice(var, [0, 2], [2, 3], var=var, name='var_slice')
    s2 = s.assign([[2], [3]])
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)
        print(sess.run(s2))
    

    输出:

    [[1. 1. 2. 1.]
     [1. 1. 3. 1.]
     [1. 1. 1. 1.]]
    

    请注意 tf.跨步切片 您提供开始和结束索引(不包括结束),与中不同 tf.slice ,给出开始和大小。另外,正如目前的代码,您必须为slice或assign操作提供一个name值(我认为这应该是一个bug,并且会发生这种情况,因为API的那部分几乎只在内部使用)。

    推荐文章