代码之家  ›  专栏  ›  技术社区  ›  Shlomi Schwartz

Keras-还原特定时间戳的LSTM隐藏状态

  •  1
  • Shlomi Schwartz  · 技术社区  · 6 年前

    LSTM - Making predictions on partial sequence ). 如前一个问题所述,我已经训练了 有状态的

    [Feature 1,Feature 2, .... ,Feature 3][Label 1]
    [Feature 1,Feature 2, .... ,Feature 3][Label 2]
    ...
    [Feature 1,Feature 2, .... ,Feature 3][Label 100]
    

    型号代码:

    def build_model(num_samples, num_features, is_training):
        model = Sequential()
        opt = optimizers.Adam(lr=0.0005, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0001)
    
        batch_size = None if is_training else 1
        stateful = False if is_training else True
        first_lstm = LSTM(32, batch_input_shape=(batch_size, num_samples, num_features),  return_sequences=True,
                          activation='tanh', stateful=stateful)
    
        model.add(first_lstm)
        model.add(LeakyReLU())
        model.add(Dropout(0.2))
        model.add(LSTM(16, return_sequences=True, activation='tanh', stateful=stateful))
        model.add(Dropout(0.2))
        model.add(LeakyReLU())
        model.add(LSTM(8, return_sequences=True, activation='tanh', stateful=stateful))
        model.add(LeakyReLU())
        model.add(Dense(1, activation='sigmoid'))
    
        if is_training:
            model.compile(loss='binary_crossentropy', optimizer=opt,
                          metrics=['accuracy', f1])
        return model
    

    当预测时,模型是 无国籍的

    [Feature 1,Feature 2, .... ,Feature 10][Label 1] -> (model) -> probability
    

    打电话 model.reset_states() 模型加工完一批100个样品后。模型运行良好,效果良好。

    多源


    我的问题:

    当我测试我的模型时,我可以控制样本的顺序,我可以确保样本来自同一个来源。i、 e前100个样本都来自源1,然后在调用 模型。重置状态()

    但是,在我的生产环境中,示例以异步方式到达,例如:

    先从源1采集3个样本,然后从源2采集2个样本

    enter image description here


    我的问题是:

    0 回复  |  直到 6 年前
        1
  •  2
  •   Roni Gadot    6 年前

    您可以获取并设置内部状态,如下所示:

    import keras.backend as K
    
    def get_states(model):
        return [K.get_value(s) for s,_ in model.state_updates]
    
    def set_states(model, states):
        for (d,_), s in zip(model.state_updates, states):
            K.set_value(d, s)