代码之家  ›  专栏  ›  技术社区  ›  matt

如何创建将被重用的tensorflow操作

  •  0
  • matt  · 技术社区  · 7 年前

    在输入/训练模型之前,我有一些数据要处理。对于这个例子,我想做一个max pool 2d。我用tensorflow写了一个简短的函数。

    import tensorflow
    import tensorflow.nn as nn
    
    def _tfMaxPool(arr, pool=(4,4), sess=None):
        op = nn.max_pool(arr, (1, 1, pool[0], 1), (1, 1, pool[0], 1 ), padding="VALID")
        op = nn.max_pool(op, (1, 1, 1, pool[1]), (1, 1, 1, pool[1]), padding="VALID")
        if sess is None:
            sess = tensorflow.Session();
    
        return sess.run(op)
    

    问题是这样每次都会在我的图中添加节点,这似乎会扰乱我的会话。另一种方法是创建模型。

    import keras
    
    seq = keras.Sequential([ 
                keras.layers.InputLayer((1, 512, 512)), 
                keras.layers.MaxPool2D((4, 4), (4, 4), data_format="channels_first")
                         ])
    def _tfMaxPool2(arr, pool=(4,4), sess=None):
        swapped = arr.swapaxes(0,1)
        return seq.predict(swapped).swapaxes(0,1)
    

    这个模型和我想要的几乎一模一样,但我认为我遗漏了一些基本的东西。

    1 回复  |  直到 7 年前
        1
  •  1
  •   saket    7 年前

    你为什么不在各种输入中重用你的图表呢?在下面的代码中,tfMaxpool只定义了一次。

    def _tfMaxPool(arr, pool=(4,4)):
        op = nn.max_pool(arr, (1, 1, pool[0], 1), (1, 1, pool[0], 1 ), padding="VALID")
        op = nn.max_pool(op, (1, 1, 1, pool[1]), (1, 1, 1, pool[1]), padding="VALID")
    
        return op
    
    input = tf.placeholder()
    output = _tfMaxPool(input)
    
    with tf.Session() as sess:
        sess.run(output, feed_dict={input:arr1})
        sess.run(output, feed_dict={input:arr2})