代码之家  ›  专栏  ›  技术社区  ›  kawingkelvin

tensorflow 2.0 keras使用ImageDataGenerator+flow\u(来自\u目录+tf.data.Dataset)进行培训时出现与“形状”相关的错误

  •  0
  • kawingkelvin  · 技术社区  · 6 年前

    我正在尝试将生成器包装到tf.data.Dataset中(只是为了了解这一点)。 这是我的片段。希望有人能发现我做错了什么。

    img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
    
    gen = img_gen.flow_from_directory(data_path, target_size=(224, 224), batch_size=32)
    
    dataset = tf.data.Dataset.from_generator(
                lambda: gen,
                output_types = (tf.float32, tf.float32),
                output_shapes = ([32, 224, 224, 3], [32, 6]),
    )
    
    model.fit(dataset, 
              steps_per_epoch = gen.n // 32, 
              epochs=10)
    

    值错误: generator 生成了形状元素(112242243),其中形状元素(322243)是预期的。

    0 回复  |  直到 6 年前
        1
  •  0
  •   kawingkelvin    6 年前

    如果我更改此项,我“似乎”解决了此问题:

    dataset = tf.data.Dataset.from_generator(
        lambda: gen,
        output_types = (tf.float32, tf.float32),
        # output_shapes = ([32, 224, 224, 3], [32, 6]),
        output_shapes = ([None, 224, 224, 3], [None, 6]),
    )
    

    我启动了.fit,它显然在工作,每一个时代的损失都在下降,准确性都在上升。

    如果有人有更好的方法或解释,请让我知道。

    推荐文章