Поделиться через


Определяемые пользователем агрегатные функции (UDAFs)

Область применения:проверка помечено да Databricks Runtime

Определяемые пользователем агрегатные функции (UDAFs) — это программируемые пользователем подпрограммы, которые работают с несколькими строками одновременно и возвращают в результате одно агрегированное значение. В этой документации перечислены классы, необходимые для создания и регистрации UDAFS. Он также содержит примеры, демонстрирующие определение и регистрацию UDAFS в Scala и их вызов в Spark SQL.

Агрегатор

СинтаксисAggregator[-IN, BUF, OUT]

Базовый класс для определяемых пользователем агрегатов, который можно использовать в операциях набора данных для получения всех элементов группы и их уменьшения до одного значения.

  • IN: тип входных данных для агрегирования.

  • BUF: тип промежуточного значения сокращения.

  • OUT: тип конечного выходного результата.

  • bufferEncoder: Encoder[BUF]

    Кодировщик для промежуточного типа значения.

  • finish(reduction: BUF): OUT

    Преобразование выходных данных сокращения.

  • merge(b1: BUF, b2: BUF): BUF

    Объединение двух промежуточных значений.

  • outputEncoder: Encoder[OUT]

    Кодировщик для конечного типа выходного значения.

  • reduce(b: BUF, a: IN): BUF

    Агрегировать входное значение a в текущее промежуточное значение. Для повышения производительности функция может изменять b и возвращать ее вместо создания нового объекта для b.

  • нуль: BUF

    Начальное значение промежуточного результата для этого агрегирования.

Примеры

Типобезопасные пользовательские агрегатные функции

Определяемые пользователем агрегаты для строго типизированных наборов данных вращаются вокруг абстрактного Aggregator класса. Например, типобезопасное среднее значение, определяемое пользователем, может выглядеть следующим образом:

Нетипизированные определяемые пользователем агрегатные функции

Типизированные агрегаты, как описано выше, также могут быть зарегистрированы как нетипизированные статистические определяемые пользователем функции для использования с кадрами данных. Например, определяемое пользователем среднее значение для нетипизированных кадров данных может выглядеть следующим образом:

Scala

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions

case class Average(var sum: Long, var count: Long)

object MyAverage extends Aggregator[Long, Average, Double] {
  // A zero value for this aggregation. Should satisfy the property that any b + zero = b
  def zero: Average = Average(0L, 0L)
  // Combine two values to produce a new value. For performance, the function may modify `buffer`
  // and return it instead of constructing a new object
  def reduce(buffer: Average, data: Long): Average = {
    buffer.sum += data
    buffer.count += 1
    buffer
  }
  // Merge two intermediate values
  def merge(b1: Average, b2: Average): Average = {
    b1.sum += b2.sum
    b1.count += b2.count
    b1
  }
  // Transform the output of the reduction
  def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
  // The Encoder for the intermediate value type
  val bufferEncoder: Encoder[Average] = Encoders.product
  // The Encoder for the final output value type
  val outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

// Register the function to access it
spark.udf.register("myAverage", functions.udaf(MyAverage))

val df = spark.read.format("json").load("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// |   name|salary|
// +-------+------+
// |Michael|  3000|
// |   Andy|  4500|
// | Justin|  3500|
// |  Berta|  4000|
// +-------+------+

val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// |        3750.0|
// +--------------+

Java

import java.io.Serializable;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.functions;

public static class Average implements Serializable  {
    private long sum;
    private long count;

    // Constructors, getters, setters...

}

public static class MyAverage extends Aggregator<Long, Average, Double> {
  // A zero value for this aggregation. Should satisfy the property that any b + zero = b
  public Average zero() {
    return new Average(0L, 0L);
  }
  // Combine two values to produce a new value. For performance, the function may modify `buffer`
  // and return it instead of constructing a new object
  public Average reduce(Average buffer, Long data) {
    long newSum = buffer.getSum() + data;
    long newCount = buffer.getCount() + 1;
    buffer.setSum(newSum);
    buffer.setCount(newCount);
    return buffer;
  }
  // Merge two intermediate values
  public Average merge(Average b1, Average b2) {
    long mergedSum = b1.getSum() + b2.getSum();
    long mergedCount = b1.getCount() + b2.getCount();
    b1.setSum(mergedSum);
    b1.setCount(mergedCount);
    return b1;
  }
  // Transform the output of the reduction
  public Double finish(Average reduction) {
    return ((double) reduction.getSum()) / reduction.getCount();
  }
  // The Encoder for the intermediate value type
  public Encoder<Average> bufferEncoder() {
    return Encoders.bean(Average.class);
  }
  // The Encoder for the final output value type
  public Encoder<Double> outputEncoder() {
    return Encoders.DOUBLE();
  }
}

// Register the function to access it
spark.udf().register("myAverage", functions.udaf(new MyAverage(), Encoders.LONG()));

Dataset<Row> df = spark.read().format("json").load("examples/src/main/resources/employees.json");
df.createOrReplaceTempView("employees");
df.show();
// +-------+------+
// |   name|salary|
// +-------+------+
// |Michael|  3000|
// |   Andy|  4500|
// | Justin|  3500|
// |  Berta|  4000|
// +-------+------+

Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
result.show();
// +--------------+
// |average_salary|
// +--------------+
// |        3750.0|
// +--------------+

SQL

-- Compile and place UDAF MyAverage in a JAR file called `MyAverage.jar` in /tmp.
CREATE FUNCTION myAverage AS 'MyAverage' USING JAR '/tmp/MyAverage.jar';

SHOW USER FUNCTIONS;
+------------------+
|          function|
+------------------+
| default.myAverage|
+------------------+

CREATE TEMPORARY VIEW employees
USING org.apache.spark.sql.json
OPTIONS (
    path "examples/src/main/resources/employees.json"
);

SELECT * FROM employees;
+-------+------+
|   name|salary|
+-------+------+
|Michael|  3000|
|   Andy|  4500|
| Justin|  3500|
|  Berta|  4000|
+-------+------+

SELECT myAverage(salary) as average_salary FROM employees;
+--------------+
|average_salary|
+--------------+
|        3750.0|
+--------------+