pyspark.sql.GroupedData.applyInArrow#
- GroupedData.applyInArrow(func, schema)#
Maps each group of the current
DataFrame
using an Arrow udf and returns the result as a DataFrame.The function should take a pyarrow.Table and return another pyarrow.Table. Alternatively, the user can pass a function that takes a tuple of pyarrow.Scalar grouping key(s) and a pyarrow.Table. For each group, all columns are passed together as a pyarrow.Table to the user-function and the returned pyarrow.Table are combined as a
DataFrame
.The schema should be a
StructType
describing the schema of the returned pyarrow.Table. The column labels of the returned pyarrow.Table must either match the field names in the defined schema if specified as strings, or match the field data types by position if not strings, e.g. integer indices. The length of the returned pyarrow.Table can be arbitrary.New in version 4.0.0.
- Parameters
- funcfunction
a Python native function that takes a pyarrow.Table and outputs a pyarrow.Table, or that takes one tuple (grouping keys) and a pyarrow.Table and outputs a pyarrow.Table.
- schema
pyspark.sql.types.DataType
or str the return type of the func in PySpark. The value can be either a
pyspark.sql.types.DataType
object or a DDL-formatted type string.
See also
Notes
This function requires a full shuffle. All the data of a group will be loaded into memory, so the user should be aware of the potential OOM risk if data is skewed and certain groups are too large to fit in memory.
This API is unstable, and for developers.
Examples
>>> from pyspark.sql.functions import ceil >>> import pyarrow >>> import pyarrow.compute as pc >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) >>> def normalize(table): ... v = table.column("v") ... norm = pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, ddof=1)) ... return table.set_column(1, "v", norm) >>> df.groupby("id").applyInArrow( ... normalize, schema="id long, v double").show() +---+-------------------+ | id| v| +---+-------------------+ | 1|-0.7071067811865475| | 1| 0.7071067811865475| | 2|-0.8320502943378437| | 2|-0.2773500981126146| | 2| 1.1094003924504583| +---+-------------------+
Alternatively, the user can pass a function that takes two arguments. In this case, the grouping key(s) will be passed as the first argument and the data will be passed as the second argument. The grouping key(s) will be passed as a tuple of Arrow scalars types, e.g., pyarrow.Int32Scalar and pyarrow.FloatScalar. The data will still be passed in as a pyarrow.Table containing all columns from the original Spark DataFrame. This is useful when the user does not want to hardcode grouping key(s) in the function.
>>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) >>> def mean_func(key, table): ... # key is a tuple of one pyarrow.Int64Scalar, which is the value ... # of 'id' for the current group ... mean = pc.mean(table.column("v")) ... return pyarrow.Table.from_pydict({"id": [key[0].as_py()], "v": [mean.as_py()]}) >>> df.groupby('id').applyInArrow( ... mean_func, schema="id long, v double") +---+---+ | id| v| +---+---+ | 1|1.5| | 2|6.0| +---+---+
>>> def sum_func(key, table): ... # key is a tuple of two pyarrow.Int64Scalars, which is the values ... # of 'id' and 'ceil(df.v / 2)' for the current group ... sum = pc.sum(table.column("v")) ... return pyarrow.Table.from_pydict({ ... "id": [key[0].as_py()], ... "ceil(v / 2)": [key[1].as_py()], ... "v": [sum.as_py()] ... }) >>> df.groupby(df.id, ceil(df.v / 2)).applyInArrow( ... sum_func, schema="id long, `ceil(v / 2)` long, v double").show() +---+-----------+----+ | id|ceil(v / 2)| v| +---+-----------+----+ | 2| 5|10.0| | 1| 1| 3.0| | 2| 3| 5.0| | 2| 2| 3.0| +---+-----------+----+