我正在使用Tensorflow框架进行分类预测。我的数据集包含大约1160个输出类。输出类值为6位数字。例如,789954。在使用Tensorflow对数据集进行训练和测试后,我得到了大约99%的准确率。
prediction=tf.argmax(logits,1)
print(prediction.eval(feed_dict={features : test_features, keep_prob: 1.0}))
prediction = np.asarray(prediction.eval(feed_dict={features : test_features, keep_prob: 1.0}))
prediction = np.reshape(prediction, (test_features.shape[0],1))
np.savetxt("prediction.csv", prediction, delimiter=",")
对于所有条目,csv文件中的结果值仅为0.00E+00。但我的期望是每个csv条目有6位代码。我想我的一个热门编码出了问题。
任何帮助都是值得赞赏的。
labels = tf.one_hot(labels, n_classes)
n_类=1160,所有值都是6位数字