#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from typing import Union, TYPE_CHECKING
from pyspark.rdd import PythonEvalType
from pyspark.sql.types import StructType
if TYPE_CHECKING:
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas._typing import PandasMapIterFunction, ArrowMapIterFunction
class PandasMapOpsMixin:
"""
Min-in for pandas map operations. Currently, only :class:`DataFrame`
can use this class.
"""
def mapInPandas(
self, func: "PandasMapIterFunction", schema: Union[StructType, str], barrier: bool = False
) -> "DataFrame":
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
function that takes and outputs a pandas DataFrame, and returns the result as a
:class:`DataFrame`.
The function should take an iterator of `pandas.DataFrame`\\s and return
another iterator of `pandas.DataFrame`\\s. All columns are passed
together as an iterator of `pandas.DataFrame`\\s to the function and the
returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
Each `pandas.DataFrame` size can be controlled by
`spark.sql.execution.arrow.maxRecordsPerBatch`. The size of the function's input and
output can be different.
.. versionadded:: 3.0.0
.. versionchanged:: 3.4.0
Supports Spark Connect.
Parameters
----------
func : function
a Python native function that takes an iterator of `pandas.DataFrame`\\s, and
outputs an iterator of `pandas.DataFrame`\\s.
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
barrier : bool, optional, default False
Use barrier mode execution.
.. versionadded: 3.5.0
Examples
--------
>>> from pyspark.sql.functions import pandas_udf
>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
>>> def filter_func(iterator):
... for pdf in iterator:
... yield pdf[pdf.id == 1]
...
>>> df.mapInPandas(filter_func, df.schema).show() # doctest: +SKIP
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+
Set ``barrier`` to ``True`` to force the ``mapInPandas`` stage running in the
barrier mode, it ensures all Python workers in the stage will be
launched concurrently.
>>> df.mapInPandas(filter_func, df.schema, barrier=True).show() # doctest: +SKIP
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+
Notes
-----
This API is experimental
See Also
--------
pyspark.sql.functions.pandas_udf
"""
from pyspark.sql import DataFrame
from pyspark.sql.pandas.functions import pandas_udf
assert isinstance(self, DataFrame)
# The usage of the pandas_udf is internal so type checking is disabled.
udf = pandas_udf(
func, returnType=schema, functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
) # type: ignore[call-overload]
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.mapInPandas(udf_column._jc.expr(), barrier)
return DataFrame(jdf, self.sparkSession)
def mapInArrow(
self, func: "ArrowMapIterFunction", schema: Union[StructType, str], barrier: bool = False
) -> "DataFrame":
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
function that takes and outputs a PyArrow's `RecordBatch`, and returns the result as a
:class:`DataFrame`.
The function should take an iterator of `pyarrow.RecordBatch`\\s and return
another iterator of `pyarrow.RecordBatch`\\s. All columns are passed
together as an iterator of `pyarrow.RecordBatch`\\s to the function and the
returned iterator of `pyarrow.RecordBatch`\\s are combined as a :class:`DataFrame`.
Each `pyarrow.RecordBatch` size can be controlled by
`spark.sql.execution.arrow.maxRecordsPerBatch`. The size of the function's input and
output can be different.
.. versionadded:: 3.3.0
Parameters
----------
func : function
a Python native function that takes an iterator of `pyarrow.RecordBatch`\\s, and
outputs an iterator of `pyarrow.RecordBatch`\\s.
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
barrier : bool, optional, default False
Use barrier mode execution.
.. versionadded: 3.5.0
Examples
--------
>>> import pyarrow # doctest: +SKIP
>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
>>> def filter_func(iterator):
... for batch in iterator:
... pdf = batch.to_pandas()
... yield pyarrow.RecordBatch.from_pandas(pdf[pdf.id == 1])
>>> df.mapInArrow(filter_func, df.schema).show() # doctest: +SKIP
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+
Set ``barrier`` to ``True`` to force the ``mapInArrow`` stage running in the
barrier mode, it ensures all Python workers in the stage will be
launched concurrently.
>>> df.mapInArrow(filter_func, df.schema, barrier=True).show() # doctest: +SKIP
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+
Notes
-----
This API is unstable, and for developers.
See Also
--------
pyspark.sql.functions.pandas_udf
pyspark.sql.DataFrame.mapInPandas
"""
from pyspark.sql import DataFrame
from pyspark.sql.pandas.functions import pandas_udf
assert isinstance(self, DataFrame)
# The usage of the pandas_udf is internal so type checking is disabled.
udf = pandas_udf(
func, returnType=schema, functionType=PythonEvalType.SQL_MAP_ARROW_ITER_UDF
) # type: ignore[call-overload]
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.pythonMapInArrow(udf_column._jc.expr(), barrier)
return DataFrame(jdf, self.sparkSession)
def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.pandas.map_ops
globs = pyspark.sql.pandas.map_ops.__dict__.copy()
spark = (
SparkSession.builder.master("local[4]").appName("sql.pandas.map_ops tests").getOrCreate()
)
globs["spark"] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.pandas.map_ops,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
)
spark.stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()