pyspark.sql module

Module Context

Important classes of Spark SQL and DataFrames:

class pyspark.sql.SQLContext(sparkContext, sqlContext=None)

Main entry point for Spark SQL functionality.

A SQLContext can be used create DataFrame, register DataFrame as tables, execute SQL over tables, cache tables, and read parquet files.

When created, SQLContext adds a method called toDF to RDD, which could be used to convert an RDD into a DataFrame, it’s a shorthand for SQLContext.createDataFrame().

Parameters:
  • sparkContext – The SparkContext backing this SQLContext.
  • sqlContext – An optional JVM Scala SQLContext. If set, we do not instantiate a new SQLContext in the JVM, instead we make all calls to this object.
applySchema(rdd, schema)

::note: Deprecated in 1.3, use createDataFrame() instead.

cacheTable(tableName)

Caches the specified table in-memory.

clearCache()

Removes all cached tables from the in-memory cache.

createDataFrame(data, schema=None, samplingRatio=None)

Creates a DataFrame from an RDD of tuple/list, list or pandas.DataFrame.

When schema is a list of column names, the type of each column will be inferred from data.

When schema is None, it will try to infer the schema (column names and types) from data, which should be an RDD of Row, or namedtuple, or dict.

If schema inference is needed, samplingRatio is used to determined the ratio of rows used for schema inference. The first row will be used if samplingRatio is None.

Parameters:
  • data – an RDD of Row/tuple/list/dict, list, or pandas.DataFrame.
  • schema – a StructType or list of column names. default None.
  • samplingRatio – the sample ratio of rows used for inferring
>>> l = [('Alice', 1)]
>>> sqlContext.createDataFrame(l).collect()
[Row(_1=u'Alice', _2=1)]
>>> sqlContext.createDataFrame(l, ['name', 'age']).collect()
[Row(name=u'Alice', age=1)]
>>> d = [{'name': 'Alice', 'age': 1}]
>>> sqlContext.createDataFrame(d).collect()
[Row(age=1, name=u'Alice')]
>>> rdd = sc.parallelize(l)
>>> sqlContext.createDataFrame(rdd).collect()
[Row(_1=u'Alice', _2=1)]
>>> df = sqlContext.createDataFrame(rdd, ['name', 'age'])
>>> df.collect()
[Row(name=u'Alice', age=1)]
>>> from pyspark.sql import Row
>>> Person = Row('name', 'age')
>>> person = rdd.map(lambda r: Person(*r))
>>> df2 = sqlContext.createDataFrame(person)
>>> df2.collect()
[Row(name=u'Alice', age=1)]
>>> from pyspark.sql.types import *
>>> schema = StructType([
...    StructField("name", StringType(), True),
...    StructField("age", IntegerType(), True)])
>>> df3 = sqlContext.createDataFrame(rdd, schema)
>>> df3.collect()
[Row(name=u'Alice', age=1)]
>>> sqlContext.createDataFrame(df.toPandas()).collect()  
[Row(name=u'Alice', age=1)]
createExternalTable(tableName, path=None, source=None, schema=None, **options)

Creates an external table based on the dataset in a data source.

It returns the DataFrame associated with the external table.

The data source is specified by the source and a set of options. If source is not specified, the default data source configured by spark.sql.sources.default will be used.

Optionally, a schema can be provided as the schema of the returned DataFrame and created external table.

getConf(key, defaultValue)

Returns the value of Spark SQL configuration property for the given key.

If the key is not set, returns defaultValue.

inferSchema(rdd, samplingRatio=None)

::note: Deprecated in 1.3, use createDataFrame() instead.

jsonFile(path, schema=None, samplingRatio=1.0)

Loads a text file storing one JSON object per line as a DataFrame.

If the schema is provided, applies the given schema to this JSON dataset. Otherwise, it samples the dataset with ratio samplingRatio to determine the schema.

