Fonctions définies par l’utilisateur d’agrégation (UDAF)
S’applique à : Databricks Runtime
Les fonctions d’agrégation définies par l’utilisateur (UDAF) sont des routines programmables par l’utilisateur qui agissent sur plusieurs lignes à la fois et retournent une valeur agrégée unique. Cette documentation répertorie les classes requises pour créer et inscrire des fonctions d’agrégation définies par l’utilisateur. Elle contient également des exemples montrant comment définir et inscrire des fonctions d’agrégation définies par l’utilisateur dans Scala, et les appeler dans Spark SQL.
Aggrégateur
Syntaxe Aggregator[-IN, BUF, OUT]
Classe de base pour des agrégations définies par l’utilisateur qui peuvent être utilisées dans des opérations de jeu de données afin de prendre tous les éléments d’un groupe et de les réduire à une seule valeur.
IN : type d’entrée pour l’agrégation.
BUF : type de valeur intermédiaire de la réduction.
OUT : type du résultat de la sortie finale.
bufferEncoder: Encoder[BUF]
Encodeur pour le type valeur intermédiaire.
finish(reduction: BUF): OUT
Transformer la sortie de la réduction.
merge(b1: BUF, b2: BUF): BUF
Fusionner deux valeurs intermédiaires.
outputEncoder: Encoder[OUT]
Encodeur pour le type de valeur de la sortie finale.
reduce(b: BUF, a: IN): BUF
Agréger les valeur d’entrée
a
dans une valeur intermédiaire actuelle. Pour les performances, la fonction peut modifierb
et le retourner au lieu de construire un nouvel objet pourb
.zero: BUF
Valeur initiale du résultat intermédiaire pour cette agrégation.
Exemples
Fonctions d'agrégation définies par l'utilisateur de type sécurisé
Les agrégations définies par l’utilisateur pour des jeux de données fortement typés tournent autour de la classe abstraite Aggregator
.
Par exemple, une moyenne définie par l’utilisateur de type sécurisé peut se présenter comme suit :
Scala
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)
object MyAverage extends Aggregator[Employee, Average, Double] {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
def zero: Average = Average(0L, 0L)
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
def reduce(buffer: Average, employee: Employee): Average = {
buffer.sum += employee.salary
buffer.count += 1
buffer
}
// Merge two intermediate values
def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
// Transform the output of the reduction
def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
// The Encoder for the intermediate value type
val bufferEncoder: Encoder[Average] = Encoders.product
// The Encoder for the final output value type
val outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
Java
import java.io.Serializable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.TypedColumn;
import org.apache.spark.sql.expressions.Aggregator;
public static class Employee implements Serializable {
private String name;
private long salary;
// Constructors, getters, setters...
}
public static class Average implements Serializable {
private long sum;
private long count;
// Constructors, getters, setters...
}
public static class MyAverage extends Aggregator<Employee, Average, Double> {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
public Average zero() {
return new Average(0L, 0L);
}
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
public Average reduce(Average buffer, Employee employee) {
long newSum = buffer.getSum() + employee.getSalary();
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
}
// Merge two intermediate values
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
}
// Transform the output of the reduction
public Double finish(Average reduction) {
return ((double) reduction.getSum()) / reduction.getCount();
}
// The Encoder for the intermediate value type
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
}
// The Encoder for the final output value type
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class);
String path = "examples/src/main/resources/employees.json";
Dataset<Employee> ds = spark.read().format("json").load(path).as(employeeEncoder);
ds.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
MyAverage myAverage = new MyAverage();
// Convert the function to a `TypedColumn` and give it a name
TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
Dataset<Double> result = ds.select(averageSalary);
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
Fonctions d'agrégation définies par l'utilisateur non typées
Des agrégations typées telles que décrites ci-dessus peuvent également être inscrites en tant que fonctions d’agrégation définies par l’utilisateur non typées pour une utilisation avec trames de données. Par exemple, une moyenne définie par l’utilisateur pour des trames de données non typées peut se présenter comme suit :
Scala
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions
case class Average(var sum: Long, var count: Long)
object MyAverage extends Aggregator[Long, Average, Double] {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
def zero: Average = Average(0L, 0L)
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
def reduce(buffer: Average, data: Long): Average = {
buffer.sum += data
buffer.count += 1
buffer
}
// Merge two intermediate values
def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
// Transform the output of the reduction
def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
// The Encoder for the intermediate value type
val bufferEncoder: Encoder[Average] = Encoders.product
// The Encoder for the final output value type
val outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
// Register the function to access it
spark.udf.register("myAverage", functions.udaf(MyAverage))
val df = spark.read.format("json").load("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
Java
import java.io.Serializable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.functions;
public static class Average implements Serializable {
private long sum;
private long count;
// Constructors, getters, setters...
}
public static class MyAverage extends Aggregator<Long, Average, Double> {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
public Average zero() {
return new Average(0L, 0L);
}
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
public Average reduce(Average buffer, Long data) {
long newSum = buffer.getSum() + data;
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
}
// Merge two intermediate values
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
}
// Transform the output of the reduction
public Double finish(Average reduction) {
return ((double) reduction.getSum()) / reduction.getCount();
}
// The Encoder for the intermediate value type
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
}
// The Encoder for the final output value type
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
// Register the function to access it
spark.udf().register("myAverage", functions.udaf(new MyAverage(), Encoders.LONG()));
Dataset<Row> df = spark.read().format("json").load("examples/src/main/resources/employees.json");
df.createOrReplaceTempView("employees");
df.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
SQL
-- Compile and place UDAF MyAverage in a JAR file called `MyAverage.jar` in /tmp.
CREATE FUNCTION myAverage AS 'MyAverage' USING JAR '/tmp/MyAverage.jar';
SHOW USER FUNCTIONS;
+------------------+
| function|
+------------------+
| default.myAverage|
+------------------+
CREATE TEMPORARY VIEW employees
USING org.apache.spark.sql.json
OPTIONS (
path "examples/src/main/resources/employees.json"
);
SELECT * FROM employees;
+-------+------+
| name|salary|
+-------+------+
|Michael| 3000|
| Andy| 4500|
| Justin| 3500|
| Berta| 4000|
+-------+------+
SELECT myAverage(salary) as average_salary FROM employees;
+--------------+
|average_salary|
+--------------+
| 3750.0|
+--------------+