代码之家  ›  专栏  ›  技术社区  ›  mrgloom

如何在批处理生成器中使用模型?

  •  0
  • mrgloom  · 技术社区  · 6 年前

    我想用 model.predict 在批处理生成器中,实现这一点的可能方法是什么?

    class DataGenerator(keras.utils.Sequence):
        def __init__(self, model_name):
            # Load model
    
        # ...
    
        def on_epoch_end(self):
            # Load model
    
    1 回复  |  直到 6 年前
        1
  •  1
  •   Daniel Möller    6 年前

    根据我的经验,在训练时预测另一个模型会带来错误。

    您可能只需将培训模型附加到生成器模型之后。

    generator_model (the one you want to use inside the generator)    
    training_model (the one you want to train)
    

    那么

    generatorInput = Input(shapeOfTheGeneratorInput)
    generatorOutput = generator_model(generatorInput)
    trainingOutput = training_model(generatorOutput)
    
    entireModel = Model(generatorInput,trainingOutput)
    

    在编译之前,请确保生成器模型的所有层都不可处理:

    genModel = entireModel.layers[1]
    for l in genModel.layers:
        l.trainable = False
    
    entireModel.compile(optimizer=optimizer,loss=loss)
    

    现在定期使用发电机。


    发电机内部预测:

    class DataGenerator(keras.utils.Sequence):
    
        def __init__(self, model_name, modelInputs, batchSize):
            self.genModel = load_model(model_name)
            self.inputs = modelInputs
            self.batchSize = batchSize
    
    
        def __len__(self):
            l,rem = divmod(len(self.inputs), self.batchSize)
            return (l + (1 if rem > 0 else 0))
    
        def __getitem__(self,i):
    
            items = self.inputs[i*self.batchSize:(i+1)*self.batchSize]
            items = doThingsWithItems(items)
    
            predItems = self.genModel.predict_on_batch(items)
    
            #the following is the only reason not to chain models
            predItems = doMoreThingsWithItems(predItems)
    
            #do something to get Y_train_items as well
    
            return predItems, y_train_items
    

    for e in range(epochs):
        for i in range(batches):
            x,y = generator[i]
            model.train_on_batch(x,y)