我在玩
save
和
load
功能
pyspark.ml.classification
模型。我创建了一个
RandomForestClassifier
,将值设置为几个参数并调用
节约
分类器的方法。成功保存。没有问题。
from pyspark.ml.classification import RandomForestClassifier
# save
rf = RandomForestClassifier()
rf.setImpurity('entropy')
rf.setPredictionCol('predme')
rf.write().overwrite().save('rf_test')
然后我尝试重新加载它,但我注意到它的参数没有我在保存之前设置的值。下面是我尝试的代码
# load
rf2 = RandomForestClassifier()
rf2.load('rf_test')
print(rf2.getImpurity()) # returns gini
print(rf2.getPredictionCol()) # returns prediction
我想我对这段代码应该如何工作以及它实际如何工作的理解是有区别的。
我该怎么做才能以我保存的方式取回物体?
编辑
我尝试过这里提到的方法。但那不起作用。这是我试过的
from pyspark.ml.classification import RandomForestClassifier
rf = RandomForestClassifier()
rf.setImpurity('entropy')
rf.setPredictionCol('predme')
rf.write().overwrite().save('rf_test')
rf2 = RandomForestClassifier
rf2.load('rf_test')
print(rf2.getImpurity())
它返回了以下内容
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: getImpurity() missing 1 required positional argument: 'self'