代码之家  ›  专栏  ›  技术社区  ›  mrbrahman

Spark UDAF泛型类型不匹配

  •  0
  • mrbrahman  · 技术社区  · 8 年前

    我试图在Spark(2.0.1,Scala 2.11)上创建UDAF,如下所示。这本质上是聚合元组并输出 Map

    import org.apache.spark.sql.expressions._
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.functions.udf
    import org.apache.spark.sql.{Row, Column}
    
    class mySumToMap[K, V](keyType: DataType, valueType: DataType) extends UserDefinedAggregateFunction {
      override def inputSchema = new StructType()
        .add("a_key", keyType)
        .add("a_value", valueType)
    
      override def bufferSchema = new StructType()
        .add("buffer_map", MapType(keyType, valueType))
    
      override def dataType = MapType(keyType, valueType)
    
      override def deterministic = true 
    
      override def initialize(buffer: MutableAggregationBuffer) = {
        buffer(0) = Map[K, V]()
      }
    
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    
        // input :: 0 = a_key (k), 1 = a_value
        if ( !(input.isNullAt(0)) ) {
    
          val a_map = buffer(0).asInstanceOf[Map[K, V]]
          val k = input.getAs[K](0)  // get the value of position 0 of the input as string (a_key)
    
          // I've split these on purpose to show that return values are all of type V
          val new_v1: V = a_map.getOrElse(k, 0.asInstanceOf[V])
          val new_v2: V = input.getAs[V](1)
          val new_v: V = new_v1 + new_v2
    
          buffer(0) = if (new_v != 0) a_map + (k -> new_v) else a_map - k
        }
      }
    
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
        val map1: Map[K, V] = buffer1(0).asInstanceOf[Map[K, V]]
        val map2: Map[K, V] = buffer2(0).asInstanceOf[Map[K, V]]
    
        buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k, 0.asInstanceOf[V])) }
      }
    
      override def evaluate(buffer: Row) = buffer(0).asInstanceOf[Map[K, V]]
    
    }
    

    <console>:74: error: type mismatch;
     found   : V
     required: String
                 val new_v: V = new_v1 + new_v2
                                         ^
    <console>:84: error: type mismatch;
     found   : V
     required: String
               buffer1(0) = map1 ++ map2.map{ case (k,v) => k -> (v + map1.getOrElse(k, 0.asInstanceOf[V])) }
    

    我做错了什么?

    对于那些将此标记为 Spark UDAF - using generics as input type? -这不是该问题的重复,因为该问题不涉及 数据类型。关于使用Map数据类型所面临的问题,上述代码非常具体和完整。

    1 回复  |  直到 8 年前
        1
  •  2
  •   Alper t. Turker    8 年前

    Numeric[_]

    class mySumToMap[K, V: Numeric](keyType: DataType, valueType: DataType) 
      extends UserDefinedAggregateFunction {
        ...
    

    使用 Implicitly 要在运行时获取它:

    val n = implicitly[Numeric[V]]
    

    并使用其 plus 方法代替 + zero 代替 0

    buffer1(0) = map1 ++ map2.map{ 
      case (k,v) => k -> n.plus(v,  map1.getOrElse(k, n.zero))
    }
    

    为了支持更广泛的类型集,您可以使用 cats Monoid

    import cats._
    import cats.implicits._
    

    class mySumToMap[K, V: Monoid](keyType: DataType, valueType: DataType) 
      extends UserDefinedAggregateFunction {
        ...
    

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      val map1: Map[K, V] = buffer1.getMap[K, V](0)
      val map2: Map[K, V] = buffer2.getMap[K, V](0)
    
      val m = implicitly[Monoid[Map[K, V]]]
    
      buffer1(0) = m.combine(map1, map2)
    }
    
    推荐文章