代码之家  ›  专栏  ›  技术社区  ›  Regressor

使用Scala连接spark数据帧中的数据

  •  0
  • Regressor  · 技术社区  · 7 年前

    我在Scala中有一个Spark数据框,如下所示-

    val df = Seq(
    (0,0,0,0.0,0),
    (1,0,0,0.1,1),
    (0,1,0,0.11,1),
    (0,0,1,0.12,1),
    (1,1,0,0.24,2),
    (1,0,1,0.27,2),
    (0,1,1,0.3,2),
    (1,1,1,0.4,3)
    ).toDF("A","B","C","rate","total")
    

    下面是它的样子

    scala> df.show
    +---+---+---+----+-----+
    |  A|  B|  C|rate|total|
    +---+---+---+----+-----+
    |  0|  0|  0| 0.0|    0|
    |  1|  0|  0| 0.1|    1|
    |  0|  1|  0|0.11|    1|
    |  0|  0|  1|0.12|    1|
    |  1|  1|  0|0.24|    2|
    |  1|  0|  1|0.27|    2|
    |  0|  1|  1| 0.3|    2|
    |  1|  1|  1| 0.4|    3|
    +---+---+---+----+-----+
    

    A、 在这种情况下,B和C是通道。0和1分别表示通道的缺失和存在。2^3显示数据帧中的8个组合,其中一列“total”给出这3个通道的行和。

    这些信道发生的个别概率可通过以下公式给出-

    scala> val oneChannelCase = df.filter($"total" === 1).toDF()
    
    scala> oneChannelCase.show()
    +---+---+---+----+-----+
    |  A|  B|  C|rate|total|
    +---+---+---+----+-----+
    |  1|  0|  0| 0.1|    1|
    |  0|  1|  0|0.11|    1|
    |  0|  0|  1|0.12|    1|
    +---+---+---+----+-----+
    

    然而,我只对这些通道的成对概率感兴趣,由-

    scala> val probs = df.filter($"total" === 2).toDF()
    
    scala> probs.show()
    +---+---+---+----+-----+
    |  A|  B|  C|rate|total|
    +---+---+---+----+-----+
    |  1|  1|  0|0.24|    2|
    |  1|  0|  1|0.27|    2|
    |  0|  1|  1| 0.3|    2|
    +---+---+---+----+-----+
    

    我想做的是在这些“probs”数据框中添加3个新列,以显示各个概率。下面是我正在寻找的输出-

    A   B   C   rate    prob_A   prob_B   prob_C
    1   1   0   0.24      0.1      0.11      0
    1   0   1   0.27      0.1      0         0.12                     
    0   1   1   0.3       0        0.11      0.12 
    

    为了让事情更清楚,输出结果的第一行显示A=1,B=1,C=0。因此,A=0.1、B=0.11和C=0的个别概率分别附加到probs数据帧。类似地,对于第二行,A=1,B=0,C=1显示A=0.1,B=0和C=0的个别概率。12分别附加到probs数据帧。

    这是我试过的-

    scala> val channels = df.columns.filter(v => !(v.contains("rate") |  v.contains("total")))
    #channels: Array[String] = Array(A, B, C)
    
    
    scala> val pivotedProb = channels.map(v => f"case when $v = 1 then rate else 0 end as prob_${v}")
    
    scala> val param = pivotedProb.mkString(",")
    
    scala> val probs = spark.sql(f"select *, $param from df")
    
    scala> probs.show()
    +---+---+---+----+-----+------+------+------+
    |  A|  B|  C|rate|total|prob_A|prob_B|prob_C|
    +---+---+---+----+-----+------+------+------+
    |  0|  0|  0| 0.0|    0|   0.0|   0.0|   0.0|
    |  1|  0|  0| 0.1|    1|   0.1|   0.0|   0.0|
    |  0|  1|  0|0.11|    1|   0.0|  0.11|   0.0|
    |  0|  0|  1|0.12|    1|   0.0|   0.0|  0.12|
    |  1|  1|  0|0.24|    2|  0.24|  0.24|   0.0|
    |  1|  0|  1|0.27|    2|  0.27|   0.0|  0.27|
    |  0|  1|  1| 0.3|    2|   0.0|   0.3|   0.3|
    |  1|  1|  1| 0.4|    3|   0.4|   0.4|   0.4|
    +---+---+---+----+-----+------+------+------+
    

    这给了我错误的输出。

    请帮忙。

    1 回复  |  直到 7 年前
        1
  •  2
  •   Leo C    7 年前

    如果我正确理解您的要求,请使用 foldLeft 要遍历通道列,可以1)生成 ratesMap 从单通道数据帧,和,2)向双通道数据帧添加列,列值等于通道和相应 费率MAP 值:

    val df = Seq(
      (0, 0, 0, 0.0, 0),
      (1, 0, 0, 0.1, 1),
      (0, 1, 0, 0.11, 1),
      (0, 0, 1, 0.12, 1),
      (1, 1, 0, 0.24, 2),
      (1, 0, 1, 0.27, 2),
      (0, 1, 1, 0.3, 2),
      (1, 1, 1, 0.4, 3)
    ).toDF("A", "B", "C", "rate", "total")
    
    val oneChannelDF = df.filter($"total" === 1)
    val twoChannelDF = df.filter($"total" === 2)
    
    val channels = df.columns.filter(v => !(v.contains("rate") || v.contains("total")))
    // channels: Array[String] = Array(A, B, C)
    
    val ratesMap = channels.foldLeft( Map[String, Double]() ){ (acc, c) =>
      acc + (c -> oneChannelDF.select("rate").where(col(c) === 1).head.getDouble(0))
    }
    // ratesMap: scala.collection.immutable.Map[String,Double] = Map(A -> 0.1, B -> 0.11, C -> 0.12)
    
    val probsDF = channels.foldLeft( twoChannelDF ){ (acc, c) =>
      acc.withColumn( "prob_" + c, col(c) * ratesMap.getOrElse(c, 0.0) )
    }
    
    probsDF.show
    // +---+---+---+----+-----+------+------+------+
    // |  A|  B|  C|rate|total|prob_A|prob_B|prob_C|
    // +---+---+---+----+-----+------+------+------+
    // |  1|  1|  0|0.24|    2|   0.1|  0.11|   0.0|
    // |  1|  0|  1|0.27|    2|   0.1|   0.0|  0.12|
    // |  0|  1|  1| 0.3|    2|   0.0|  0.11|  0.12|
    // +---+---+---+----+-----+------+------+------+