代码之家  ›  专栏  ›  技术社区  ›  A T

从凯拉斯的发电机那里得到X戥u测试,Y戥u测试?

  •  1
  • A T  · 技术社区  · 8 年前

    对于某些问题,验证数据不能是生成器,例如: TensorBoard histograms :

    如果打印直方图,则必须提供验证数据,并且不能是生成器。

    我当前的代码如下:

    image_data_generator = ImageDataGenerator()
    
    training_seq   = image_data_generator.flow_from_directory(training_dir)
    validation_seq = image_data_generator.flow_from_directory(validation_dir)
    testing_seq    = image_data_generator.flow_from_directory(testing_dir)
    
    model = Sequential(..)
    # ..
    model.compile(..)
    model.fit_generator(training_seq, validation_data=validation_seq, ..)
    

    我如何提供它作为 validation_data=(x_test, y_test) ?

    2 回复  |  直到 8 年前
        1
  •  2
  •   today    8 年前

    更新(22/06/2018):阅读OP提供的答案以简洁有效的解决方案。读我的来了解发生了什么。


    在python中,可以使用以下命令获取所有生成器数据:

    data = [x for x in generator]
    

    但是, ImageDataGenerators 不终止,因此上述方法将不起作用。但在这种情况下,我们可以使用相同的方法进行一些修改:

    data = []     # store all the generated data batches
    labels = []   # store all the generated label batches
    max_iter = 100  # maximum number of iterations, in each iteration one batch is generated; the proper value depends on batch size and size of whole data
    i = 0
    for d, l in validation_generator:
        data.append(d)
        labels.append(l)
        i += 1
        if i == max_iter:
            break
    

    现在我们有两个张量批列表。我们需要重塑它们,使其成为两个张量,一个用于数据(即 X 一个用于标签(即 y ):

    data = np.array(data)
    data = np.reshape(data, (data.shape[0]*data.shape[1],) + data.shape[2:])
    
    labels = np.array(labels)
    labels = np.reshape(labels, (labels.shape[0]*labels.shape[1],) + labels.shape[2:])
    
        2
  •  2
  •   A T    7 年前

    Python2.7和Python3.*解决方案:

    from platform import python_version_tuple
    
    if python_version_tuple()[0] == '3':
        xrange = range
        izip = zip
        imap = map
    else:
        from itertools import izip, imap
    
    import numpy as np
    
    # ..
    # other code as in question
    # ..
    
    x, y = izip(*(validation_seq[i] for i in xrange(len(validation_seq))))
    x_val, y_val = np.vstack(x), np.vstack(y)
    

    或支持 class_mode='binary' 然后:

    from keras.utils import to_categorical
    
    x_val = np.vstack(x)
    y_val = np.vstack(imap(to_categorical, y))[:,0] if class_mode == 'binary' else y
    

    完整的可运行代码: https://gist.github.com/AlecTaylor/7f6cc03ed6c3dd84548a039e2e0fd006

    推荐文章