代码之家  ›  专栏  ›  技术社区  ›  bioblackgeorge

tidymols:loss_accuracy不提供可变重要性结果

  •  0
  • bioblackgeorge  · 技术社区  · 2 年前

    使用虹膜数据集,通过迭代搜索对knn分类器进行调整,以实现多重分类。但是,使用 loss accuracy 在里面 DALEX::model_parts() 对于可变重要性,提供空结果。

    如果有任何想法,我将不胜感激。非常感谢您的支持!

    library(tidyverse)
    library(tidymodels)
    library(DALEXtra)
    tidymodels_prefer()
    
    df <- iris 
    
    # split
    set.seed(2023)
    splits <- initial_split(df, strata = Species, prop = 4/5)
    df_train <- training(splits)
    df_test  <-  testing(splits)
    
    # workflow
    df_rec <- recipe(Species ~ ., data = df_train) 
    
    knn_model <- nearest_neighbor(neighbors = tune()) %>% 
      set_engine("kknn") %>% 
      set_mode("classification")
    
    df_wflow <- workflow() %>%
      add_model(knn_model) %>%
      add_recipe(df_rec) 
    
    # cross-validation
    set.seed(2023)
    knn_res <-
      df_wflow %>%
      tune_bayes(
        metrics = metric_set(accuracy),
        resamples = vfold_cv(df_train, strata = "Species", v = 2),
        control = control_bayes(verbose = TRUE, save_pred = TRUE))
    
    # fit
    best_k <- knn_res %>%
      select_best("accuracy")
    
    knn_mod <- df_wflow %>%
      finalize_workflow(best_k) %>%
      fit(df_train)
    
    # variable importance
    knn_exp <- explain_tidymodels(extract_fit_parsnip(knn_mod), 
                       data = df_rec %>% prep() %>% bake(new_data = NULL, all_predictors()),
                       y = df_train$Species)
    
    set.seed(2023)
    vip <- model_parts(knn_exp, type = "variable_importance", loss_function = loss_accuracy)
    plot(vip) # empty plot
    
    
    
    
    0 回复  |  直到 2 年前
        1
  •  1
  •   EmilHvitfeldt    2 年前

    你得到了 0 对于所有结果,因为根据{DALEX}的模型类型是 "multiclass" .

    如果类型为 "classification" .

    knn_exp$model_info$type
    #> [1] "multiclass"
    

    这意味着发生的预测将是预测的概率(这里我们得到1和0,因为建模非常过拟合)

    predicted <- knn_exp$predict_function(knn_exp$model, newdata = df_train)
    predicted
    #>      setosa versicolor virginica
    #> [1,]      1          0         0
    #> [2,]      1          0         0
    #> [3,]      1          0         0
    #> [4,]      1          0         0
    #> [5,]      1          0         0
    #> [6,]      1          0         0
    #> ...
    

    当您使用 loss_accuracy() 作为损失函数,它通过使用以下计算来实现

    loss_accuracy
    #> function (observed, predicted, na.rm = TRUE) 
    #> mean(observed == predicted, na.rm = na.rm)
    #> <bytecode: 0x159276bb8>
    #> <environment: namespace:DALEX>
    #> attr(,"loss_name")
    #> [1] "Accuracy"
    

    如果我们一步一步地进行计算,我们就能明白为什么这会成为一个问题。首先,我们定义 observed 作为结果因素

    observed <- df_train$Species
    observed
    #>   [1] setosa     setosa     setosa     setosa     setosa     setosa    
    #>   [7] setosa     setosa     setosa     setosa     setosa     setosa    
    #>  [13] setosa     setosa     setosa     setosa     setosa     setosa    
    #>  [19] setosa     setosa     setosa     setosa     setosa     setosa    
    #>  [25] setosa     setosa     setosa     setosa     setosa     setosa    
    #>  [31] setosa     setosa     setosa     setosa     setosa     setosa    
    #>  [37] setosa     setosa     setosa     setosa     versicolor versicolor
    #>  [43] versicolor versicolor versicolor versicolor versicolor versicolor
    #>  [49] versicolor versicolor versicolor versicolor versicolor versicolor
    #>  [55] versicolor versicolor versicolor versicolor versicolor versicolor
    #>  [61] versicolor versicolor versicolor versicolor versicolor versicolor
    #>  [67] versicolor versicolor versicolor versicolor versicolor versicolor
    #>  [73] versicolor versicolor versicolor versicolor versicolor versicolor
    #>  [79] versicolor versicolor virginica  virginica  virginica  virginica 
    #>  [85] virginica  virginica  virginica  virginica  virginica  virginica 
    #>  [91] virginica  virginica  virginica  virginica  virginica  virginica 
    #>  [97] virginica  virginica  virginica  virginica  virginica  virginica 
    #> [103] virginica  virginica  virginica  virginica  virginica  virginica 
    #> [109] virginica  virginica  virginica  virginica  virginica  virginica 
    #> [115] virginica  virginica  virginica  virginica  virginica  virginica 
    #> Levels: setosa versicolor virginica
    

    自从 观察 是因子向量,并且 predicted 是一个数字矩阵,我们得到的是一个逻辑矩阵 FALSE 因为这些值从不相同。

    head(observed == predicted)
    #>      setosa versicolor virginica
    #> [1,]  FALSE      FALSE     FALSE
    #> [2,]  FALSE      FALSE     FALSE
    #> [3,]  FALSE      FALSE     FALSE
    #> [4,]  FALSE      FALSE     FALSE
    #> [5,]  FALSE      FALSE     FALSE
    #> [6,]  FALSE      FALSE     FALSE
    

    所以当我们取平均值时,我们得到了预期 0 .

    mean(observed == predicted)
    #> [1] 0
    
    推荐文章