代码之家  ›  专栏  ›  技术社区  ›  codex

为什么我训练的神经网络产生相同的输出

  •  1
  • codex  · 技术社区  · 6 年前

    我已经用Encog3.3训练了我的神经网络,有MLP,resilientProp(试用,因为BackProp的学习率和动量很难设置),10个输入(包括理想值),1个隐藏层,7个神经元,1个输出神经元,sigmoid激活,训练集约80k行,测试集约96行,错误率为0.01,0.007(我创建了2个模型,但只有2个不同的错误率,上面提到的所有其他设置都是相同的)。我还对数据进行了最小最大值的标准化。也许我的评估码错了?或者代码的某些部分?任何帮助都将不胜感激。

    完整代码:

    public class ANN
    {   
    //training
    //public final static String SQL = "SELECT load_input, day_of_week, weekend_day, type_of_day, week_num, time, day_date, month, year, ideal_value FROM sample WHERE (year,month,day_date,time) between (2012,4,1,1) and (2014,9,29, 96) ORDER BY ID";
    //testing
    public final static String SQL = "SELECT load_input, day_of_week, weekend_day, type_of_day, week_num, time, day_date, month, year, ideal_value FROM sample WHERE (year,month,day_date,time) between (2014,9,30,1) and (2014,9,30, 92) ORDER BY ID";
    //validation
    //public final static String SQL = "SELECT load_input, day_of_week, weekend_day, type_of_day, week_num, time, day_date, month, year, ideal_value FROM sample WHERE (year,month,day_date,time) between (2014,9,30,93) and (2014,9,30, 96) ORDER BY ID";
    public final static int INPUT_SIZE = 9;
    public final static int IDEAL_SIZE = 1;
    public final static String SQL_DRIVER = "org.postgresql.Driver";
    public final static String SQL_URL = "jdbc:postgresql://localhost/ANN";
    public final static String SQL_UID = "postgres";
    public final static String SQL_PWD = "";
    
    public static void main(String args[])
    {   
        Mynetwork();
        //train network. will add customizable params later.
        //train(trainingData());
        //evaluate network
        evaluate(trainingData());
        Encog.getInstance().shutdown();
    }
    public static void evaluate(MLDataSet testSet)
    {
        BasicNetwork network = (BasicNetwork)EncogDirectoryPersistence.loadObject(new File("directory"));
    
        // test the neural network
        System.out.println("Neural Network Results:");
        for(MLDataPair pair: testSet ) {
            final MLData output = network.compute(pair.getInput());
            System.out.println(pair.getInput().getData(0) + "," + pair.getInput().getData(1) + "," + pair.getInput().getData(2) + "," + pair.getInput().getData(3) + "," + pair.getInput().getData(4) + "," + pair.getInput().getData(5) + "," + pair.getInput().getData(6) + "," + pair.getInput().getData(7) + "," + pair.getInput().getData(8) + "," + "Predicted=" + output.getData(0) + ", Actual=" + pair.getIdeal().getData(0));
        }
    }
    public static BasicNetwork Mynetwork()
    {
        //basic neural network template. Inputs should'nt have activation functions
        //because it affects data coming from the previous layer and there is no previous layer before the input.
        BasicNetwork network = new BasicNetwork();
        //input layer with 2 neurons.
        //The 'true' parameter means that it should have a bias neuron. Bias neuron affects the next layer.
        network.addLayer(new BasicLayer(null , true, 9));
        //hidden layer with 3 neurons
        network.addLayer(new BasicLayer(new ActivationSigmoid(), true, 5));
        //output layer with 1 neuron
        network.addLayer(new BasicLayer(new ActivationSigmoid(), false, 1));
        network.getStructure().finalizeStructure() ;
        network.reset();
    
        return network;
    }
    public static void train(MLDataSet trainingSet)
    {
        //Backpropagation(network, dataset, learning rate, momentum)
        //final Backpropagation train = new Backpropagation(Mynetwork(), trainingSet, 0.1, 0.9);
        final ResilientPropagation train = new ResilientPropagation(Mynetwork(), trainingSet);
        //final QuickPropagation train = new QuickPropagation(Mynetwork(), trainingSet, 0.9);
    
        int epoch = 1;
    
        do {
            train.iteration();
            System.out.println("Epoch #" + epoch + " Error:" + train.getError());
            epoch++;
        } while((train.getError() > 0.01)); 
        System.out.println("Saving network");
        System.out.println("Saving Done");
        EncogDirectoryPersistence.saveObject(new File("directory"), Mynetwork());
    }
    public static MLDataSet trainingData()
    {
        MLDataSet trainingSet = new SQLNeuralDataSet(
                ANN.SQL,
                ANN.INPUT_SIZE,
                ANN.IDEAL_SIZE,
                ANN.SQL_DRIVER,
                ANN.SQL_URL,
                ANN.SQL_UID,
                ANN.SQL_PWD);
    
        return trainingSet;
    }
    

    以下是我的结果:

    Predicted=0.4451817588640455, Actual=0.5260616667545941
    Predicted=0.4451817588640455, Actual=0.5196499668339777
    Predicted=0.4451817588640455, Actual=0.5083828048375548
    Predicted=0.4451817588640455, Actual=0.49985462144799725
    Predicted=0.4451817588640455, Actual=0.49085956670499675
    Predicted=0.4451817588640455, Actual=0.485008112408512
    Predicted=0.4451817588640455, Actual=0.47800504210686795
    Predicted=0.4451817588640455, Actual=0.4693212349328293
    (...and so on with the same "predicted")
    

    结果我期待(为了演示的目的,我用一些随机的东西更改了“预测的”,表明网络实际上在预测):

    Predicted=0.4451817588640455, Actual=0.5260616667545941
    Predicted=0.5123312331212122, Actual=0.5196499668339777
    Predicted=0.435234234234254365, Actual=0.5083828048375548
    Predicted=0.673424556563455, Actual=0.49985462144799725
    Predicted=0.2344673345345544235, Actual=0.49085956670499675
    Predicted=0.123346457544324, Actual=0.485008112408512
    Predicted=0.5673452342342342, Actual=0.47800504210686795
    Predicted=0.678435234423423423, Actual=0.4693212349328293
    

    更新:

    0.5386671932975533,1100000.0,0.0,1.0,40.0,1.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.5260616667545941
    0.5260616667545941,1100000.0,0.0,1.0,40.0,2.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.5196499668339777
    0.5196499668339777,1100000.0,0.0,1.0,40.0,3.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.5083828048375548
    0.5083828048375548,1100000.0,0.0,1.0,40.0,4.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.49985462144799725
    0.49985462144799725,1100000.0,0.0,1.0,40.0,5.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.49085956670499675
    0.49085956670499675,1100000.0,0.0,1.0,40.0,6.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.485008112408512
    0.485008112408512,1100000.0,0.0,1.0,40.0,7.0,30.0,9.0,2014.0,Predicted=0.4451817588640455, Actual=0.47800504210686795
    
    1 回复  |  直到 6 年前
        1
  •  0
  •   codex    6 年前

    通过彻底规范化所有输入特性来修复它。我在想,这已经足够规范化你试图预测的主要输入,并保持影响它的因素不变。

    推荐文章