#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import ClassVar, Type, Dict, List, Optional, Union, cast, TYPE_CHECKING
from pyspark.util import local_connect_and_auth
from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
from pyspark.errors import PySparkRuntimeError
if TYPE_CHECKING:
    from pyspark.resource import ResourceInformation
[docs]class TaskContext:
    """
    Contextual information about a task which can be read or mutated during
    execution. To access the TaskContext for a running task, use:
    :meth:`TaskContext.get`.
    .. versionadded:: 2.2.0
    Examples
    --------
    >>> from pyspark import TaskContext
    Get a task context instance from :class:`RDD`.
    >>> spark.sparkContext.setLocalProperty("key1", "value")
    >>> taskcontext = spark.sparkContext.parallelize([1]).map(lambda _: TaskContext.get()).first()
    >>> isinstance(taskcontext.attemptNumber(), int)
    True
    >>> isinstance(taskcontext.partitionId(), int)
    True
    >>> isinstance(taskcontext.stageId(), int)
    True
    >>> isinstance(taskcontext.taskAttemptId(), int)
    True
    >>> taskcontext.getLocalProperty("key1")
    'value'
    >>> isinstance(taskcontext.cpus(), int)
    True
    Get a task context instance from a dataframe via Python UDF.
    >>> from pyspark.sql import Row
    >>> from pyspark.sql.functions import udf
    >>> @udf("STRUCT<anum: INT, partid: INT, stageid: INT, taskaid: INT, prop: STRING, cpus: INT>")
    ... def taskcontext_as_row():
    ...    taskcontext = TaskContext.get()
    ...    return Row(
    ...        anum=taskcontext.attemptNumber(),
    ...        partid=taskcontext.partitionId(),
    ...        stageid=taskcontext.stageId(),
    ...        taskaid=taskcontext.taskAttemptId(),
    ...        prop=taskcontext.getLocalProperty("key2"),
    ...        cpus=taskcontext.cpus())
    ...
    >>> spark.sparkContext.setLocalProperty("key2", "value")
    >>> [(anum, partid, stageid, taskaid, prop, cpus)] = (
    ...     spark.range(1).select(taskcontext_as_row()).first()
    ... )
    >>> isinstance(anum, int)
    True
    >>> isinstance(partid, int)
    True
    >>> isinstance(stageid, int)
    True
    >>> isinstance(taskaid, int)
    True
    >>> prop
    'value'
    >>> isinstance(cpus, int)
    True
    Get a task context instance from a dataframe via Pandas UDF.
    >>> import pandas as pd  # doctest: +SKIP
    >>> from pyspark.sql.functions import pandas_udf
    >>> @pandas_udf("STRUCT<"
    ...     "anum: INT, partid: INT, stageid: INT, taskaid: INT, prop: STRING, cpus: INT>")
    ... def taskcontext_as_row(_):
    ...    taskcontext = TaskContext.get()
    ...    return pd.DataFrame({
    ...        "anum": [taskcontext.attemptNumber()],
    ...        "partid": [taskcontext.partitionId()],
    ...        "stageid": [taskcontext.stageId()],
    ...        "taskaid": [taskcontext.taskAttemptId()],
    ...        "prop": [taskcontext.getLocalProperty("key3")],
    ...        "cpus": [taskcontext.cpus()]
    ...    })  # doctest: +SKIP
    ...
    >>> spark.sparkContext.setLocalProperty("key3", "value")  # doctest: +SKIP
    >>> [(anum, partid, stageid, taskaid, prop, cpus)] = (
    ...     spark.range(1).select(taskcontext_as_row("id")).first()
    ... )  # doctest: +SKIP
    >>> isinstance(anum, int)
    True
    >>> isinstance(partid, int)
    True
    >>> isinstance(stageid, int)
    True
    >>> isinstance(taskaid, int)
    True
    >>> prop
    'value'
    >>> isinstance(cpus, int)
    True
    """
    _taskContext: ClassVar[Optional["TaskContext"]] = None
    _attemptNumber: Optional[int] = None
    _partitionId: Optional[int] = None
    _stageId: Optional[int] = None
    _taskAttemptId: Optional[int] = None
    _localProperties: Optional[Dict[str, str]] = None
    _cpus: Optional[int] = None
    _resources: Optional[Dict[str, "ResourceInformation"]] = None
    def __new__(cls: Type["TaskContext"]) -> "TaskContext":
        """
        Even if users construct :class:`TaskContext` instead of using get, give them the singleton.
        """
        taskContext = cls._taskContext
        if taskContext is not None:
            return taskContext
        cls._taskContext = taskContext = object.__new__(cls)
        return taskContext
    @classmethod
    def _getOrCreate(cls: Type["TaskContext"]) -> "TaskContext":
        """Internal function to get or create global :class:`TaskContext`."""
        if cls._taskContext is None:
            cls._taskContext = TaskContext()
        return cls._taskContext
    @classmethod
    def _setTaskContext(cls: Type["TaskContext"], taskContext: "TaskContext") -> None:
        cls._taskContext = taskContext
