代码之家  ›  专栏  ›  技术社区  ›  user9238790

通过循环从数据帧传递单行以进行预测

  •  1
  • user9238790  · 技术社区  · 8 年前

    我使用传递行索引 iloc 并使用指定位置 n . 相反,如何修改代码以从 class_zero ,并打印它的每个预测。

    import numpy as np
    import pandas as pd
    from sklearn.datasets import make_classification
    from sklearn.ensemble import RandomForestClassifier
    
    
    X, y = make_classification(n_samples=1000,
                               n_features=6,
                               n_informative=3,
                               n_classes=2,
                               random_state=0,
                               shuffle=False)
    
    # Creating a dataFrame
    df = pd.DataFrame({'Feature 1':X[:,0],
                                      'Feature 2':X[:,1],
                                      'Feature 3':X[:,2],
                                      'Feature 4':X[:,3],
                                      'Feature 5':X[:,4],
                                      'Feature 6':X[:,5],
                                      'Class':y})
    
    y_train = df['Class']
    X_train = df.drop('Class', axis=1)
    class_zero = df.loc[df['Class'] == 0]
    
    n = 5  #instead of specifying 5 which is where class_zero = 0, I want to pass directly the class_zero from the list I created
    #and print for each one
    
    rf = RandomForestClassifier()
    rf.fit(X_train, y_train)
    instances = X_train.iloc[n].values.reshape(1, -1)
    
    predictValue = rf.predict(instances)
    actualValue = y_train.iloc[n]
    
    print('##')
    print(n)
    print(predictValue)
    print(actualValue)
    print('##')
    
    1 回复  |  直到 8 年前
        1
  •  1
  •   Vivek Kumar    8 年前

    可以将class==0的行索引用作中的列表 iloc()

    更改class\u zero如下:

    class_zero = df.index[df['Class'] == 0].tolist()
    

    而你做的重塑是错误的。保持这样:

    instances = X_train.iloc[class_zero].values
    

    编辑以供评论:

    for n in class_zero:
        instances = X_train.iloc[n].values.reshape(1,-1)
    
        predictValue = rf.predict(instances)
        actualValue = y_train.iloc[n]
    
        print('##')
        print(n)
        print(predictValue)
        print(actualValue)
        print('##')
    
    推荐文章