我在Scala中有一个Spark数据框,如下所示-
val df = Seq(
(0,0,0,0.0,0),
(1,0,0,0.1,1),
(0,1,0,0.11,1),
(0,0,1,0.12,1),
(1,1,0,0.24,2),
(1,0,1,0.27,2),
(0,1,1,0.3,2),
(1,1,1,0.4,3)
).toDF("A","B","C","rate","total")
下面是它的样子
scala> df.show
+---+---+---+----+-----+
| A| B| C|rate|total|
+---+---+---+----+-----+
| 0| 0| 0| 0.0| 0|
| 1| 0| 0| 0.1| 1|
| 0| 1| 0|0.11| 1|
| 0| 0| 1|0.12| 1|
| 1| 1| 0|0.24| 2|
| 1| 0| 1|0.27| 2|
| 0| 1| 1| 0.3| 2|
| 1| 1| 1| 0.4| 3|
+---+---+---+----+-----+
A、 在这种情况下,B和C是通道。0和1分别表示通道的缺失和存在。2^3显示数据帧中的8个组合,其中一列“total”给出这3个通道的行和。
这些信道发生的个别概率可通过以下公式给出-
scala> val oneChannelCase = df.filter($"total" === 1).toDF()
scala> oneChannelCase.show()
+---+---+---+----+-----+
| A| B| C|rate|total|
+---+---+---+----+-----+
| 1| 0| 0| 0.1| 1|
| 0| 1| 0|0.11| 1|
| 0| 0| 1|0.12| 1|
+---+---+---+----+-----+
然而,我只对这些通道的成对概率感兴趣,由-
scala> val probs = df.filter($"total" === 2).toDF()
scala> probs.show()
+---+---+---+----+-----+
| A| B| C|rate|total|
+---+---+---+----+-----+
| 1| 1| 0|0.24| 2|
| 1| 0| 1|0.27| 2|
| 0| 1| 1| 0.3| 2|
+---+---+---+----+-----+
我想做的是在这些“probs”数据框中添加3个新列,以显示各个概率。下面是我正在寻找的输出-
A B C rate prob_A prob_B prob_C
1 1 0 0.24 0.1 0.11 0
1 0 1 0.27 0.1 0 0.12
0 1 1 0.3 0 0.11 0.12
为了让事情更清楚,输出结果的第一行显示A=1,B=1,C=0。因此,A=0.1、B=0.11和C=0的个别概率分别附加到probs数据帧。类似地,对于第二行,A=1,B=0,C=1显示A=0.1,B=0和C=0的个别概率。12分别附加到probs数据帧。
这是我试过的-
scala> val channels = df.columns.filter(v => !(v.contains("rate") | v.contains("total")))
scala> val pivotedProb = channels.map(v => f"case when $v = 1 then rate else 0 end as prob_${v}")
scala> val param = pivotedProb.mkString(",")
scala> val probs = spark.sql(f"select *, $param from df")
scala> probs.show()
+---+---+---+----+-----+------+------+------+
| A| B| C|rate|total|prob_A|prob_B|prob_C|
+---+---+---+----+-----+------+------+------+
| 0| 0| 0| 0.0| 0| 0.0| 0.0| 0.0|
| 1| 0| 0| 0.1| 1| 0.1| 0.0| 0.0|
| 0| 1| 0|0.11| 1| 0.0| 0.11| 0.0|
| 0| 0| 1|0.12| 1| 0.0| 0.0| 0.12|
| 1| 1| 0|0.24| 2| 0.24| 0.24| 0.0|
| 1| 0| 1|0.27| 2| 0.27| 0.0| 0.27|
| 0| 1| 1| 0.3| 2| 0.0| 0.3| 0.3|
| 1| 1| 1| 0.4| 3| 0.4| 0.4| 0.4|
+---+---+---+----+-----+------+------+------+
这给了我错误的输出。
请帮忙。