代码之家  ›  专栏  ›  技术社区  ›  DeltaIV

不能在keras中将modelcheckpoint与mobilenet一起使用

  •  0
  • DeltaIV  · 技术社区  · 6 年前

    我试着在Keras的虚拟数据上训练MobileNet,在多GPU机器上的Docker容器中。 Initially I was trying to train Xception ,但我决定换一个更小的模型,这样即使是机器功能较弱的人也可以复制我的代码。我有点矛盾 ModelCheckpoint 我无法理解。

    import tensorflow as tf
    import keras.utils
    from keras.applications import MobileNet
    from keras.callbacks import ModelCheckpoint
    from keras.optimizers import Adam
    import numpy as np
    import os
    
    
    height = 224
    width = 224
    channels = 3
    epochs = 10
    num_classes = 10
    
    # Generate dummy data
    batch_size = 32  
    n_train = 256
    n_test = 64
    x_train = np.random.random((n_train, height, width, channels))
    y_train = keras.utils.to_categorical(np.random.randint(num_classes, size=(n_train, 1)), num_classes=num_classes)
    x_test = np.random.random((n_train, height, width, channels))
    y_test = keras.utils.to_categorical(np.random.randint(num_classes, size=(n_test, 1)), num_classes=num_classes)
    # Get input shape
    input_shape = x_train.shape[1:]
    
    # Instantiate model 
    model = MobileNet(weights=None,
                      input_shape=input_shape,
                      classes=num_classes)
    
    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])
    # Viewing Model Configuration
    model.summary()
    
    # Model file name
    filepath = 'model_epoch_{epoch:02d}_loss_{loss:0.2f}_val_{val_loss:.2f}.hdf5'
    
    # Define save_best_only checkpointer
    checkpointer = ModelCheckpoint(filepath=filepath,
                                 monitor='val_acc',
                                 verbose=1,
                                 save_best_only=True)
    
    # Let's fit!
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              validation_data=(x_test, y_test),
              callbacks=[checkpointer])
    

    我得到的错误是

    Traceback (most recent call last):
      File "very_basic_test.py", line 52, in <module>
        callbacks=[checkpointer])
      File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1650, in fit
        batch_size=batch_size)
      File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1490, in _standardize_user_data
        _check_array_lengths(x, y, sample_weights)
      File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 220, in _check_array_lengths
        'and ' + str(list(set_y)[0]) + ' target samples.')
    ValueError: Input arrays should have the same number of samples as target arrays. Found 256 input samples and 64 target samples.
    

    python、keras和tensorflow版本:

    python -c 'import keras; import tensorflow; import sys; print(sys.version, 'keras.__version__', 'tensorflow.__version__')'
    Using TensorFlow backend.
    ('2.7.12 (default, Dec  4 2017, 14:50:18) \n[GCC 5.4.0 20160609]', '2.1.6', '1.7.0')
    
    1 回复  |  直到 6 年前
        1
  •  1
  •   nuric    6 年前

    这个问题与检查点回调无关,而是与您提供的数据有关。看一看 x_train.shape y_train.shape 要检查样本数量、第一维度大小是否不匹配。错误似乎发生在该行上,因为这是 .fit 功能。