代码之家  ›  专栏  ›  技术社区  ›  ZK Zhao

pyspark:如何实现dataframe describe()和summary()。

  •  3
  • ZK Zhao  · 技术社区  · 6 年前

    我想知道怎么做 df.describe() df.summary() 已实现

    就像在 https://spark.apache.org/docs/latest/api/python/_modules/pyspark/sql/dataframe.html#DataFrame.summary

    def summary(self, *statistics):
        if len(statistics) == 1 and isinstance(statistics[0], list):
            statistics = statistics[0]
        jdf = self._jdf.summary(self._jseq(statistics))
        return DataFrame(jdf, self.sql_ctx)
    

    我不太熟悉Python中的oo,我有点困惑。分位数和其他统计数据在哪里实现?

    1 回复  |  直到 6 年前
        1
  •  3
  •   user8371915    6 年前
    • jdf 是对Java的引用 Dataset 通过py4j访问的对象。
    • python代码调用其 summary 方法:

      jdf = self._jdf.summary(self._jseq(statistics))
      
    • Dataset.summary calls StatFunctions.summary method

      def summary(statistics: String*): DataFrame = StatFunctions.summary(this, statistics.toSeq)
      
    • 哪个 is implemented 这样地:

      def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = {
      
      
        val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
        val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics
      
      
        val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p =>
          try {
            p.stripSuffix("%").toDouble / 100.0
          } catch {
            case e: NumberFormatException =>
              throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
          }
        }
        require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
      
      
        var percentileIndex = 0
        val statisticFns = selectedStatistics.map { stats =>
          if (stats.endsWith("%")) {
            val index = percentileIndex
            percentileIndex += 1
            (child: Expression) =>
              GetArrayItem(
                new ApproximatePercentile(child, Literal.create(percentiles)).toAggregateExpression(),
                Literal(index))
          } else {
            stats.toLowerCase(Locale.ROOT) match {
              case "count" => (child: Expression) => Count(child).toAggregateExpression()
              case "mean" => (child: Expression) => Average(child).toAggregateExpression()
              case "stddev" => (child: Expression) => StddevSamp(child).toAggregateExpression()
              case "min" => (child: Expression) => Min(child).toAggregateExpression()
              case "max" => (child: Expression) => Max(child).toAggregateExpression()
              case _ => throw new IllegalArgumentException(s"$stats is not a recognised statistic")
            }
          }
        }
      
      
        val selectedCols = ds.logicalPlan.output
          .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType])
      
      
        val aggExprs = statisticFns.flatMap { func =>
          selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name))
        }
      
      
        // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val.
        lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head
      
      
        // We will have one row for each selected statistic in the result.
        val result = Array.fill[InternalRow](selectedStatistics.length) {
          // each row has the statistic name, and statistic values of each selected column.
          new GenericInternalRow(selectedCols.length + 1)
        }
      
      
        var rowIndex = 0
        while (rowIndex < result.length) {
          val statsName = selectedStatistics(rowIndex)
          result(rowIndex).update(0, UTF8String.fromString(statsName))
          for (colIndex <- selectedCols.indices) {
            val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex)
            result(rowIndex).update(colIndex + 1, statsValue)
          }
          rowIndex += 1
        }
      
      
        // All columns are string type
        val output = AttributeReference("summary", StringType)() +:
          selectedCols.map(c => AttributeReference(c.name, StringType)())
      
      
        Dataset.ofRows(ds.sparkSession, LocalRelation(output, result))
      }