代码之家  ›  专栏  ›  技术社区  ›  A T

使用tf。Keras中的指标?

  •  3
  • A T  · 技术社区  · 8 年前

    我特别感兴趣的是 specificity_at_sensitivity .查看 the Keras docs :

    from keras import metrics
    
    model.compile(loss='mean_squared_error',
                  optimizer='sgd',
                  metrics=[metrics.mae, metrics.categorical_accuracy])
    

    但它看起来像 metrics 列表必须具有arity 2的功能,接受 (y_true, y_pred) 返回一个张量值。


    编辑:目前我是这样做的:

    from sklearn.metrics import confusion_matrix
    
    predictions = model.predict(x_test)
    y_test = np.argmax(y_test, axis=-1)
    predictions = np.argmax(predictions, axis=-1)
    c = confusion_matrix(y_test, predictions)
    print('Confusion matrix:\n', c)
    print('sensitivity', c[0, 0] / (c[0, 1] + c[0, 0]))
    print('specificity', c[1, 1] / (c[1, 1] + c[1, 0]))
    

    这种方法的缺点是,我只在培训结束时才得到我关心的输出。更愿意每10个时代左右获取一次指标。

    2 回复  |  直到 8 年前
        1
  •  2
  •   rvinas    8 年前

    我发现了一个相关的问题 github ,看来 tf.metrics Keras模型仍然不支持。但是,如果您对使用 tf.metrics.specificity_at_sensitivity ,我建议以下解决方法(灵感来自 BogdanRuzh's 解决方案):

    def specificity_at_sensitivity(sensitivity, **kwargs):
        def metric(labels, predictions):
            # any tensorflow metric
            value, update_op = tf.metrics.specificity_at_sensitivity(labels, predictions, sensitivity, **kwargs)
    
            # find all variables created for this metric
            metric_vars = [i for i in tf.local_variables() if 'specificity_at_sensitivity' in i.name.split('/')[2]]
    
            # Add metric variables to GLOBAL_VARIABLES collection.
            # They will be initialized for new session.
            for v in metric_vars:
                tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, v)
    
            # force to update metric values
            with tf.control_dependencies([update_op]):
                value = tf.identity(value)
                return value
        return metric
    
    
    model.compile(loss='mean_squared_error',
                  optimizer='sgd',
                  metrics=[metrics.mae,
                           metrics.categorical_accuracy,
                           specificity_at_sensitivity(0.5)])
    

    更新:

    你可以用 model.evaluate 在培训后检索指标。

        2
  •  0
  •   Wendong Zheng    8 年前

    我不认为只有两个传入参数有严格的限制。py函数只有三个传入参数,但k选择默认值5。

    def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
        return K.mean(K.in_top_k(y_pred, K.cast(K.max(y_true, axis=-1), 'int32'), k), axis=-1)