代码之家  ›  专栏  ›  技术社区  ›  Ahmad

mnist:获取混淆矩阵

  •  0
  • Ahmad  · 技术社区  · 6 年前

    我试图得到mnist数据集的混淆矩阵。

    这是我的密码:

    mnist = tf.keras.datasets.mnist
    
    (x_train, y_train),(x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.tanh),
      tf.keras.layers.Dropout(0.2),
      tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
    
    model.fit(x_train, y_train, epochs=1, callbacks=[history])
    
    test_predictions = model.predict(x_test)
    
    
    # Compute confusion matrix
    confusion = tf.confusion_matrix(y_test, test_predictions)
    

    问题是 test_prediction 是10000x 10矩阵,而y_测试是10000x 10矩阵。实际上,神经网络并不能为每个测试样本提供输出。如何计算这种情况下的混淆矩阵?

    那么,我该如何呈现混淆矩阵呢?我可以为此使用SCI工具包库吗?

    2 回复  |  直到 6 年前
        1
  •  2
  •   Biswadip Mandal    6 年前

    这可能是因为您的预测包含所有可能类的概率。您需要选择概率最高的类,这将产生与y_测试相同的维度。可以使用numpy中的argmax()方法。它的工作原理如下:

    import numpy as np
    a = np.array([[0.9,0.1,0],[0.2,0.3,0.5],[0.4,0.6,0]])
    np.argmax(a, axis=0)
    array([0, 2, 1])
    

    您可以使用sklearn生成混淆矩阵。你的代码会变成这样

    from sklearn.metrics import confusion_matrix
    import numpy as np
    
    confusion = confusion_matrix(y_test, np.argmax(test_predictions,axis=1))
    
        2
  •  1
  •   druskacik    6 年前

    如果您使用.predict_Classes方法而不仅仅是predict,您将得到概率最高的类的向量。

    然后,您可以使用sklearn中的混淆矩阵。

    test_predictions = model.predict_classes(x_test)
    
    from sklearn.metrics import confusion_matrix
    
    cm = confusion_matrix(y_true = y_test, y_pred = test_predictions)
    print(cm)
    

    这里的测试预测形状是(10000,)。

    打印结果如下:

    array([[ 967,    1,    1,    2,    0,    1,    5,    0,    2,    1],
       [   0, 1126,    3,    1,    0,    1,    1,    0,    3,    0],
       [   3,    2, 1001,    8,    1,    0,    3,    6,    8,    0],
       [   0,    0,    1, 1002,    0,    1,    0,    1,    5,    0],
       [   3,    1,    2,    2,  955,    2,    6,    1,    3,    7],
       [   3,    1,    0,   37,    1,  833,    9,    0,    6,    2],
       [   4,    3,    1,    1,    1,    3,  941,    0,    4,    0],
       [   2,    9,    8,    5,    0,    0,    0,  988,    8,    8],
       [   3,    1,    3,   10,    3,    2,    2,    3,  946,    1],
       [   3,    8,    0,   10,    8,    8,    1,    4,    5,  962]],
      dtype=int64)