>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
>>> shutil.rmtree(jsonFile)
>>> with open(jsonFile, 'w') as f:
...     f.writelines(jsonStrings)
>>> df1 = sqlContext.jsonFile(jsonFile)
>>> df1.printSchema()
root
 |-- field1: long (nullable = true)
 |-- field2: string (nullable = true)
 |-- field3: struct (nullable = true)
 |    |-- field4: long (nullable = true)
>>> from pyspark.sql.types import *
>>> schema = StructType([
...     StructField("field2", StringType()),
...     StructField("field3",
...         StructType([StructField("field5", ArrayType(IntegerType()))]))])
>>> df2 = sqlContext.jsonFile(jsonFile, schema)
>>> df2.printSchema()
root
 |-- field2: string (nullable = true)
 |-- field3: struct (nullable = true)
 |    |-- field5: array (nullable = true)
 |    |    |-- element: integer (containsNull = true)
jsonRDD(rdd, schema=None, samplingRatio=1.0)

Loads an RDD storing one JSON object per string as a DataFrame.

If the schema is provided, applies the given schema to this JSON dataset. Otherwise, it samples the dataset with ratio samplingRatio to determine the schema.

>>> df1 = sqlContext.jsonRDD(json)
>>> df1.first()
Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None)
>>> df2 = sqlContext.jsonRDD(json, df1.schema)
>>> df2.first()
Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None)
>>> from pyspark.sql.types import *
>>> schema = StructType([
...     StructField("field2", StringType()),
...     StructField("field3",
...                 StructType([StructField("field5", ArrayType(IntegerType()))]))
... ])
>>> df3 = sqlContext.jsonRDD(json, schema)
>>> df3.first()
Row(field2=u'row1', field3=Row(field5=None))
load(path=None, source=None, schema=None, **options)

Returns the dataset in a data source as a DataFrame.

The data source is specified by the source and a set of options. If source is not specified, the default data source configured by spark.sql.sources.default will be used.

Optionally, a schema can be provided as the schema of the returned DataFrame.

parquetFile(*paths)

Loads a Parquet file, returning the result as a DataFrame.

>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
>>> df.saveAsParquetFile(parquetFile)
>>> df2 = sqlContext.parquetFile(parquetFile)
>>> sorted(df.collect()) == sorted(df2.collect())
True
registerDataFrameAsTable(df, tableName)

Registers the given DataFrame as a temporary table in the catalog.

Temporary tables exist only during the lifetime of this instance of SQLContext.

>>> sqlContext.registerDataFrameAsTable(df, "table1")
registerFunction(name, f, returnType=StringType)

Registers a lambda function as a UDF so it can be used in SQL statements.

In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type.

Parameters:
  • name – name of the UDF
  • samplingRatio – lambda function
  • returnType – a DataType object
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
[Row(c0=u'4')]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
setConf(key, value)

Sets the given Spark SQL configuration property.

sql(sqlQuery)

Returns a DataFrame representing the result of the given query.

>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
table(tableName)

Returns the specified table as a DataFrame.

>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
True
tableNames(dbName=None)

Returns a list of names of tables in the database dbName.

If dbName is not specified, the current database will be used.

>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> "table1" in sqlContext.tableNames()
True
>>> "table1" in sqlContext.tableNames("db")
True
tables(dbName=None)

Returns a DataFrame containing names of tables in the given database.

If dbName is not specified, the current database will be used.

The returned DataFrame has two columns: tableName and isTemporary (a column with BooleanType indicating if a table is a temporary one or not).

>>> sqlContext.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlContext.tables()
>>> df2.filter("tableName = 'table1'").first()
Row(tableName=u'table1', isTemporary=True)
udf

Returns a UDFRegistration for UDF registration.

uncacheTable(tableName)

Removes the specified table from the in-memory cache.

class pyspark.sql.HiveContext(sparkContext, hiveContext=None)

A variant of Spark SQL that integrates with data stored in Hive.

Configuration for Hive is read from hive-site.xml on the classpath. It supports running both SQL and HiveQL commands.

Parameters:
  • sparkContext – The SparkContext to wrap.
  • hiveContext – An optional JVM Scala HiveContext. If set, we do not instantiate a new HiveContext in the JVM, instead we make all calls to this object.
