我正在使用Keras中的序列生成器并行地从磁盘获取数据,但我得到了一个非常奇怪的错误。
这是我的序列生成器代码
class detracSequence(Sequence):
def __init__(self, x_set, y_set, bbox_set, batch_size):
self.x, self.y, self.bbox = x_set, y_set, bbox_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
print 'index range', idx*self.batch_size, 'till', (idx+1)*self.batch_size
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_bbox = self.bbox[idx * self.batch_size:(idx + 1) * self.batch_size]
# print batch_x
imgs = np.ndarray((self.batch_size,128,128,3))
for file_index in range(self.batch_size):
temp = cv2.imread(batch_x[file_index])
if temp.shape[0] == 0:
print '1', batch_x[file_index]
# print '1', temp.shape
#print(temp)
x1, x2, x3, x4 = batch_bbox[file_index,0],batch_bbox[file_index,1],batch_bbox[file_index,2], batch_bbox[file_index,3]
#print batch_x[file_index]
temp_ = temp[int(x2):int(x4),int(x1):int(x3)]
imgs[file_index] = cv2.resize(temp_,(128,128))
return imgs, np.array(batch_y)
这就是调用这个生成器的代码。
Xtrain_gen = detracSequence(X_train,y_train,training_coordinates, batch_size=32)
history = model.fit_generator(generator=Xtrain_gen, epochs=20, validation_data=Xvalidation_gen,callbacks=callbacks_list,use_multiprocessing=True)
现在,问题是IDX的值是由内部代码生成的。我的期望是它将处理索引绑定。但在方法上
获取项目
(self,idx),我得到一个idx的值,它给了我索引超出范围的错误,如下所示,它是位wierd。这是错误日志
Traceback (most recent call last):
File "finetuneInceptionV3.py", line 112, in <module>
history = model.fit_generator(generator=Xtrain_gen, epochs=20, validation_data=Xvalidation_gen,callbacks=callbacks_list,use_multiprocessing=True)
File "/home/sfarkya/tfenv/local/lib/python2.7/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/home/sfarkya/tfenv/local/lib/python2.7/site-packages/keras/engine/training.py", line 2192, in fit_generator
generator_output = next(output_generator)
File "/home/sfarkya/tfenv/local/lib/python2.7/site-packages/keras/utils/data_utils.py", line 584, in get
six.raise_from(StopIteration(e), e)
File "/home/sfarkya/tfenv/local/lib/python2.7/site-packages/six.py", line 737, in raise_from
raise value
StopIteration: list index out of range
现在,我不知道如何在不涉及源代码的情况下解决这个问题,但我不希望发生这种情况。有人能告诉我这里有没有遗失什么东西吗?