代码之家  ›  专栏  ›  技术社区  ›  David Parks

Tensorflow'tf。层。batch\u normalization`不会将更新操作添加到`tf。GraphKeys。更新\u操作`

  •  9
  • David Parks  · 技术社区  · 8 年前

    以下代码(复制/粘贴可运行)说明了使用 tf.layers.batch_normalization .

    import tensorflow as tf
    bn = tf.layers.batch_normalization(tf.constant([0.0]))
    print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
    
    > []     # UPDATE_OPS collection is empty
    

    使用TF 1.5,文件(以下引用)明确指出 UPDATE\u OPS不应为空 在这种情况下( https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization

    注:培训时,moving\u mean和moving\u variance需要 已更新。默认情况下,更新操作放置在 tf.GraphKeys.UPDATE_OPS ,因此需要将它们作为依赖项添加到 列车运行。例如:

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss)
    
    1 回复  |  直到 8 年前
        1
  •  7
  •   David Parks    8 年前

    只需将代码更改为培训模式(通过设置 training 标记为 True )如中所述 quote :

    注:何时 训练 ,需要更新moving\u mean和moving\u variance。默认情况下,更新操作放置在tf中。GraphKeys。更新\u OPS,因此需要将其作为依赖项添加到train\u OPS。

     import tensorflow as tf
     bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
     print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
    

    将输出:

    [< tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(1,) dtype=float32_ref>, 
     < tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(1,) dtype=float32_ref>]
    

    Gamma和Beta最终进入TRAINABLE\u VARIABLES集合:

    print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
    
    [<tf.Variable 'batch_normalization/gamma:0' shape=(1,) dtype=float32_ref>, 
     <tf.Variable 'batch_normalization/beta:0' shape=(1,) dtype=float32_ref>]
    
    推荐文章