代码之家  ›  专栏  ›  技术社区  ›  AaronDT

Keras:从自定义生成器中获取一个批次

  •  1
  • AaronDT  · 技术社区  · 7 年前

    我用这里的代码创建了一个自定义生成器 https://github.com/keras-team/keras/issues/3386#issuecomment-237555199 作为模板。而不是使用 .flow .flow_from_dataframe . 因此,我的发电机看起来是这样的:

    def createGenerator( X, I):
    
        while True:
    
            # suffled indices    
            idx = np.random.permutation( X.shape[0])
            # create image generator
            datagen=ImageDataGenerator(preprocessing_function=preprocess_input,
                                  horizontal_flip = True,
                                  rotation_range=20)
    
    
            batches = datagen.flow_from_dataframe(shuffle=False,dataframe=X[["url",target]], directory=bilder_resized_dir, x_col="url", y_col=target,
                                                has_ext=False, class_mode="categorical", color_mode=color_mode, target_size=(IMG_SIZE,IMG_SIZE), batch_size=batch_size,classes=class_list)
            idx0 = 0
            for batch in batches:
                idx1 = idx0 + batch[0].shape[0]
    
                yield [batch[0], I[ idx[ idx0:idx1 ] ]], batch[1]
    
                idx0 = idx1
                if idx1 >= X.shape[0]:
                    break
    

    next() 就像这样:

    combo_gen = createGenerator(X1_dataframe ,X2_aux_input)
    x,y = next(combo_gen)
    

    但是,这会导致以下错误:

    ---------------------------------------------------------------------------
    StopIteration                             Traceback (most recent call last)
    <ipython-input-274-4a67da5d38d9> in <module>
    ----> 1 x,y = next(combo_gen)
    
    StopIteration: 
    

    我怎样才能从这台发电机上得到一批呢?

    0 回复  |  直到 7 年前