代码之家  ›  专栏  ›  技术社区  ›  McAngus

如何将“tf.nn.dynamic”与非rnn组件一起使用

  •  -1
  • McAngus  · 技术社区  · 7 年前

    我有一个架构,在输入RNN之前使用编码器。编码器输入形状为 [batch, height, width, channels] rnn输入是shape [batch, time, height, width, channels] . 我想将编码器的输出直接提供给RNN,但这会造成内存问题。我得养活她 batch*time ~= 3*100 (通过重塑)一次将图像放入编码器。我知道 tf.nn.dynamic_rnn 可以利用杠杆 swap_memory ,我也想在编码器中利用这个。下面是一些压缩代码:

    #image inputs [batch, time, height, width, channels]
    inputs = tf.placeholder(tf.float32, [batch, time, in_sh[0], in_sh[1], in_sh[2]])
    
    #This is where the trouble starts
    #merge batch and time
    inputs = tf.reshape(inputs, [batch*time, in_sh[0], in_sh[1], in_sh[2]])
    #build the encoder (and get shape of output)
    enc, enc_sh = build_encoder(inputs)
    #change back to time format
    enc = tf.reshape(enc, [batch, time, enc_sh[0], enc_sh[1], enc_sh[2]])
    
    #build rnn and get initial state (zero_state)
    rnn, initial_state = build_rnn()
    #use dynamic unrolling
    rnn_outputs, rnn_state = tf.nn.dynamic_rnn(
            rnn, enc,
            initial_state=initial_state,
            swap_memory=True,
            time_major=False)
    

    我目前使用的方法是先对所有图像运行编码器(并保存到磁盘),但我希望执行数据集增强(对图像),这在提取特征后是不可能的。

    1 回复  |  直到 7 年前
        1
  •  0
  •   McAngus    7 年前

    对于其他遇到这个问题的人。我做了一个包装 RNNCell 这就满足了我的需要。这个 model_fn 是一个函数,它使用输入创建子图并返回输出张量。不幸的是,输出形状必须是已知的(至少我不能让它工作否则)。

    class WrapperCell(tf.nn.rnn_cell.RNNCell):
        """A Wrapper for a non recurrent component that feeds into an RNN."""
    
        def __init__(self, model_fn, out_shape, reuse=None):
            super(WrapperCell, self).__init__(_reuse=reuse)
            self.out_shape = out_shape
            self.model_fn = model_fn
    
        @property
        def state_size(self):
            return tf.TensorShape([1])
    
        @property
        def output_size(self):
            return tf.TensorShape(self.out_shape)
    
        def call(self, x, h):
            mod = self.model_fn(x)
            return mod, tf.zeros_like(h)