代码之家  ›  专栏  ›  技术社区  ›  Rafael Díaz

列车xgboost模型不平衡数据

  •  -2
  • Rafael Díaz  · 技术社区  · 6 年前

    我是一个使用xgboost软件包的新手,我正在尝试创建一个具有最大准确性的模型,并且敏感性和特异性是平衡的。我的问题是,基地是不平衡的约1:3和预测给我一个非常低的灵敏度。

    data(cats,package = "MASS")
    prop.table(table(cats$Sex))
            F         M 
    0.3263889 0.6736111 
    

    library(ggplot2)
    ggplot(cats, aes(Bwt, Hwt, colour = Sex, shape = Sex)) + geom_point(size = 3)
    

    enter image description here

    加载xgboost包

    library(xgboost)
    
    # Split data
    set.seed(123);index <- sample(nrow(cats),size = nrow(cats)*0.75)
    train <- cats[index,]; test <- cats[-index,]
    
    train_x <- train; test_x <- test
    train_x$Sex <- NULL; test_x$Sex <- NULL
    
    # Convert predict variable in numeric
    y_train <- as.numeric(train$Sex) - 1; y_test <- as.numeric(test$Sex) - 1
    train_x[] <- sapply(train_x, as.numeric); test_x[] <- sapply(test_x, as.numeric)
    
    # Construct xgb.DMatrix object from either a dense matrix
    dtrain <- xgb.DMatrix(as.matrix(train_x),label = y_train)
    dtest <- xgb.DMatrix(as.matrix(test_x))
    

    ## xgboost parameters
    xgb_params <- list(booster = "gbtree" 
                    , objectve = "binary:logistic"
                    , eta=0.1 #default 0.3 [0,1]
                    , gamma=0
                    , max_depth=7 # default 6 Typical values: 3-10
                    , subsample=1
                    , tree_method = "exact"
                    , scale_pos_weight = 5
                    , base_score=median(y_train)
                    , seed = 2018)
    
    # tuning Cross Validation 
    xgbcv <- xgb.cv(params = xgb_params
                    , data = dtrain
                    , nrounds = 2000
                    , nfold = 7
                    , print_every_n = 5
                    , early_stopping_rounds = 40
                    , maximize = F
                    , prediction = F
                    , showsd = T
                    , metrics = "error")
    
    # train model
    gb_dt <- xgb.train(params = xgb_params
                    , data = dtrain
                    , nrounds = xgbcv$best_iteration
                    , print_every_n = 2
                    , early_stopping_rounds = 40
                    , maximize = F
                    , watchlist = list(train=dtrain))
    
    test_probs <- predict(gb_dt, dtest, type = "response")
    test_preds <- as.numeric(test_probs > .5)
    
    # Change predicted values to match original data set, check accuracy
    test_submit <- 0
    test_submit[test_preds==0] <- "F"
    test_submit[test_preds==1] <- "M"
    

    我计算混乱矩阵

    caret::confusionMatrix(as.factor(test_submit), test$Sex)
    Confusion Matrix and Statistics
    
              Reference
    Prediction  F  M
             F  7  0
             M  7 22
    
                   Accuracy : 0.8056          
                     95% CI : (0.6398, 0.9181)
        No Information Rate : 0.6111          
        P-Value [Acc > NIR] : 0.01065         
    
                      Kappa : 0.55            
     Mcnemar's Test P-Value : 0.02334         
    
                Sensitivity : 0.5000          
                Specificity : 1.0000          
             Pos Pred Value : 1.0000          
             Neg Pred Value : 0.7586          
                 Prevalence : 0.3889          
             Detection Rate : 0.1944          
       Detection Prevalence : 0.1944          
          Balanced Accuracy : 0.7500          
    
           'Positive' Class : F 
    

    scale_pos_weight 我不知道。欢迎任何建议。 注: 我的兴趣是有一个最佳平衡的模型。尽可能正确地分类最大数量的雌猫。

    1 回复  |  直到 6 年前
        1
  •  0
  •   Rafael Díaz    6 年前

    max_depth = 6 树的最大深度。现在使用以下代码:

    library(caret)
    k <- seq(0.5,0.99,0.01)
    Sensitivity <- rep(0,length(k))
    
    for(i in seq_along(k)){
      test_preds <- as.numeric(test_probs > k[i])
      test_submit <- 0
      test_submit[test_preds==0] <- "F"
      test_submit[test_preds==1] <- "M"
      tab <- confusionMatrix(as.factor(test_submit), test$Sex)
      Sensitivity[i] = tab$byClass[[1]]}
    
    max(Sensitivity)
    [1] 0.8571429
    pos <- which.max(Sensitivity)
    
    test_preds <- as.numeric(test_probs > k[pos])
    test_submit <- 0
    test_submit[test_preds==0] <- "F"
    test_submit[test_preds==1] <- "M"
    

    confusionMatrix(as.factor(test_submit), test$Sex)
    Confusion Matrix and Statistics
    
              Reference
    Prediction  F  M
             F 12  5
             M  2 17
    
                   Accuracy : 0.8056          
                     95% CI : (0.6398, 0.9181)
        No Information Rate : 0.6111          
        P-Value [Acc > NIR] : 0.01065         
    
                      Kappa : 0.6063          
     Mcnemar's Test P-Value : 0.44969         
    
                Sensitivity : 0.8571          
                Specificity : 0.7727          
             Pos Pred Value : 0.7059          
             Neg Pred Value : 0.8947          
                 Prevalence : 0.3889          
             Detection Rate : 0.3333          
       Detection Prevalence : 0.4722          
          Balanced Accuracy : 0.8149          
    
           'Positive' Class : F       
    

    注: 不过,价格的特殊性会降低。