#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from enum import Enum
import inspect
import functools
import os
from typing import (
    Any,
    Callable,
    Dict,
    Optional,
    List,
    Sequence,
    TYPE_CHECKING,
    cast,
    TypeVar,
    Union,
)
# For backward compatibility.
from pyspark.errors import (  # noqa: F401
    AnalysisException,
    ParseException,
    IllegalArgumentException,
    StreamingQueryException,
    QueryExecutionException,
    PythonException,
    UnknownException,
    SparkUpgradeException,
    PySparkImportError,
    PySparkNotImplementedError,
    PySparkRuntimeError,
)
from pyspark.util import is_remote_only, JVM_INT_MAX
from pyspark.errors.exceptions.captured import CapturedException  # noqa: F401
from pyspark.find_spark_home import _find_spark_home
if TYPE_CHECKING:
    from py4j.java_collections import JavaArray
    from py4j.java_gateway import (
        JavaClass,
        JavaGateway,
        JavaObject,
        JVMView,
    )
    from pyspark import SparkContext
    from pyspark.sql.session import SparkSession
    from pyspark.sql.dataframe import DataFrame
    from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
def to_java_array(gateway: "JavaGateway", jtype: "JavaClass", arr: Sequence[Any]) -> "JavaArray":
    """
    Convert python list to java type array
    Parameters
    ----------
    gateway :
        Py4j Gateway
    jtype :
        java type of element in array
    arr :
        python type list
    """
    jarray: "JavaArray" = gateway.new_array(jtype, len(arr))
    for i in range(0, len(arr)):
        jarray[i] = arr[i]
    return jarray
def to_scala_map(jvm: "JVMView", dic: Dict) -> "JavaObject":
    """
    Convert a dict into a Scala Map.
    """
    assert jvm is not None
    return jvm.PythonUtils.toScalaMap(dic)
def require_test_compiled() -> None:
    """Raise Exception if test classes are not compiled"""
    import os
    import glob
    test_class_path = os.path.join(_find_spark_home(), "sql", "core", "target", "*", "test-classes")
    paths = glob.glob(test_class_path)
    if len(paths) == 0:
        raise PySparkRuntimeError(
            errorClass="TEST_CLASS_NOT_COMPILED",
            messageParameters={"test_class_path": test_class_path},
        )
def require_minimum_plotly_version() -> None:
    """Raise ImportError if plotly is not installed"""
    from pyspark.loose_version import LooseVersion
    minimum_plotly_version = "4.8"
    try:
        import plotly
        have_plotly = True
    except ImportError as error:
        have_plotly = False
        raised_error = error
    if not have_plotly:
        raise PySparkImportError(
            errorClass="PACKAGE_NOT_INSTALLED",
            messageParameters={
                "package_name": "Plotly",
                "minimum_version": str(minimum_plotly_version),
            },
        ) from raised_error
    if LooseVersion(plotly.__version__) < LooseVersion(minimum_plotly_version):
        raise PySparkImportError(
            errorClass="UNSUPPORTED_PACKAGE_VERSION",
            messageParameters={
                "package_name": "Plotly",
                "minimum_version": str(minimum_plotly_version),
                "current_version": str(plotly.__version__),
            },
        )
class ForeachBatchFunction:
    """
    This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps
    the user-defined 'foreachBatch' function such that it can be called from the JVM when
    the query is active.
    """
    def __init__(self, session: "SparkSession", func: Callable[["DataFrame", int], None]):
        self.func = func
        self.session = session
    def call(self, jdf: "JavaObject", batch_id: int) -> None:
        from pyspark.sql.dataframe import DataFrame
        from pyspark.sql.session import SparkSession
        try:
            session_jdf = jdf.sparkSession()
            # assuming that spark context is still the same between JVM and PySpark
            wrapped_session_jdf = SparkSession(self.session.sparkContext, session_jdf)
            self.func(DataFrame(jdf, wrapped_session_jdf), batch_id)
        except Exception as e:
            self.error = e
            raise e
    class Java:
        implements = ["org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction"]
# Python implementation of 'org.apache.spark.sql.catalyst.util.StringConcat'
class StringConcat:
    def __init__(self, maxLength: int = JVM_INT_MAX - 15):
        self.maxLength: int = maxLength
        self.strings: List[str] = []
        self.length: int = 0
    def atLimit(self) -> bool:
        return self.length >= self.maxLength
    def append(self, s: str) -> None:
        if s is not None:
            sLen = len(s)
            if not self.atLimit():
                available = self.maxLength - self.length
                stringToAppend = s if available >= sLen else s[0:available]
                self.strings.append(stringToAppend)
            self.length = min(self.length + sLen, JVM_INT_MAX - 15)
    def toString(self) -> str:
        # finalLength = self.maxLength if self.atLimit()  else self.length
        return "".join(self.strings)