class pyspark.sql.DataFrame(jdf, sql_ctx)

A distributed collection of data grouped into named columns.

A DataFrame is equivalent to a relational table in Spark SQL, and can be created using various functions in SQLContext:

people = sqlContext.parquetFile("...")

Once created, it can be manipulated using the various domain-specific-language (DSL) functions defined in: DataFrame, Column.

To select a column from the data frame, use the apply method:

ageCol = people.age

A more concrete example:

# To create DataFrame using SQLContext
people = sqlContext.parquetFile("...")
department = sqlContext.parquetFile("...")

people.filter(people.age > 30).join(department, people.deptId == department.id))           .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
agg(*exprs)

Aggregate on the entire DataFrame without groups (shorthand for df.groupBy.agg()).

>>> df.agg({"age": "max"}).collect()
[Row(MAX(age)=5)]
>>> from pyspark.sql import functions as F
>>> df.agg(F.min(df.age)).collect()
[Row(MIN(age)=2)]
cache()

Persists with the default storage level (MEMORY_ONLY_SER).

collect()

Returns all the records as a list of Row.

>>> df.collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
columns

Returns all column names as a list.

>>> df.columns
[u'age', u'name']
count()

Returns the number of rows in this DataFrame.

>>> df.count()
2L
describe(*cols)

Computes statistics for numeric columns.

This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical columns.

>>> df.describe().show()
summary age
count   2
mean    3.5
stddev  1.5
min     2
max     5
distinct()

Returns a new DataFrame containing the distinct rows in this DataFrame.

>>> df.distinct().count()
2L
dropna(how='any', thresh=None, subset=None)

Returns a new DataFrame omitting rows with null values.

This is an alias for na.drop().

Parameters:
  • how – ‘any’ or ‘all’. If ‘any’, drop a row if it contains any nulls. If ‘all’, drop a row only if all its values are null.
  • thresh – int, default None If specified, drop rows that have less than thresh non-null values. This overwrites the how parameter.
  • subset – optional list of column names to consider.
>>> df4.dropna().show()
age height name
10  80     Alice
>>> df4.na.drop().show()
age height name
10  80     Alice
dtypes

Returns all column names and their data types as a list.

>>> df.dtypes
[('age', 'int'), ('name', 'string')]
explain(extended=False)

Prints the (logical and physical) plans to the console for debugging purpose.

Parameters:extended – boolean, default False. If False, prints only the physical plan.
>>> df.explain()
PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:...
>>> df.explain(True)
== Parsed Logical Plan ==
...
== Analyzed Logical Plan ==
...
== Optimized Logical Plan ==
...
== Physical Plan ==
...
== RDD ==
fillna(value, subset=None)

Replace null values, alias for na.fill().

Parameters:
  • value – int, long, float, string, or dict. Value to replace null values with. If the value is a dict, then subset is ignored and value must be a mapping from column name (string) to replacement value. The replacement value must be an int, long, float, or string.
  • subset – optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if value is a string, and subset contains a non-string column, then the non-string column is simply ignored.
>>> df4.fillna(50).show()
age height name
10  80     Alice
5   50     Bob
50  50     Tom
50  50     null
>>> df4.fillna({'age': 50, 'name': 'unknown'}).show()
age height name
10  80     Alice
5   null   Bob
50  null   Tom
50  null   unknown
>>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()
age height name
10  80     Alice
5   null   Bob
50  null   Tom
50  null   unknown
filter(condition)

Filters rows using the given condition.

where() is an alias for filter().

Parameters:condition – a Column of types.BooleanType or a string of SQL expression.
>>> df.filter(df.age > 3).collect()
[Row(age=5, name=u'Bob')]
>>> df.where(df.age == 2).collect()
[Row(age=2, name=u'Alice')]
>>> df.filter("age > 3").collect()
[Row(age=5, name=u'Bob')]
>>> df.where("age = 2").collect()
[Row(age=2, name=u'Alice')]
first()

Returns the first row as a Row.