[docs]    @classmethod
    def get(cls: Type["TaskContext"]) -> Optional["TaskContext"]:
        """
        Return the currently active :class:`TaskContext`. This can be called inside of
        user functions to access contextual information about running tasks.
        Returns
        -------
        :class:`TaskContext`, optional
        Notes
        -----
        Must be called on the worker, not the driver. Returns ``None`` if not initialized.
        """
        return cls._taskContext 
[docs]    def stageId(self) -> int:
        """
        The ID of the stage that this task belong to.
        Returns
        -------
        int
            current stage id.
        """
        return cast(int, self._stageId) 
[docs]    def partitionId(self) -> int:
        """
        The ID of the RDD partition that is computed by this task.
        Returns
        -------
        int
            current partition id.
        """
        return cast(int, self._partitionId) 
[docs]    def attemptNumber(self) -> int:
        """
        How many times this task has been attempted.  The first task attempt will be assigned
        attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
        Returns
        -------
        int
            current attempt number.
        """
        return cast(int, self._attemptNumber) 
[docs]    def taskAttemptId(self) -> int:
        """
        An ID that is unique to this task attempt (within the same :class:`SparkContext`,
        no two task attempts will share the same attempt ID).  This is roughly equivalent
        to Hadoop's `TaskAttemptID`.
        Returns
        -------
        int
            current task attempt id.
        """
        return cast(int, self._taskAttemptId) 
[docs]    def getLocalProperty(self, key: str) -> Optional[str]:
        """
        Get a local property set upstream in the driver, or None if it is missing.
        Parameters
        ----------
        key : str
            the key of the local property to get.
        Returns
        -------
        int
            the value of the local property.
        """
        return cast(Dict[str, str], self._localProperties).get(key, None) 
[docs]    def cpus(self) -> int:
        """
        CPUs allocated to the task.
        Returns
        -------
        int
            the number of CPUs.
        """
        return cast(int, self._cpus) 
[docs]    def resources(self) -> Dict[str, "ResourceInformation"]:
        """
        Resources allocated to the task. The key is the resource name and the value is information
        about the resource.
        Returns
        -------
        dict
            a dictionary of a string resource name, and :class:`ResourceInformation`.
        """
        from pyspark.resource import ResourceInformation
        return cast(Dict[str, "ResourceInformation"], self._resources)  
BARRIER_FUNCTION = 1
ALL_GATHER_FUNCTION = 2
def _load_from_socket(
    conn_info: Optional[Union[str, int]],
    auth_secret: Optional[str],
    function: int,
    all_gather_message: Optional[str] = None,
) -> List[str]:
    """
    Load data from a given socket, this is a blocking method thus only return when the socket
    connection has been closed.
    """
    (sockfile, sock) = local_connect_and_auth(conn_info, auth_secret)
    # The call may block forever, so no timeout
    sock.settimeout(None)
    if function == BARRIER_FUNCTION:
        # Make a barrier() function call.
        write_int(function, sockfile)
    elif function == ALL_GATHER_FUNCTION:
        # Make a all_gather() function call.
        write_int(function, sockfile)
        write_with_length(cast(str, all_gather_message).encode("utf-8"), sockfile)
    else:
        raise ValueError("Unrecognized function type")
    sockfile.flush()
    # Collect result.
    len = read_int(sockfile)
    res = []
    for i in range(len):
        res.append(UTF8Deserializer().loads(sockfile))
    # Release resources.
    sockfile.close()
    sock.close()
    return res
[docs]class BarrierTaskContext(TaskContext):
    """
    A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage.
    Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task.
    .. versionadded:: 2.4.0
    Notes
    -----
    This API is experimental
    Examples
    --------
    Set a barrier, and execute it with RDD.
    >>> from pyspark import BarrierTaskContext
    >>> def block_and_do_something(itr):
    ...     taskcontext = BarrierTaskContext.get()
    ...     # Do something.
    ...
    ...     # Wait until all tasks finished.
    ...     taskcontext.barrier()
    ...
    ...     return itr
    ...
    >>> rdd = spark.sparkContext.parallelize([1])
    >>> rdd.barrier().mapPartitions(block_and_do_something).collect()
    [1]
    """
    _conn_info: ClassVar[Optional[Union[str, int]]] = None
    _secret: ClassVar[Optional[str]] = None
    @classmethod
    def _getOrCreate(cls: Type["BarrierTaskContext"]) -> "BarrierTaskContext":
        """
        Internal function to get or create global :class:`BarrierTaskContext`. We need to make sure
        :class:`BarrierTaskContext` is returned from here because it is needed in python worker
        reuse scenario, see SPARK-25921 for more details.
        """
        if not isinstance(cls._taskContext, BarrierTaskContext):
            cls._taskContext = object.__new__(cls)
        return cls._taskContext
