我解决了。这是我的答案。这个解决方案的运行速度确实比我以前的解决方案快得多(<1/10),我的真实数据集上有这个问题。
我在reduce中避免了对driver和map以及数据集的合并。
val idPeersDS = Seq(
(1, Seq(1,2,3)),
(2, Seq(2,1,6)),
(3, Seq(3,1,2)),
(4, Seq(4,5,6)),
(5, Seq(5,4,6)),
(6, Seq(6,1,2))
).toDS.select($"_1" as "id", $"_2" as "peers")
val infoDS = Seq(
(1, "A", "X", 10),
(1, "A", "Y", 20),
(1, "B", "X", 30),
(1, "B", "Y", 40),
(2, "A", "Y", 10),
(2, "B", "X", 20),
(2, "B", "Y", 30),
(3, "A", "X", 40),
(4, "B", "Y", 10),
(5, "A", "X", 20),
(5, "B", "X", 30),
(6, "A", "Y", 40),
(6, "B", "Y", 10)
).toDS.select($"_1" as "id", $"_2" as "type1", $"_3" as "type2", $"_4" cast "double" as "metric")
// Exiting paste mode, now interpreting.
idPeersDS: org.apache.spark.sql.DataFrame = [id: int, peers: array<int>]
infoDS: org.apache.spark.sql.DataFrame = [id: int, type1: string ... 2 more fields]
scala> idPeersDS.show
+---+---------+
| id| peers|
+---+---------+
| 1|[1, 2, 3]|
| 2|[2, 1, 6]|
| 3|[3, 1, 2]|
| 4|[4, 5, 6]|
| 5|[5, 4, 6]|
| 6|[6, 1, 2]|
+---+---------+
scala> infoDS.show
+---+-----+-----+------+
| id|type1|type2|metric|
+---+-----+-----+------+
| 1| A| X| 10.0|
| 1| A| Y| 20.0|
| 1| B| X| 30.0|
| 1| B| Y| 40.0|
| 2| A| Y| 10.0|
| 2| B| X| 20.0|
| 2| B| Y| 30.0|
| 3| A| X| 40.0|
| 4| B| Y| 10.0|
| 5| A| X| 20.0|
| 5| B| X| 30.0|
| 6| A| Y| 40.0|
| 6| B| Y| 10.0|
+---+-----+-----+------+
scala> val infowithpeers = infoDS.join(idPeersDS, "id")
infowithpeers: org.apache.spark.sql.DataFrame = [id: int, type1: string ... 3 more fields]
scala> infowithpeers.show
+---+-----+-----+------+---------+
| id|type1|type2|metric| peers|
+---+-----+-----+------+---------+
| 1| A| X| 10.0|[1, 2, 3]|
| 1| A| Y| 20.0|[1, 2, 3]|
| 1| B| X| 30.0|[1, 2, 3]|
| 1| B| Y| 40.0|[1, 2, 3]|
| 2| A| Y| 10.0|[2, 1, 6]|
| 2| B| X| 20.0|[2, 1, 6]|
| 2| B| Y| 30.0|[2, 1, 6]|
| 3| A| X| 40.0|[3, 1, 2]|
| 4| B| Y| 10.0|[4, 5, 6]|
| 5| A| X| 20.0|[5, 4, 6]|
| 5| B| X| 30.0|[5, 4, 6]|
| 6| A| Y| 40.0|[6, 1, 2]|
| 6| B| Y| 10.0|[6, 1, 2]|
+---+-----+-----+------+---------+
scala> val joinMap = udf { values: Seq[Map[Int,Double]] => values.flatten.toMap }
joinMap: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,MapType(IntegerType,DoubleType,false),Some(List(ArrayType(MapType(IntegerType,DoubleType,false),true))))
scala> val zScoreCal = udf { (metric: Double, zScoreMetrics: WrappedArray[Double]) =>
| val ds = new DescriptiveStatistics(zScoreMetrics.toArray)
| val mean = ds.getMean()
| val sd = Math.sqrt(ds.getPopulationVariance())
| val zScore = if (sd == 0.0) {0.0} else {(metric - mean) / sd}
| zScore
| }
zScoreCal: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function2>,DoubleType,Some(List(DoubleType, ArrayType(DoubleType,false))))
scala> :paste
// Entering paste mode (ctrl-D to finish)
val infowithpeersidmetric = infowithpeers.withColumn("idmetric", map($"id",$"metric"))
val idsingrpdf = infowithpeersidmetric.groupBy("type1","type2").agg(joinMap(collect_list(map($"id", $"metric"))) as "idsingrp")
val metricsMap = udf { (peers: Seq[Int], values: Map[Int,Double]) => {
peers.map(p => values.getOrElse(p,0.0))
}
}
// Exiting paste mode, now interpreting.
infowithpeersidmetric: org.apache.spark.sql.DataFrame = [id: int, type1: string ... 4 more fields]
idsingrpdf: org.apache.spark.sql.DataFrame = [type1: string, type2: string ... 1 more field]
metricsMap: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function2>,ArrayType(DoubleType,false),Some(List(ArrayType(IntegerType,false), MapType(IntegerType,DoubleType,false))))
scala> val infoWithMap = infowithpeers.join(idsingrpdf, Seq("type1","type2")).withColumn("zScoreMetrics", metricsMap($"peers", $"idsingrp")).withColumn("zscore", round(zScoreCal($"metric",$"zScoreMetrics"),2))
infoWithMap: org.apache.spark.sql.DataFrame = [type1: string, type2: string ... 6 more fields]
scala> infoWithMap.show
+-----+-----+---+------+---------+--------------------+------------------+------+
|type1|type2| id|metric| peers| idsingrp| zScoreMetrics|zscore|
+-----+-----+---+------+---------+--------------------+------------------+------+
| A| X| 1| 10.0|[1, 2, 3]|[3 -> 40.0, 5 -> ...| [10.0, 0.0, 40.0]| -0.39|
| A| Y| 1| 20.0|[1, 2, 3]|[2 -> 10.0, 6 -> ...| [20.0, 10.0, 0.0]| 1.22|
| B| X| 1| 30.0|[1, 2, 3]|[1 -> 30.0, 2 -> ...| [30.0, 20.0, 0.0]| 1.07|
| B| Y| 1| 40.0|[1, 2, 3]|[4 -> 10.0, 1 -> ...| [40.0, 30.0, 0.0]| 0.98|
| A| Y| 2| 10.0|[2, 1, 6]|[2 -> 10.0, 6 -> ...|[10.0, 20.0, 40.0]| -1.07|
| B| X| 2| 20.0|[2, 1, 6]|[1 -> 30.0, 2 -> ...| [20.0, 30.0, 0.0]| 0.27|
| B| Y| 2| 30.0|[2, 1, 6]|[4 -> 10.0, 1 -> ...|[30.0, 40.0, 10.0]| 0.27|
| A| X| 3| 40.0|[3, 1, 2]|[3 -> 40.0, 5 -> ...| [40.0, 10.0, 0.0]| 1.37|
| B| Y| 4| 10.0|[4, 5, 6]|[4 -> 10.0, 1 -> ...| [10.0, 0.0, 10.0]| 0.71|
| A| X| 5| 20.0|[5, 4, 6]|[3 -> 40.0, 5 -> ...| [20.0, 0.0, 0.0]| 1.41|
| B| X| 5| 30.0|[5, 4, 6]|[1 -> 30.0, 2 -> ...| [30.0, 0.0, 0.0]| 1.41|
| A| Y| 6| 40.0|[6, 1, 2]|[2 -> 10.0, 6 -> ...|[40.0, 20.0, 10.0]| 1.34|
| B| Y| 6| 10.0|[6, 1, 2]|[4 -> 10.0, 1 -> ...|[10.0, 40.0, 30.0]| -1.34|
+-----+-----+---+------+---------+--------------------+------------------+------+