pyspark.sql.functions.grouping_id

pyspark.sql.functions.grouping_id(*cols: ColumnOrName) → pyspark.sql.column.Column[source]

Aggregate function: returns the level of grouping, equals to

(grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn)

New in version 2.0.0.

Changed in version 3.4.0: Supports Spark Connect.

Parameters
colsColumn or str

columns to check for.

Returns
Column

returns level of the grouping it relates to.

Notes

The list of columns should match with grouping columns exactly, or empty (means all the grouping columns).

Examples

>>> df = spark.createDataFrame([(1, "a", "a"),
...                             (3, "a", "a"),
...                             (4, "b", "c")], ["c1", "c2", "c3"])
>>> df.cube("c2", "c3").agg(grouping_id(), sum("c1")).orderBy("c2", "c3").show()
+----+----+-------------+-------+
|  c2|  c3|grouping_id()|sum(c1)|
+----+----+-------------+-------+
|NULL|NULL|            3|      8|
|NULL|   a|            2|      4|
|NULL|   c|            2|      4|
|   a|NULL|            1|      4|
|   a|   a|            0|      4|
|   b|NULL|            1|      4|
|   b|   c|            0|      4|
+----+----+-------------+-------+