Source code for pyspark.sql.context

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys
import warnings
import json

if sys.version >= '3':
    basestring = unicode = str
else:
    from itertools import imap as map

from py4j.protocol import Py4JError

from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
from pyspark.sql import since
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
    _infer_schema, _has_nulltype, _merge_type, _create_converter
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.utils import install_exception_handler
from pyspark.sql.functions import UserDefinedFunction

try:
    import pandas
    has_pandas = True
except Exception:
    has_pandas = False

__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]


def _monkey_patch_RDD(sqlContext):
    def toDF(self, schema=None, sampleRatio=None):
        """
        Converts current :class:`RDD` into a :class:`DataFrame`

        This is a shorthand for ``sqlContext.createDataFrame(rdd, schema, sampleRatio)``

        :param schema: a StructType or list of names of columns
        :param samplingRatio: the sample ratio of rows used for inferring
        :return: a DataFrame

        >>> rdd.toDF().collect()
        [Row(name=u'Alice', age=1)]
        """
        return sqlContext.createDataFrame(self, schema, sampleRatio)

    RDD.toDF = toDF


[docs]class SQLContext(object): """Main entry point for Spark SQL functionality. A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as tables, execute SQL over tables, cache tables, and read parquet files. :param sparkContext: The :class:`SparkContext` backing this SQLContext. :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new SQLContext in the JVM, instead we make all calls to this object. """ @ignore_unicode_prefix def __init__(self, sparkContext, sqlContext=None): """Creates a new SQLContext. >>> from datetime import datetime >>> sqlContext = SQLContext(sc) >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) >>> df = allTypes.toDF() >>> df.registerTempTable("allTypes") >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \ time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm self._scala_SQLContext = sqlContext _monkey_patch_RDD(self) install_exception_handler() @property def _ssql_ctx(self): """Accessor for the JVM Spark SQL context. Subclasses can override this property to provide their own JVM Contexts. """ if self._scala_SQLContext is None: self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext @since(1.3)
[docs] def setConf(self, key, value): """Sets the given Spark SQL configuration property. """ self._ssql_ctx.setConf(key, value)
@since(1.3)
[docs] def getConf(self, key, defaultValue): """Returns the value of Spark SQL configuration property for the given key. If the key is not set, returns defaultValue. """ return self._ssql_ctx.getConf(key, defaultValue)
@property @since("1.3.1") def udf(self): """Returns a :class:`UDFRegistration` for UDF registration. :return: :class:`UDFRegistration` """ return UDFRegistration(self) @since(1.4)
[docs] def range(self, start, end=None, step=1, numPartitions=None): """ Create a :class:`DataFrame` with single LongType column named `id`, containing elements in a range from `start` to `end` (exclusive) with step value `step`. :param start: the start value :param end: the end value (exclusive) :param step: the incremental step (default: 1) :param numPartitions: the number of partitions of the DataFrame :return: :class:`DataFrame` >>> sqlContext.range(1, 7, 2).collect() [Row(id=1), Row(id=3), Row(id=5)] If only one argument is specified, it will be used as the end value. >>> sqlContext.range(3).collect() [Row(id=0), Row(id=1), Row(id=2)] """ if numPartitions is None: numPartitions = self._sc.defaultParallelism if end is None: jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions)) else: jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) return DataFrame(jdf, self)
@ignore_unicode_prefix @since(1.2)
[docs] def registerFunction(self, name, f, returnType=StringType()): """Registers a lambda function as a UDF so it can be used in SQL statements. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF :param samplingRatio: lambda function :param returnType: a :class:`DataType` object >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() [Row(_c0=u'4')] >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(_c0=4)] >>> from pyspark.sql.types import IntegerType >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(_c0=4)] """ udf = UserDefinedFunction(f, returnType, name) self._ssql_ctx.udf().registerPython(name, udf._judf)
def _inferSchemaFromList(self, data): """ Infer schema from list of Row or tuple. :param data: list of Row or tuple :return: StructType """ if not data: raise ValueError("can not infer schema from empty dataset") first = data[0] if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") schema = _infer_schema(first) if _has_nulltype(schema): for r in data: schema = _merge_type(schema, _infer_schema(r)) if not _has_nulltype(schema): break else: raise ValueError("Some of types cannot be determined after inferring") return schema def _inferSchema(self, rdd, samplingRatio=None): """ Infer schema from an RDD of Row or tuple. :param rdd: an RDD of Row or tuple :param samplingRatio: sampling ratio, or no sampling (default) :return: StructType """ first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " "can not infer schema") if type(first) is dict: warnings.warn("Using RDD of dict to inferSchema is deprecated. " "Use pyspark.sql.Row instead") if samplingRatio is None: schema = _infer_schema(first) if _has_nulltype(schema): for row in rdd.take(100)[1:]: schema = _merge_type(schema, _infer_schema(row)) if not _has_nulltype(schema): break else: raise ValueError("Some of types cannot be determined by the " "first 100 rows, please try again with sampling") else: if samplingRatio < 0.99: rdd = rdd.sample(False, float(samplingRatio)) schema = rdd.map(_infer_schema).reduce(_merge_type) return schema @ignore_unicode_prefix
[docs] def inferSchema(self, rdd, samplingRatio=None): """ .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. """ warnings.warn("inferSchema is deprecated, please use createDataFrame instead.") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") return self.createDataFrame(rdd, None, samplingRatio)
@ignore_unicode_prefix
[docs] def applySchema(self, rdd, schema): """ .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. """ warnings.warn("applySchema is deprecated, please use createDataFrame instead") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") if not isinstance(schema, StructType): raise TypeError("schema should be StructType, but got %s" % type(schema)) return self.createDataFrame(rdd, schema)
def _createFromRDD(self, rdd, schema, samplingRatio): """ Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. """ if schema is None or isinstance(schema, (list, tuple)): struct = self._inferSchema(rdd, samplingRatio) converter = _create_converter(struct) rdd = rdd.map(converter) if isinstance(schema, (list, tuple)): for i, name in enumerate(schema): struct.fields[i].name = name struct.names[i] = name schema = struct elif isinstance(schema, StructType): # take the first few rows to verify schema rows = rdd.take(10) for row in rows: _verify_type(row, schema) else: raise TypeError("schema should be StructType or list or None, but got: %s" % schema) # convert python objects to sql data rdd = rdd.map(schema.toInternal) return rdd, schema def _createFromLocal(self, data, schema): """ Create an RDD for DataFrame from an list or pandas.DataFrame, returns the RDD and schema. """ if has_pandas and isinstance(data, pandas.DataFrame): if schema is None: schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] # make sure data could consumed multiple times if not isinstance(data, list): data = list(data) if schema is None or isinstance(schema, (list, tuple)): struct = self._inferSchemaFromList(data) if isinstance(schema, (list, tuple)): for i, name in enumerate(schema): struct.fields[i].name = name struct.names[i] = name schema = struct elif isinstance(schema, StructType): for row in data: _verify_type(row, schema) else: raise TypeError("schema should be StructType or list or None, but got: %s" % schema) # convert python objects to sql data data = [schema.toInternal(row) for row in data] return self._sc.parallelize(data), schema @since(1.3) @ignore_unicode_prefix
[docs] def createDataFrame(self, data, schema=None, samplingRatio=None): """ Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`, list or :class:`pandas.DataFrame`. When ``schema`` is a list of column names, the type of each column will be inferred from ``data``. When ``schema`` is ``None``, it will try to infer the schema (column names and types) from ``data``, which should be an RDD of :class:`Row`, or :class:`namedtuple`, or :class:`dict`. If schema inference is needed, ``samplingRatio`` is used to determined the ratio of rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. :param data: an RDD of :class:`Row`/:class:`tuple`/:class:`list`/:class:`dict`, :class:`list`, or :class:`pandas.DataFrame`. :param schema: a :class:`StructType` or list of column names. default None. :param samplingRatio: the sample ratio of rows used for inferring :return: :class:`DataFrame` >>> l = [('Alice', 1)] >>> sqlContext.createDataFrame(l).collect() [Row(_1=u'Alice', _2=1)] >>> sqlContext.createDataFrame(l, ['name', 'age']).collect() [Row(name=u'Alice', age=1)] >>> d = [{'name': 'Alice', 'age': 1}] >>> sqlContext.createDataFrame(d).collect() [Row(age=1, name=u'Alice')] >>> rdd = sc.parallelize(l) >>> sqlContext.createDataFrame(rdd).collect() [Row(_1=u'Alice', _2=1)] >>> df = sqlContext.createDataFrame(rdd, ['name', 'age']) >>> df.collect() [Row(name=u'Alice', age=1)] >>> from pyspark.sql import Row >>> Person = Row('name', 'age') >>> person = rdd.map(lambda r: Person(*r)) >>> df2 = sqlContext.createDataFrame(person) >>> df2.collect() [Row(name=u'Alice', age=1)] >>> from pyspark.sql.types import * >>> schema = StructType([ ... StructField("name", StringType(), True), ... StructField("age", IntegerType(), True)]) >>> df3 = sqlContext.createDataFrame(rdd, schema) >>> df3.collect() [Row(name=u'Alice', age=1)] >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name=u'Alice', age=1)] >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]]).collect()) # doctest: +SKIP [Row(0=1, 1=2)] """ if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") if isinstance(data, RDD): rdd, schema = self._createFromRDD(data, schema, samplingRatio) else: rdd, schema = self._createFromLocal(data, schema) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) df = DataFrame(jdf, self) df._schema = schema return df
@since(1.3)
[docs] def registerDataFrameAsTable(self, df, tableName): """Registers the given :class:`DataFrame` as a temporary table in the catalog. Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. >>> sqlContext.registerDataFrameAsTable(df, "table1") """ if (df.__class__ is DataFrame): self._ssql_ctx.registerDataFrameAsTable(df._jdf, tableName) else: raise ValueError("Can only register DataFrame as table")
[docs] def parquetFile(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead. >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ warnings.warn("parquetFile is deprecated. Use read.parquet() instead.") gateway = self._sc._gateway jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) for i in range(0, len(paths)): jpaths[i] = paths[i] jdf = self._ssql_ctx.parquetFile(jpaths) return DataFrame(jdf, self)
[docs] def jsonFile(self, path, schema=None, samplingRatio=1.0): """Loads a text file storing one JSON object per line as a :class:`DataFrame`. .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead. >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes [('age', 'bigint'), ('name', 'string')] """ warnings.warn("jsonFile is deprecated. Use read.json() instead.") if schema is None: df = self._ssql_ctx.jsonFile(path, samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) df = self._ssql_ctx.jsonFile(path, scala_datatype) return DataFrame(df, self)
@ignore_unicode_prefix @since(1.0)
[docs] def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. If the schema is provided, applies the given schema to this JSON dataset. Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. >>> df1 = sqlContext.jsonRDD(json) >>> df1.first() Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) >>> df2 = sqlContext.jsonRDD(json, df1.schema) >>> df2.first() Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) >>> from pyspark.sql.types import * >>> schema = StructType([ ... StructField("field2", StringType()), ... StructField("field3", ... StructType([StructField("field5", ArrayType(IntegerType()))])) ... ]) >>> df3 = sqlContext.jsonRDD(json, schema) >>> df3.first() Row(field2=u'row1', field3=Row(field5=None)) """ def func(iterator): for x in iterator: if not isinstance(x, basestring): x = unicode(x) if isinstance(x, unicode): x = x.encode("utf-8") yield x keyed = rdd.mapPartitions(func) keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return DataFrame(df, self)
[docs] def load(self, path=None, source=None, schema=None, **options): """Returns the dataset in a data source as a :class:`DataFrame`. .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead. """ warnings.warn("load is deprecated. Use read.load() instead.") return self.read.load(path, source, schema, **options)
@since(1.3)
[docs] def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): """Creates an external table based on the dataset in a data source. It returns the DataFrame associated with the external table. The data source is specified by the ``source`` and a set of ``options``. If ``source`` is not specified, the default data source configured by ``spark.sql.sources.default`` will be used. Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. :return: :class:`DataFrame` """ if path is not None: options["path"] = path if source is None: source = self.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") if schema is None: df = self._ssql_ctx.createExternalTable(tableName, source, options) else: if not isinstance(schema, StructType): raise TypeError("schema should be StructType") scala_datatype = self._ssql_ctx.parseDataType(schema.json()) df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, options) return DataFrame(df, self)
@ignore_unicode_prefix @since(1.0)
[docs] def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. :return: :class:`DataFrame` >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
@since(1.0)
[docs] def table(self, tableName): """Returns the specified table as a :class:`DataFrame`. :return: :class:`DataFrame` >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) True """ return DataFrame(self._ssql_ctx.table(tableName), self)
@ignore_unicode_prefix @since(1.3)
[docs] def tables(self, dbName=None): """Returns a :class:`DataFrame` containing names of tables in the given database. If ``dbName`` is not specified, the current database will be used. The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` (a column with :class:`BooleanType` indicating if a table is a temporary one or not). :param dbName: string, name of the database to use. :return: :class:`DataFrame` >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() Row(tableName=u'table1', isTemporary=True) """ if dbName is None: return DataFrame(self._ssql_ctx.tables(), self) else: return DataFrame(self._ssql_ctx.tables(dbName), self)
@since(1.3)
[docs] def tableNames(self, dbName=None): """Returns a list of names of tables in the database ``dbName``. :param dbName: string, name of the database to use. Default to the current database. :return: list of table names, in string >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> "table1" in sqlContext.tableNames() True >>> "table1" in sqlContext.tableNames("db") True """ if dbName is None: return [name for name in self._ssql_ctx.tableNames()] else: return [name for name in self._ssql_ctx.tableNames(dbName)]
@since(1.0)
[docs] def cacheTable(self, tableName): """Caches the specified table in-memory.""" self._ssql_ctx.cacheTable(tableName)
@since(1.0)
[docs] def uncacheTable(self, tableName): """Removes the specified table from the in-memory cache.""" self._ssql_ctx.uncacheTable(tableName)
@since(1.3)
[docs] def clearCache(self): """Removes all cached tables from the in-memory cache. """ self._ssql_ctx.clearCache()
@property @since(1.4) def read(self): """ Returns a :class:`DataFrameReader` that can be used to read data in as a :class:`DataFrame`. :return: :class:`DataFrameReader` """ return DataFrameReader(self)
[docs]class HiveContext(SQLContext): """A variant of Spark SQL that integrates with data stored in Hive. Configuration for Hive is read from ``hive-site.xml`` on the classpath. It supports running both SQL and HiveQL commands. :param sparkContext: The SparkContext to wrap. :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new :class:`HiveContext` in the JVM, instead we make all calls to this object. """ def __init__(self, sparkContext, hiveContext=None): SQLContext.__init__(self, sparkContext) if hiveContext: self._scala_HiveContext = hiveContext @property def _ssql_ctx(self): try: if not hasattr(self, '_scala_HiveContext'): self._scala_HiveContext = self._get_hive_ctx() return self._scala_HiveContext except Py4JError as e: raise Exception("You must build Spark with Hive. " "Export 'SPARK_HIVE=true' and run " "build/sbt assembly", e) def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc())
[docs] def refreshTable(self, tableName): """Invalidate and refresh all the cached the metadata of the given table. For performance reasons, Spark SQL or the external data source library it uses might cache certain metadata about a table, such as the location of blocks. When those change outside of Spark SQL, users should call this function to invalidate the cache. """ self._ssql_ctx.refreshTable(tableName)
class UDFRegistration(object): """Wrapper for user-defined function registration.""" def __init__(self, sqlContext): self.sqlContext = sqlContext def register(self, name, f, returnType=StringType()): return self.sqlContext.registerFunction(name, f, returnType) register.__doc__ = SQLContext.registerFunction.__doc__ def _test(): import os import doctest from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.context os.chdir(os.environ["SPARK_HOME"]) globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) globs['rdd'] = rdd = sc.parallelize( [Row(field1=1, field2="row1"), Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) globs['df'] = rdd.toDF() jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' '"field6":[{"field7": "row2"}]}', '{"field1" : null, "field2": "row3", ' '"field3":{"field4":33, "field5": []}}' ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) (failure_count, test_count) = doctest.testmod( pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: exit(-1) if __name__ == "__main__": _test()