## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importsysfromtypingimportUnion,TYPE_CHECKINGfrompyspark.rddimportPythonEvalTypefrompyspark.sql.typesimportStructTypeifTYPE_CHECKING:frompyspark.sql.dataframeimportDataFramefrompyspark.sql.pandas._typingimportPandasMapIterFunction,ArrowMapIterFunctionclassPandasMapOpsMixin:""" Min-in for pandas map operations. Currently, only :class:`DataFrame` can use this class. """defmapInPandas(self,func:"PandasMapIterFunction",schema:Union[StructType,str],barrier:bool=False)->"DataFrame":""" Maps an iterator of batches in the current :class:`DataFrame` using a Python native function that takes and outputs a pandas DataFrame, and returns the result as a :class:`DataFrame`. The function should take an iterator of `pandas.DataFrame`\\s and return another iterator of `pandas.DataFrame`\\s. All columns are passed together as an iterator of `pandas.DataFrame`\\s to the function and the returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`. Each `pandas.DataFrame` size can be controlled by `spark.sql.execution.arrow.maxRecordsPerBatch`. The size of the function's input and output can be different. .. versionadded:: 3.0.0 .. versionchanged:: 3.4.0 Supports Spark Connect. Parameters ---------- func : function a Python native function that takes an iterator of `pandas.DataFrame`\\s, and outputs an iterator of `pandas.DataFrame`\\s. schema : :class:`pyspark.sql.types.DataType` or str the return type of the `func` in PySpark. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. barrier : bool, optional, default False Use barrier mode execution. .. versionadded: 3.5.0 Examples -------- >>> from pyspark.sql.functions import pandas_udf >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) >>> def filter_func(iterator): ... for pdf in iterator: ... yield pdf[pdf.id == 1] ... >>> df.mapInPandas(filter_func, df.schema).show() # doctest: +SKIP +---+---+ | id|age| +---+---+ | 1| 21| +---+---+ Set ``barrier`` to ``True`` to force the ``mapInPandas`` stage running in the barrier mode, it ensures all Python workers in the stage will be launched concurrently. >>> df.mapInPandas(filter_func, df.schema, barrier=True).show() # doctest: +SKIP +---+---+ | id|age| +---+---+ | 1| 21| +---+---+ Notes ----- This API is experimental See Also -------- pyspark.sql.functions.pandas_udf """frompyspark.sqlimportDataFramefrompyspark.sql.pandas.functionsimportpandas_udfassertisinstance(self,DataFrame)# The usage of the pandas_udf is internal so type checking is disabled.udf=pandas_udf(func,returnType=schema,functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)# type: ignore[call-overload]udf_column=udf(*[self[col]forcolinself.columns])jdf=self._jdf.mapInPandas(udf_column._jc.expr(),barrier)returnDataFrame(jdf,self.sparkSession)defmapInArrow(self,func:"ArrowMapIterFunction",schema:Union[StructType,str],barrier:bool=False)->"DataFrame":""" Maps an iterator of batches in the current :class:`DataFrame` using a Python native function that takes and outputs a PyArrow's `RecordBatch`, and returns the result as a :class:`DataFrame`. The function should take an iterator of `pyarrow.RecordBatch`\\s and return another iterator of `pyarrow.RecordBatch`\\s. All columns are passed together as an iterator of `pyarrow.RecordBatch`\\s to the function and the returned iterator of `pyarrow.RecordBatch`\\s are combined as a :class:`DataFrame`. Each `pyarrow.RecordBatch` size can be controlled by `spark.sql.execution.arrow.maxRecordsPerBatch`. The size of the function's input and output can be different. .. versionadded:: 3.3.0 Parameters ---------- func : function a Python native function that takes an iterator of `pyarrow.RecordBatch`\\s, and outputs an iterator of `pyarrow.RecordBatch`\\s. schema : :class:`pyspark.sql.types.DataType` or str the return type of the `func` in PySpark. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. barrier : bool, optional, default False Use barrier mode execution. .. versionadded: 3.5.0 Examples -------- >>> import pyarrow # doctest: +SKIP >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) >>> def filter_func(iterator): ... for batch in iterator: ... pdf = batch.to_pandas() ... yield pyarrow.RecordBatch.from_pandas(pdf[pdf.id == 1]) >>> df.mapInArrow(filter_func, df.schema).show() # doctest: +SKIP +---+---+ | id|age| +---+---+ | 1| 21| +---+---+ Set ``barrier`` to ``True`` to force the ``mapInArrow`` stage running in the barrier mode, it ensures all Python workers in the stage will be launched concurrently. >>> df.mapInArrow(filter_func, df.schema, barrier=True).show() # doctest: +SKIP +---+---+ | id|age| +---+---+ | 1| 21| +---+---+ Notes ----- This API is unstable, and for developers. See Also -------- pyspark.sql.functions.pandas_udf pyspark.sql.DataFrame.mapInPandas """frompyspark.sqlimportDataFramefrompyspark.sql.pandas.functionsimportpandas_udfassertisinstance(self,DataFrame)# The usage of the pandas_udf is internal so type checking is disabled.udf=pandas_udf(func,returnType=schema,functionType=PythonEvalType.SQL_MAP_ARROW_ITER_UDF)# type: ignore[call-overload]udf_column=udf(*[self[col]forcolinself.columns])jdf=self._jdf.pythonMapInArrow(udf_column._jc.expr(),barrier)returnDataFrame(jdf,self.sparkSession)def_test()->None:importdoctestfrompyspark.sqlimportSparkSessionimportpyspark.sql.pandas.map_opsglobs=pyspark.sql.pandas.map_ops.__dict__.copy()spark=(SparkSession.builder.master("local[4]").appName("sql.pandas.map_ops tests").getOrCreate())globs["spark"]=spark(failure_count,test_count)=doctest.testmod(pyspark.sql.pandas.map_ops,globs=globs,optionflags=doctest.ELLIPSIS|doctest.NORMALIZE_WHITESPACE|doctest.REPORT_NDIFF,)spark.stop()iffailure_count:sys.exit(-1)if__name__=="__main__":_test()