>>> df.first()
Row(age=2, name=u'Alice')
flatMap(f)

Returns a new RDD by first applying the f function to each Row, and then flattening the results.

This is a shorthand for df.rdd.flatMap().

>>> df.flatMap(lambda p: p.name).collect()
[u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b']
foreach(f)

Applies the f function to all Row of this DataFrame.

This is a shorthand for df.rdd.foreach().

>>> def f(person):
...     print person.name
>>> df.foreach(f)
foreachPartition(f)

Applies the f function to each partition of this DataFrame.

This a shorthand for df.rdd.foreachPartition().

>>> def f(people):
...     for person in people:
...         print person.name
>>> df.foreachPartition(f)
groupBy(*cols)

Groups the DataFrame using the specified columns, so we can run aggregation on them. See GroupedData for all the available aggregate functions.

Parameters:cols – list of columns to group by. Each element should be a column name (string) or an expression (Column).
>>> df.groupBy().avg().collect()
[Row(AVG(age)=3.5)]
>>> df.groupBy('name').agg({'age': 'mean'}).collect()
[Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
>>> df.groupBy(df.name).avg().collect()
[Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
head(n=None)

Returns the first n rows as a list of Row, or the first Row if n is None.

>>> df.head()
Row(age=2, name=u'Alice')
>>> df.head(1)
[Row(age=2, name=u'Alice')]
insertInto(tableName, overwrite=False)

Inserts the contents of this DataFrame into the specified table.

Optionally overwriting any existing data.

intersect(other)

Return a new DataFrame containing rows only in both this frame and another frame.

This is equivalent to INTERSECT in SQL.

isLocal()

Returns True if the collect() and take() methods can be run locally (without any Spark executors).

join(other, joinExprs=None, joinType=None)

Joins with another DataFrame, using the given join expression.

The following performs a full outer join between df1 and df2.

Parameters:
  • other – Right side of the join
  • joinExprs – Join expression
  • joinType – str, default ‘inner’. One of inner, outer, left_outer, right_outer, semijoin.
>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
[Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
limit(num)

Limits the result count to the number specified.

>>> df.limit(1).collect()
[Row(age=2, name=u'Alice')]
>>> df.limit(0).collect()
[]
map(f)

Returns a new RDD by applying a the f function to each Row.

This is a shorthand for df.rdd.map().

>>> df.map(lambda p: p.name).collect()
[u'Alice', u'Bob']
mapPartitions(f, preservesPartitioning=False)

Returns a new RDD by applying the f function to each partition.

This is a shorthand for df.rdd.mapPartitions().

>>> rdd = sc.parallelize([1, 2, 3, 4], 4)
>>> def f(iterator): yield 1
>>> rdd.mapPartitions(f).sum()
4
na

Returns a DataFrameNaFunctions for handling missing values.

orderBy(*cols)

Returns a new DataFrame sorted by the specified column(s).

Parameters:cols – list of Column to sort by.
>>> df.sort(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.orderBy(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> from pyspark.sql.functions import *
>>> df.sort(asc("age")).collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.orderBy(desc("age"), "name").collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
persist(storageLevel=StorageLevel(False, True, False, False, 1))

Sets the storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. If no storage level is specified defaults to (MEMORY_ONLY_SER).

printSchema()

Prints out the schema in the tree format.

>>> df.printSchema()
root
 |-- age: integer (nullable = true)
 |-- name: string (nullable = true)
rdd

Returns the content as an pyspark.RDD of Row.

registerAsTable(name)

DEPRECATED: use registerTempTable() instead

registerTempTable(name)

Registers this RDD as a temporary table using the given name.

The lifetime of this temporary table is tied to the SQLContext that was used to create this DataFrame.

>>> df.registerTempTable("people")
>>> df2 = sqlContext.sql("select * from people")
>>> sorted(df.collect()) == sorted(df2.collect())
True
repartition(numPartitions)

Returns a new DataFrame that has exactly numPartitions partitions.

>>> df.repartition(10).rdd.getNumPartitions()
10
sample(withReplacement, fraction, seed=None)

Returns a sampled subset of this DataFrame.

>>> df.sample(False, 0.5, 97).count()
1L
save(path=None, source=None, mode='error', **options)

Saves the contents of the DataFrame to a data source.

The data source is specified by the source and a set of options. If source is not specified, the default data source configured by spark.sql.sources.default will be used.

Additionally, mode is used to specify the behavior of the save operation when data already exists in the data source. There are four modes:

  • append: Append contents of this DataFrame to existing data.
  • overwrite: Overwrite existing data.
  • error: Throw an exception if data already exists.
  • ignore: Silently ignore this operation if data already exists.
saveAsParquetFile(path)

Saves the contents as a Parquet file, preserving the schema.

Files that are written out using this method can be read back in as a DataFrame using SQLContext.parquetFile().

>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
>>> df.saveAsParquetFile(parquetFile)
>>> df2 = sqlContext.parquetFile(parquetFile)
>>> sorted(df2.collect()) == sorted(df.collect())
True
saveAsTable(tableName, source=None, mode='error', **options)

Saves the contents of this DataFrame to a data source as a table.

The data source is specified by the source and a set of options. If source is not specified, the default data source configured by spark.sql.sources.default will be used.

Additionally, mode is used to specify the behavior of the saveAsTable operation when table already exists in the data source. There are four modes:

  • append: Append contents of this DataFrame to existing data.
  • overwrite: Overwrite existing data.
  • error: Throw an exception if data already exists.
  • ignore: Silently ignore this operation if data already exists.
schema

Returns the schema of this DataFrame as a types.StructType.

>>> df.schema
StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
select(*cols)

Projects a set of expressions and returns a new DataFrame.

Parameters:cols – list of column names (string) or expressions (Column). If one of the column names is ‘*’, that column is expanded to include all columns in the current DataFrame.
>>> df.select('*').collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.select('name', 'age').collect()
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
selectExpr(*expr)

Projects a set of SQL expressions and returns a new DataFrame.

This is a variant of select() that accepts SQL expressions.

>>> df.selectExpr("age * 2", "abs(age)").collect()
[Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
show(n=20)

Prints the first n rows to the console.

>>> df
DataFrame[age: int, name: string]
>>> df.show()
age name
2   Alice
5   Bob
sort(*cols)

Returns a new DataFrame sorted by the specified column(s).

Parameters:cols – list of Column to sort by.
>>> df.sort(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.orderBy(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> from pyspark.sql.functions import *
>>> df.sort(asc("age")).collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.orderBy(desc("age"), "name").collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
subtract(other)

Return a new DataFrame containing rows in this frame but not in another frame.

This is equivalent to EXCEPT in SQL.

take(num)

Returns the first num rows as a list of Row.

>>> df.take(2)
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
toJSON(use_unicode=False)

Converts a DataFrame into a RDD of string.

Each row is turned into a JSON document as one element in the returned RDD.

>>> df.toJSON().first()
'{"age":2,"name":"Alice"}'
toPandas()

Returns the contents of this DataFrame as Pandas pandas.DataFrame.

This is only available if Pandas is installed and available.

>>> df.toPandas()  
   age   name
0    2  Alice
1    5    Bob
unionAll(other)

Return a new DataFrame containing union of rows in this frame and another frame.

This is equivalent to UNION ALL in SQL.

unpersist(blocking=True)

Marks the DataFrame as non-persistent, and remove all blocks for it from memory and disk.

where(condition)

Filters rows using the given condition.

where() is an alias for filter().

Parameters:condition – a Column of types.BooleanType or a string of SQL expression.
>>> df.filter(df.age > 3).collect()
[Row(age=5, name=u'Bob')]
>>> df.where(df.age == 2).collect()
[Row(age=2, name=u'Alice')]
>>> df.filter("age > 3").collect()
[Row(age=5, name=u'Bob')]
>>> df.where("age = 2").collect()
[Row(age=2, name=u'Alice')]
withColumn(colName, col)

Returns a new DataFrame by adding a column.

Parameters:
  • colName – string, name of the new column.
  • col – a Column expression for the new column.
>>> df.withColumn('age2', df.age + 2).collect()
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
withColumnRenamed(existing, new)

REturns a new DataFrame by renaming an existing column.

Parameters:
  • existing – string, name of the existing column to rename.
  • col – string, new name of the column.
>>> df.withColumnRenamed('age', 'age2').collect()
[Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
class pyspark.sql.GroupedData(jdf, sql_ctx)

A set of methods for aggregations on a DataFrame, created by DataFrame.groupBy().

agg(*exprs)

Compute aggregates and returns the result as a DataFrame.

The available aggregate functions are avg, max, min, sum, count.

If exprs is a single dict mapping from string to string, then the key is the column to perform aggregation on, and the value is the aggregate function.

Alternatively, exprs can also be a list of aggregate Column expressions.

Parameters:exprs – a dict mapping from column name (string) to aggregate functions (string), or a list of Column.
>>> gdf = df.groupBy(df.name)
>>> gdf.agg({"*": "count"}).collect()
[Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]
>>> from pyspark.sql import functions as F
>>> gdf.agg(F.min(df.age)).collect()
[Row(MIN(age)=5), Row(MIN(age)=2)]
avg(*args)

Computes average values for each numeric columns for each group.

mean() is an alias for avg().

Parameters:cols – list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().avg('age').collect()
[Row(AVG(age)=3.5)]
>>> df3.groupBy().avg('age', 'height').collect()
[Row(AVG(age)=3.5, AVG(height)=82.5)]
count()

Counts the number of records for each group.

>>> df.groupBy(df.age).count().collect()
[Row(age=2, count=1), Row(age=5, count=1)]
max(*args)

Computes the max value for each numeric columns for each group.

>>> df.groupBy().max('age').collect()
[Row(MAX(age)=5)]
>>> df3.groupBy().max('age', 'height').collect()
[Row(MAX(age)=5, MAX(height)=85)]
mean(*args)

Computes average values for each numeric columns for each group.

mean() is an alias for avg().

Parameters:cols – list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().mean('age').collect()
[Row(AVG(age)=3.5)]
>>> df3.groupBy().mean('age', 'height').collect()
[Row(AVG(age)=3.5, AVG(height)=82.5)]
min(*args)

Computes the min value for each numeric column for each group.

Parameters:cols – list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().min('age').collect()
[Row(MIN(age)=2)]
>>> df3.groupBy().min('age', 'height').collect()
[Row(MIN(age)=2, MIN(height)=80)]
sum(*args)

Compute the sum for each numeric columns for each group.

Parameters:cols – list of column names (string). Non-numeric columns are ignored.
>>> df.groupBy().sum('age').collect()
[Row(SUM(age)=7)]
>>> df3.groupBy().sum('age', 'height').collect()
[Row(SUM(age)=7, SUM(height)=165)]
class pyspark.sql.Column(jc)

A column in a DataFrame.

Column instances can be created by:

# 1. Select a column out of a DataFrame

df.colName
df["colName"]

# 2. Create from an expression
df.colName + 1
1 / df.colName
alias(alias)

Return a alias for this column

>>> df.select(df.age.alias("age2")).collect()
[Row(age2=2), Row(age2=5)]
asc()

Returns a sort expression based on the ascending order of the given column name.

cast(dataType)

Convert the column into type dataType

>>> df.select(df.age.cast("string").alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
desc()

Returns a sort expression based on the descending order of the given column name.

endswith(other)

binary operator

getField(other)

An expression that gets a field by name in a StructField.

inSet(*cols)

A boolean expression that is evaluated to true if the value of this expression is contained by the evaluated values of the arguments.

>>> df[df.name.inSet("Bob", "Mike")].collect()
[Row(age=5, name=u'Bob')]
>>> df[df.age.inSet([1, 2, 3])].collect()
[Row(age=2, name=u'Alice')]
isNotNull()

True if the current expression is not null.

isNull()

True if the current expression is null.

like(other)

binary operator

rlike(other)

binary operator

startswith(other)

binary operator

substr(startPos, length)

Return a Column which is a substring of the column

Parameters:
  • startPos – start position (int or Column)
  • length – length of the substring (int or Column)
>>> df.select(df.name.substr(1, 3).alias("col")).collect()
[Row(col=u'Ali'), Row(col=u'Bob')]
class pyspark.sql.Row

A row in DataFrame. The fields in it can be accessed like attributes.

Row can be used to create a row object by using named arguments, the fields will be sorted by names.

>>> row = Row(name="Alice", age=11)
>>> row
Row(age=11, name='Alice')
>>> row.name, row.age
('Alice', 11)

Row also can be used to create another Row like class, then it could be used to create Row objects, such as

>>> Person = Row("name", "age")
>>> Person
<Row(name, age)>
>>> Person("Alice", 11)
Row(name='Alice', age=11)
asDict()

Return as an dict

class pyspark.sql.DataFrameNaFunctions(df)

Functionality for working with missing data in DataFrame.

drop(how='any', thresh=None, subset=None)

Returns a new DataFrame omitting rows with null values.

This is an alias for na.drop().

Parameters:
  • how – ‘any’ or ‘all’. If ‘any’, drop a row if it contains any nulls. If ‘all’, drop a row only if all its values are null.
  • thresh – int, default None If specified, drop rows that have less than thresh non-null values. This overwrites the how parameter.
  • subset – optional list of column names to consider.
>>> df4.dropna().show()
age height name
10  80     Alice
>>> df4.na.drop().show()
age height name
10  80     Alice
fill(value, subset=None)

Replace null values, alias for na.fill().

Parameters:
  • value – int, long, float, string, or dict. Value to replace null values with. If the value is a dict, then subset is ignored and value must be a mapping from column name (string) to replacement value. The replacement value must be an int, long, float, or string.
  • subset – optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if value is a string, and subset contains a non-string column, then the non-string column is simply ignored.
>>> df4.fillna(50).show()
age height name
10  80     Alice
5   50     Bob
50  50     Tom
50  50     null
>>> df4.fillna({'age': 50, 'name': 'unknown'}).show()
age height name
10  80     Alice
5   null   Bob
50  null   Tom
50  null   unknown
>>> df4.na.fill({'age': 50, 'name': 'unknown'}).show()
age height name
10  80     Alice
5   null   Bob
50  null   Tom
50  null   unknown

pyspark.sql.types module

class pyspark.sql.types.DataType[source]

Base class for data types.

json()[source]
jsonValue()[source]
simpleString()[source]
classmethod typeName()[source]
class pyspark.sql.types.NullType[source]

Null type.

The data type representing None, used for the types that cannot be inferred.

class pyspark.sql.types.StringType[source]

String data type.

class pyspark.sql.types.BinaryType[source]

Binary (byte array) data type.

class pyspark.sql.types.BooleanType[source]

Boolean data type.

class pyspark.sql.types.DateType[source]

Date (datetime.date) data type.

class pyspark.sql.types.TimestampType[source]

Timestamp (datetime.datetime) data type.

class pyspark.sql.types.DecimalType(precision=None, scale=None)[source]

Decimal (decimal.Decimal) data type.

jsonValue()[source]
simpleString()[source]
class pyspark.sql.types.DoubleType[source]

Double data type, representing double precision floats.

class pyspark.sql.types.FloatType[source]

Float data type, representing single precision floats.

class pyspark.sql.types.ByteType[source]

Byte data type, i.e. a signed integer in a single byte.

simpleString()[source]
class pyspark.sql.types.IntegerType[source]

Int data type, i.e. a signed 32-bit integer.

simpleString()[source]
class pyspark.sql.types.LongType[source]

Long data type, i.e. a signed 64-bit integer.

If the values are beyond the range of [-9223372036854775808, 9223372036854775807], please use DecimalType.

simpleString()[source]
class pyspark.sql.types.ShortType[source]

Short data type, i.e. a signed 16-bit integer.

simpleString()[source]
class pyspark.sql.types.ArrayType(elementType, containsNull=True)[source]

Array data type.

Parameters:
  • elementTypeDataType of each element in the array.
  • containsNull – boolean, whether the array can contain null (None) values.
classmethod fromJson(json)[source]
jsonValue()[source]
simpleString()[source]
class pyspark.sql.types.MapType(keyType, valueType, valueContainsNull=True)[source]

Map data type.

Parameters:
  • keyTypeDataType of the keys in the map.
  • valueTypeDataType of the values in the map.
  • valueContainsNull – indicates whether values can contain null (None) values.

Keys in a map data type are not allowed to be null (None).

classmethod fromJson(json)[source]
jsonValue()[source]
simpleString()[source]
class pyspark.sql.types.StructField(name, dataType, nullable=True, metadata=None)[source]

A field in StructType.

Parameters:
  • name – string, name of the field.
  • dataTypeDataType of the field.
  • nullable – boolean, whether the field can be null (None) or not.
  • metadata – a dict from string to simple type that can be serialized to JSON automatically
classmethod fromJson(json)[source]
jsonValue()[source]
simpleString()[source]
class pyspark.sql.types.StructType(fields)[source]

Struct type, consisting of a list of StructField.

This is the data type representing a Row.

classmethod fromJson(json)[source]
jsonValue()[source]
simpleString()[source]

pyspark.sql.functions module

A collections of builtin functions

pyspark.sql.functions.abs(col)

Computes the absolutle value.

pyspark.sql.functions.approxCountDistinct(col, rsd=None)[source]

Returns a new Column for approximate distinct count of col.

>>> df.agg(approxCountDistinct(df.age).alias('c')).collect()
[Row(c=2)]
pyspark.sql.functions.asc(col)

Returns a sort expression based on the ascending order of the given column name.

pyspark.sql.functions.avg(col)

Aggregate function: returns the average of the values in a group.

pyspark.sql.functions.col(col)

Returns a Column based on the given column name.

pyspark.sql.functions.column(col)

Returns a Column based on the given column name.

pyspark.sql.functions.count(col)

Aggregate function: returns the number of items in a group.

pyspark.sql.functions.countDistinct(col, *cols)[source]

Returns a new Column for distinct count of col or cols.

>>> df.agg(countDistinct(df.age, df.name).alias('c')).collect()
[Row(c=2)]
>>> df.agg(countDistinct("age", "name").alias('c')).collect()
[Row(c=2)]
pyspark.sql.functions.desc(col)

Returns a sort expression based on the descending order of the given column name.

pyspark.sql.functions.first(col)

Aggregate function: returns the first value in a group.

pyspark.sql.functions.last(col)

Aggregate function: returns the last value in a group.

pyspark.sql.functions.lit(col)

Creates a Column of literal value.

pyspark.sql.functions.lower(col)

Converts a string expression to upper case.

pyspark.sql.functions.max(col)

Aggregate function: returns the maximum value of the expression in a group.

pyspark.sql.functions.mean(col)

Aggregate function: returns the average of the values in a group.

pyspark.sql.functions.min(col)

Aggregate function: returns the minimum value of the expression in a group.

pyspark.sql.functions.sqrt(col)

Computes the square root of the specified float value.

pyspark.sql.functions.sum(col)

Aggregate function: returns the sum of all values in the expression.

pyspark.sql.functions.sumDistinct(col)

Aggregate function: returns the sum of distinct values in the expression.

pyspark.sql.functions.udf(f, returnType=StringType)[source]

Creates a Column expression representing a user defined function (UDF).

>>> from pyspark.sql.types import IntegerType
>>> slen = udf(lambda s: len(s), IntegerType())
>>> df.select(slen(df.name).alias('slen')).collect()
[Row(slen=5), Row(slen=3)]
pyspark.sql.functions.upper(col)

Converts a string expression to upper case.

Table Of Contents

Previous topic

pyspark.mllib package

Next topic

pyspark.streaming module

This Page