# Python implementation of 'org.apache.spark.util.SparkSchemaUtils.escapeMetaCharacters'
def escape_meta_characters(s: str) -> str:
    return (
        s.replace("\n", "\\n")
        .replace("\r", "\\r")
        .replace("\t", "\\t")
        .replace("\f", "\\f")
        .replace("\b", "\\b")
        .replace("\u000B", "\\v")
        .replace("\u0007", "\\a")
    )
def to_str(value: Any) -> Optional[str]:
    """
    A wrapper over str(), but converts bool values to lower case strings.
    If None is given, just returns None, instead of converting it to string "None".
    """
    if isinstance(value, bool):
        return str(value).lower()
    elif value is None:
        return value
    else:
        return str(value)
def enum_to_value(value: Any) -> Any:
    """Convert an Enum to its value if it is not None."""
    return enum_to_value(value.value) if value is not None and isinstance(value, Enum) else value
def is_timestamp_ntz_preferred() -> bool:
    """
    Return a bool if TimestampNTZType is preferred according to the SQL configuration set.
    """
    if is_remote():
        from pyspark.sql.connect.session import SparkSession as ConnectSparkSession
        session = ConnectSparkSession.getActiveSession()
        if session is None:
            return False
        else:
            return session.conf.get("spark.sql.timestampType", None) == "TIMESTAMP_NTZ"
    else:
        from pyspark import SparkContext
        jvm = SparkContext._jvm
        return jvm is not None and jvm.PythonSQLUtils.isTimestampNTZPreferred()
[docs]def is_remote() -> bool:
    """
    Returns if the current running environment is for Spark Connect.
    .. versionadded:: 4.0.0
    Notes
    -----
    This will only return ``True`` if there is a remote session running.
    Otherwise, it returns ``False``.
    This API is unstable, and for developers.
    Returns
    -------
    bool
    Examples
    --------
    >>> from pyspark.sql import is_remote
    >>> is_remote()
    False
    """
    return ("SPARK_CONNECT_MODE_ENABLED" in os.environ) or is_remote_only() 
def try_remote_functions(f: FuncT) -> FuncT:
    """Mark API supported from Spark Connect."""
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect import functions
            return getattr(functions, f.__name__)(*args, **kwargs)
        else:
            return f(*args, **kwargs)
    return cast(FuncT, wrapped)
def try_partitioning_remote_functions(f: FuncT) -> FuncT:
    """Mark API supported from Spark Connect."""
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.functions import partitioning
            return getattr(partitioning, f.__name__)(*args, **kwargs)
        else:
            return f(*args, **kwargs)
    return cast(FuncT, wrapped)
def try_remote_avro_functions(f: FuncT) -> FuncT:
    """Mark API supported from Spark Connect."""
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.avro import functions
            return getattr(functions, f.__name__)(*args, **kwargs)
        else:
            return f(*args, **kwargs)
    return cast(FuncT, wrapped)
def try_remote_protobuf_functions(f: FuncT) -> FuncT:
    """Mark API supported from Spark Connect."""
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.protobuf import functions
            return getattr(functions, f.__name__)(*args, **kwargs)
        else:
            return f(*args, **kwargs)
    return cast(FuncT, wrapped)
def get_active_spark_context() -> "SparkContext":
    """Raise RuntimeError if SparkContext is not initialized,
    otherwise, returns the active SparkContext."""
    from pyspark import SparkContext
    sc = SparkContext._active_spark_context
    if sc is None or sc._jvm is None:
        raise PySparkRuntimeError(
            errorClass="SESSION_OR_CONTEXT_NOT_EXISTS",
            messageParameters={},
        )
    return sc
def try_remote_session_classmethod(f: FuncT) -> FuncT:
    """Mark API supported from Spark Connect."""
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.session import SparkSession
            assert inspect.isclass(args[0])
            return getattr(SparkSession, f.__name__)(*args[1:], **kwargs)
        else:
            return f(*args, **kwargs)
    return cast(FuncT, wrapped)
