Source code for pyspark.sql.pandas.map_ops

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys

from pyspark import since
from pyspark.rdd import PythonEvalType


class PandasMapOpsMixin(object):
    """
    Min-in for pandas map operations. Currently, only :class:`DataFrame`
    can use this class.
    """

    @since(3.0)
    def mapInPandas(self, func, schema):
        """
        Maps an iterator of batches in the current :class:`DataFrame` using a Python native
        function that takes and outputs a pandas DataFrame, and returns the result as a
        :class:`DataFrame`.

        The function should take an iterator of `pandas.DataFrame`\\s and return
        another iterator of `pandas.DataFrame`\\s. All columns are passed
        together as an iterator of `pandas.DataFrame`\\s to the function and the
        returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
        Each `pandas.DataFrame` size can be controlled by
        `spark.sql.execution.arrow.maxRecordsPerBatch`.

        :param func: a Python native function that takes an iterator of `pandas.DataFrame`\\s, and
            outputs an iterator of `pandas.DataFrame`\\s.
        :param schema: the return type of the `func` in PySpark. The value can be either a
            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.

        >>> from pyspark.sql.functions import pandas_udf
        >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
        >>> def filter_func(iterator):
        ...     for pdf in iterator:
        ...         yield pdf[pdf.id == 1]
        >>> df.mapInPandas(filter_func, df.schema).show()  # doctest: +SKIP
        +---+---+
        | id|age|
        +---+---+
        |  1| 21|
        +---+---+

        .. seealso:: :meth:`pyspark.sql.functions.pandas_udf`

        .. note:: Experimental
        """
        from pyspark.sql import DataFrame
        from pyspark.sql.pandas.functions import pandas_udf

        assert isinstance(self, DataFrame)

        udf = pandas_udf(
            func, returnType=schema, functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
        udf_column = udf(*[self[col] for col in self.columns])
        jdf = self._jdf.mapInPandas(udf_column._jc.expr())
        return DataFrame(jdf, self.sql_ctx)


def _test():
    import doctest
    from pyspark.sql import SparkSession
    import pyspark.sql.pandas.map_ops
    globs = pyspark.sql.pandas.map_ops.__dict__.copy()
    spark = SparkSession.builder\
        .master("local[4]")\
        .appName("sql.pandas.map_ops tests")\
        .getOrCreate()
    globs['spark'] = spark
    (failure_count, test_count) = doctest.testmod(
        pyspark.sql.pandas.map_ops, globs=globs,
        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
    spark.stop()
    if failure_count:
        sys.exit(-1)


if __name__ == "__main__":
    _test()