假设需要使用UDF是正确的。下面是一个在类似环境下如何工作的示例:
>>> import random
>>> from pyspark.sql.functions import udf
>>> centers = {1: 2, 2: 3, 3: 4, 4:5, 5:6}
>>> choices = [1, 2, 3, 4,5]
>>> l = [(random.random(), random.choice(choices)) for i in range(10)]
>>> df = spark.createDataFrame(df, ['features', 'prediction'])
>>> df.show()
+-------------------+----------+
| features|prediction|
+-------------------+----------+
| 0.4836744206538728| 3|
|0.38698675915124414| 4|
|0.18612684714681604| 3|
| 0.5056159922655895| 1|
| 0.7825023909896331| 4|
|0.49933715239708243| 5|
| 0.6673811293962939| 4|
| 0.7010166164833609| 3|
| 0.6867109795526414| 5|
|0.21975859257732422| 3|
+-------------------+----------+
>>> dist = udf(lambda features, prediction: features - centers[prediction])
>>> df.withColumn('dist', dist(df.features, df.prediction)).show()
+-------------------+----------+-------------------+
| features|prediction| dist|
+-------------------+----------+-------------------+
| 0.4836744206538728| 3| -3.516325579346127|
|0.38698675915124414| 4| -4.613013240848756|
|0.18612684714681604| 3| -3.813873152853184|
| 0.5056159922655895| 1|-1.4943840077344106|
| 0.7825023909896331| 4| -4.217497609010367|
|0.49933715239708243| 5| -5.500662847602918|
| 0.6673811293962939| 4|-4.3326188706037065|
| 0.7010166164833609| 3| -3.298983383516639|
| 0.6867109795526414| 5| -5.313289020447359|
|0.21975859257732422| 3| -3.780241407422676|
+-------------------+----------+-------------------+
您可以将我创建UDF的行更改为以下内容:
dist = udf(lambda features, prediction: features.squared_distance(model.clusterCenters()[prediction]))
由于我没有实际的数据来处理,我希望这是正确的!