def dispatch_df_method(f: FuncT) -> FuncT:
    """
    For the use cases of direct DataFrame.method(df, ...), it checks if self
    is a Connect DataFrame or Classic DataFrame, and dispatches.
    """
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
            if isinstance(args[0], ConnectDataFrame):
                return getattr(ConnectDataFrame, f.__name__)(*args, **kwargs)
        else:
            from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame
            if isinstance(args[0], ClassicDataFrame):
                return getattr(ClassicDataFrame, f.__name__)(*args, **kwargs)
        raise PySparkNotImplementedError(
            errorClass="NOT_IMPLEMENTED",
            messageParameters={"feature": f"DataFrame.{f.__name__}"},
        )
    return cast(FuncT, wrapped)
def dispatch_col_method(f: FuncT) -> FuncT:
    """
    For the use cases of direct Column.method(col, ...), it checks if self
    is a Connect Column or Classic Column, and dispatches.
    """
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.column import Column as ConnectColumn
            if isinstance(args[0], ConnectColumn):
                return getattr(ConnectColumn, f.__name__)(*args, **kwargs)
        else:
            from pyspark.sql.classic.column import Column as ClassicColumn
            if isinstance(args[0], ClassicColumn):
                return getattr(ClassicColumn, f.__name__)(*args, **kwargs)
        raise PySparkNotImplementedError(
            errorClass="NOT_IMPLEMENTED",
            messageParameters={"feature": f"Column.{f.__name__}"},
        )
    return cast(FuncT, wrapped)
def dispatch_window_method(f: FuncT) -> FuncT:
    """
    For use cases of direct Window.method(col, ...), this function dispatches
    the call to either ConnectWindow or ClassicWindow based on the execution
    environment.
    """
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.window import Window as ConnectWindow
            return getattr(ConnectWindow, f.__name__)(*args, **kwargs)
        else:
            from pyspark.sql.classic.window import Window as ClassicWindow
            return getattr(ClassicWindow, f.__name__)(*args, **kwargs)
    return cast(FuncT, wrapped)
def dispatch_table_arg_method(f: FuncT) -> FuncT:
    """
    Dispatches TableArg method calls to either ConnectTableArg or ClassicTableArg
    based on the execution environment.
    """
    @functools.wraps(f)
    def wrapped(*args: Any, **kwargs: Any) -> Any:
        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
            from pyspark.sql.connect.table_arg import TableArg as ConnectTableArg
            return getattr(ConnectTableArg, f.__name__)(*args, **kwargs)
        else:
            from pyspark.sql.classic.table_arg import TableArg as ClassicTableArg
            return getattr(ClassicTableArg, f.__name__)(*args, **kwargs)
    return cast(FuncT, wrapped)
def pyspark_column_op(
    func_name: str, left: "IndexOpsLike", right: Any, fillna: Any = None
) -> Union["SeriesOrIndex", None]:
    """
    Wrapper function for column_op to get proper Column class.
    """
    from pyspark.pandas.base import column_op
    from pyspark.sql.column import Column
    from pyspark.pandas.data_type_ops.base import _is_extension_dtypes
    result = column_op(getattr(Column, func_name))(left, right)
    # It works as expected on extension dtype, so we don't need to call `fillna` for this case.
    if (fillna is not None) and (_is_extension_dtypes(left) or _is_extension_dtypes(right)):
        fillna = None
    # TODO(SPARK-43877): Fix behavior difference for compare binary functions.
    return result.fillna(fillna) if fillna is not None else result
def get_lit_sql_str(val: str) -> str:
    # Equivalent to `lit(val)._jc.expr().sql()` for string typed val
    # See `sql` definition in `sql/catalyst/src/main/scala/org/apache/spark/
    # sql/catalyst/expressions/literals.scala`
    return "'" + val.replace("\\", "\\\\").replace("'", "\\'") + "'"
class NumpyHelper:
    @staticmethod
    def linspace(start: float, stop: float, num: int) -> Sequence[float]:
        if num == 1:
            return [float(start)]
        step = (float(stop) - float(start)) / (num - 1)
        return [start + step * i for i in range(num)]
def remote_only(func: Union[Callable, property]) -> Union[Callable, property]:
    """
    Decorator to mark a function or method as only available in Spark Connect.
    This decorator allows for easy identification of Spark Connect-specific APIs.
    """
    if isinstance(func, property):
        # If it's a property, we need to set the attribute on the getter function
        getter_func = func.fget
        getter_func._remote_only = True  # type: ignore[union-attr]
        return property(getter_func)
    else:
        func._remote_only = True  # type: ignore[attr-defined]
        return func