pyspark.sql.functions.aggregate

pyspark.sql.functions.aggregate(col: ColumnOrName, initialValue: ColumnOrName, merge: Callable[[pyspark.sql.column.Column, pyspark.sql.column.Column], pyspark.sql.column.Column], finish: Optional[Callable[[pyspark.sql.column.Column], pyspark.sql.column.Column]] = None) → pyspark.sql.column.Column[source]

Applies a binary operator to an initial state and all elements in the array, and reduces this to a single state. The final state is converted into the final result by applying a finish function.

Both functions can use methods of Column, functions defined in pyspark.sql.functions and Scala UserDefinedFunctions. Python UserDefinedFunctions are not supported (SPARK-27052).

New in version 3.1.0.

Changed in version 3.4.0: Supports Spark Connect.

Parameters
colColumn or str

name of column or expression

initialValueColumn or str

initial value. Name of column or expression

mergefunction

a binary function (acc: Column, x: Column) -> Column... returning expression of the same type as zero

finishfunction

an optional unary function (x: Column) -> Column: ... used to convert accumulated value.

Returns
Column

final value after aggregate function is applied.

Examples

>>> df = spark.createDataFrame([(1, [20.0, 4.0, 2.0, 6.0, 10.0])], ("id", "values"))
>>> df.select(aggregate("values", lit(0.0), lambda acc, x: acc + x).alias("sum")).show()
+----+
| sum|
+----+
|42.0|
+----+
>>> def merge(acc, x):
...     count = acc.count + 1
...     sum = acc.sum + x
...     return struct(count.alias("count"), sum.alias("sum"))
>>> df.select(
...     aggregate(
...         "values",
...         struct(lit(0).alias("count"), lit(0.0).alias("sum")),
...         merge,
...         lambda acc: acc.sum / acc.count,
...     ).alias("mean")
... ).show()
+----+
|mean|
+----+
| 8.4|
+----+