我试着在Keras的虚拟数据上训练MobileNet,在多GPU机器上的Docker容器中。
Initially I was trying to train Xception
,但我决定换一个更小的模型,这样即使是机器功能较弱的人也可以复制我的代码。我有点矛盾
ModelCheckpoint
我无法理解。
import tensorflow as tf
import keras.utils
from keras.applications import MobileNet
from keras.callbacks import ModelCheckpoint
from keras.optimizers import Adam
import numpy as np
import os
height = 224
width = 224
channels = 3
epochs = 10
num_classes = 10
# Generate dummy data
batch_size = 32
n_train = 256
n_test = 64
x_train = np.random.random((n_train, height, width, channels))
y_train = keras.utils.to_categorical(np.random.randint(num_classes, size=(n_train, 1)), num_classes=num_classes)
x_test = np.random.random((n_train, height, width, channels))
y_test = keras.utils.to_categorical(np.random.randint(num_classes, size=(n_test, 1)), num_classes=num_classes)
# Get input shape
input_shape = x_train.shape[1:]
# Instantiate model
model = MobileNet(weights=None,
input_shape=input_shape,
classes=num_classes)
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
# Viewing Model Configuration
model.summary()
# Model file name
filepath = 'model_epoch_{epoch:02d}_loss_{loss:0.2f}_val_{val_loss:.2f}.hdf5'
# Define save_best_only checkpointer
checkpointer = ModelCheckpoint(filepath=filepath,
monitor='val_acc',
verbose=1,
save_best_only=True)
# Let's fit!
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
callbacks=[checkpointer])
我得到的错误是
Traceback (most recent call last):
File "very_basic_test.py", line 52, in <module>
callbacks=[checkpointer])
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1650, in fit
batch_size=batch_size)
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 1490, in _standardize_user_data
_check_array_lengths(x, y, sample_weights)
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 220, in _check_array_lengths
'and ' + str(list(set_y)[0]) + ' target samples.')
ValueError: Input arrays should have the same number of samples as target arrays. Found 256 input samples and 64 target samples.
python、keras和tensorflow版本:
python -c 'import keras; import tensorflow; import sys; print(sys.version, 'keras.__version__', 'tensorflow.__version__')'
Using TensorFlow backend.
('2.7.12 (default, Dec 4 2017, 14:50:18) \n[GCC 5.4.0 20160609]', '2.1.6', '1.7.0')