[docs]    @classmethod
    def get(cls: Type["BarrierTaskContext"]) -> "BarrierTaskContext":
        """
        Return the currently active :class:`BarrierTaskContext`.
        This can be called inside of user functions to access contextual information about
        running tasks.
        Notes
        -----
        Must be called on the worker, not the driver. Returns ``None`` if not initialized.
        An Exception will raise if it is not in a barrier stage.
        This API is experimental
        """
        if not isinstance(cls._taskContext, BarrierTaskContext):
            raise PySparkRuntimeError(
                errorClass="NOT_IN_BARRIER_STAGE",
                messageParameters={},
            )
        return cls._taskContext 
    @classmethod
    def _initialize(
        cls: Type["BarrierTaskContext"], conn_info: Optional[Union[str, int]], secret: Optional[str]
    ) -> None:
        """
        Initialize :class:`BarrierTaskContext`, other methods within :class:`BarrierTaskContext`
        can only be called after BarrierTaskContext is initialized.
        """
        cls._conn_info = conn_info
        cls._secret = secret
[docs]    def barrier(self) -> None:
        """
        Sets a global barrier and waits until all tasks in this stage hit this barrier.
        Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks
        in the same stage have reached this routine.
        .. versionadded:: 2.4.0
        Notes
        -----
        This API is experimental
        In a barrier stage, each task much have the same number of `barrier()`
        calls, in all possible code branches. Otherwise, you may get the job hanging
        or a `SparkException` after timeout.
        """
        if self._conn_info is None:
            raise PySparkRuntimeError(
                errorClass="CALL_BEFORE_INITIALIZE",
                messageParameters={
                    "func_name": "barrier",
                    "object": "BarrierTaskContext",
                },
            )
        else:
            _load_from_socket(self._conn_info, self._secret, BARRIER_FUNCTION) 
[docs]    def allGather(self, message: str = "") -> List[str]:
        """
        This function blocks until all tasks in the same stage have reached this routine.
        Each task passes in a message and returns with a list of all the messages passed in
        by each of those tasks.
        .. versionadded:: 3.0.0
        Notes
        -----
        This API is experimental
        In a barrier stage, each task much have the same number of `barrier()`
        calls, in all possible code branches. Otherwise, you may get the job hanging
        or a `SparkException` after timeout.
        """
        if not isinstance(message, str):
            raise TypeError("Argument `message` must be of type `str`")
        elif self._conn_info is None:
            raise PySparkRuntimeError(
                errorClass="CALL_BEFORE_INITIALIZE",
                messageParameters={
                    "func_name": "allGather",
                    "object": "BarrierTaskContext",
                },
            )
        else:
            return _load_from_socket(self._conn_info, self._secret, ALL_GATHER_FUNCTION, message) 
[docs]    def getTaskInfos(self) -> List["BarrierTaskInfo"]:
        """
        Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage,
        ordered by partition ID.
        .. versionadded:: 2.4.0
        Notes
        -----
        This API is experimental
        Examples
        --------
        >>> from pyspark import BarrierTaskContext
        >>> rdd = spark.sparkContext.parallelize([1])
        >>> barrier_info = rdd.barrier().mapPartitions(
        ...     lambda _: [BarrierTaskContext.get().getTaskInfos()]).collect()[0][0]
        >>> barrier_info.address
        '...:...'
        """
        if self._conn_info is None:
            raise PySparkRuntimeError(
                errorClass="CALL_BEFORE_INITIALIZE",
                messageParameters={
                    "func_name": "getTaskInfos",
                    "object": "BarrierTaskContext",
                },
            )
        else:
            addresses = cast(Dict[str, str], self._localProperties).get("addresses", "")
            return [BarrierTaskInfo(h.strip()) for h in addresses.split(",")]  
[docs]class BarrierTaskInfo:
    """
    Carries all task infos of a barrier task.
    .. versionadded:: 2.4.0
    Attributes
    ----------
    address : str
        The IPv4 address (host:port) of the executor that the barrier task is running on
    Notes
    -----
    This API is experimental
    """
    def __init__(self, address: str) -> None:
        self.address = address 
def _test() -> None:
    import doctest
    import sys
    from pyspark.sql import SparkSession
    globs = globals().copy()
    globs["spark"] = (
        SparkSession.builder.master("local[2]").appName("taskcontext tests").getOrCreate()
    )
    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
    globs["spark"].stop()
    if failure_count:
        sys.exit(-1)
if __name__ == "__main__":
    _test()