代码之家  ›  专栏  ›  技术社区  ›  twfx

创建一个热编码器。CountVectorizer返回arrayType错误(intergerType,true)

  •  0
  • twfx  · 技术社区  · 7 年前

    我尝试为以下输入数据创建一个热编码器:

    +------+--------------------+
    |userid|     categoryIndexes|
    +------+--------------------+
    | 24868|              [7276]|
    | 35335|             [12825]|
    | 42634| .    [14550, 14550]|
    | 51183|              [7570]|
    | 61065|             [14782]|
    | 70292|              [7282]|
    | 72326|      [14883, 14877]|
    | 96632|             [14902]|
    | 99703|             [14889]|
    |121994|       [16000, 7417]|
    |144782|      [12139, 12139]|
    |175886|        [7305, 7305]|
    |221451|      [14889, 12139]|
    |226945|             [18097]|
    |250401|              [7278]|
    |256892|        [7383, 5514]|
    |270043|              [7442]|
    |272338|              [7306]|
    |284802|      [18310, 14898]|
    +------+--------------------+
    

    提到 Aggregating a One-Hot Encoded feature in pyspark Encode and assemble multiple features in PySpark 我试着用

    from pyspark.ml.feature import CountVectorizer
    
    df_user_catlist = df_order.groupBy("userid").agg(F.collect_list('level3_cat').alias('categoryIndexes'))
    cv = CountVectorizer(inputCol='categoryIndexes', outputCol='categoryVec')
    transformed_df = cv.fit(df_user_catlist).transform(df_user_catlist)
    transformed_df.show()
    

    但发现了以下错误

    IllegalArgumentException: u'requirement failed: Column category must be of type equal to one of the following types: [ArrayType(StringType,true), ArrayType(StringType,false)] but was actually of type ArrayType(IntegerType,true).'
    

    我注意到区别在于输入数据是integertype而不是stringtype,我可以知道(a)如何将其转换为stringtype,或者有更好的方法将其转换为ohe吗?

    1 回复  |  直到 7 年前
        1
  •  2
  •   hamza tuna    7 年前

    需要将字符串强制转换为类别索引:

    from pyspark.sql import functions as F
    
    df_user_catlist = df_user_catlist \
        .withColumn('categoryIndexes', 
             F.col('categoryIndexes').cast('array<string>'))