代码之家  ›  专栏  ›  技术社区  ›  Atinesh

无法将数据输入tensorflow图

  •  0
  • Atinesh  · 技术社区  · 6 年前

    我已经在计算机上训练了一个神经网络模型 MNIST 使用脚本的数据集 mnist_3.1_convolutional_bigger_dropout.py 本条规定 tutorial

    我想在自定义数据集上测试经过训练的模型,因此我编写了一个小脚本 predict.py 加载经过训练的模型并向其提供数据。我尝试了两种预处理图像的方法,以便它们与MNIST格式兼容。

    • 方法1 :将图像大小调整为28x28
    • 方法2 :提到的技巧 here 使用

    这两种方法都会导致错误

    InvalidArgumentError(回溯见上文):必须为带有dtype float的占位符张量“占位符_2”输入一个值

    预测py

    # Importing libraries
    from scipy.misc import imread
    import tensorflow as tf
    import numpy as np
    import cv2 as cv
    import glob
    
    from test import imageprepare
    
    files = glob.glob('data2/*.*')
    #print(files)
    
    # Method 1
    '''
    img_data = []
    for fl in files:
        img = imageprepare(fl)
        img = img.reshape(img.shape[0], img.shape[1], 1)
        img_data.append(img)
    '''
    
    # Method 2
    dig_cont = [cv.imread(fl, 0) for fl in files]
    #print(len(dig_cont))
    
    img_data = []
    for i in range(len(dig_cont)):
        img = cv.resize(dig_cont[i], (28, 28))
        img = img.reshape(img.shape[0], img.shape[1], 1)
        img_data.append(img)
    
    
    print("Restoring Model ...")
    
    sess = tf.Session()
    
    # Step-1: Recreate the network graph. At this step only graph is created.
    tf_saver = tf.train.import_meta_graph('model/model.meta')
    
    # Step-2: Now let's load the weights saved using the restore method.
    tf_saver.restore(sess, tf.train.latest_checkpoint('model'))
    
    print("Model restored")
    
    x = tf.get_default_graph().get_tensor_by_name('X:0')
    print('x :', x.shape)
    y = tf.get_default_graph().get_tensor_by_name('Y:0')
    print('y :', y.shape)
    
    dict_data = {x: img_data}
    
    result = sess.run(y, feed_dict=dict_data)
    print(result)
    print(result.shape)
    
    sess.close()
    
    0 回复  |  直到 6 年前
        1
  •  0
  •   Atinesh    6 年前

    问题解决了,我忘了传递变量的值 pkeep .我必须做以下改变才能让它工作。

    dict_data = {x: img_data, pkeep: 1.0}
    

    而不是

    dict_data = {x: img_data}