代码之家  ›  专栏  ›  技术社区  ›  wishihadabettername

如何在ML管道中访问底层模型的参数?

  •  2
  • wishihadabettername  · 技术社区  · 7 年前

    val lr = new LinearRegression()
    val lrModel = lr.fit(df)
    
    lrModel: org.apache.spark.ml.regression.LinearRegressionModel = linReg_b22a7bb88404
    
    println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
    Coefficients: [0.9705748115939526] Intercept: 0.31041486689532866
    

    但是,如果我在管道内使用它(如下面的简化示例中),

    val pipeline = new Pipeline().setStages(Array(lr))
    val lrModel = pipeline.fit(df)
    

    scala> lrModel
    res9: org.apache.spark.ml.PipelineModel = pipeline_99ca9cba48f8
    
    scala> println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
    <console>:68: error: value coefficients is not a member of org.apache.spark.ml.PipelineModel
           println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
                                             ^
    <console>:68: error: value intercept is not a member of org.apache.spark.ml.PipelineModel
           println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
    

    我理解这意味着什么(很明显,因为管道,我得到了一个不同的类),但不知道如何获得真正的底层模型。

    1 回复  |  直到 7 年前
        1
  •  9
  •   Jacek Laskowski    7 年前

    LinearRegressionModel stages 与对应的索引完全相同 LinearRegression .

    import org.apache.spark.ml.regressio‌​n.LinearRegressionMo‌​del
    lrModel.stages(0).asInstanceOf[LinearRegressionMo‌​del]