Source code for pyspark.sql.types

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys
import decimal
import time
import datetime
import keyword
import warnings
import json
import re
import weakref
from array import array
from operator import itemgetter

if sys.version >= "3":
    long = int
    unicode = str

from py4j.protocol import register_input_converter
from py4j.java_gateway import JavaClass

__all__ = [
    "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
    "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
    "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]


[docs]class DataType(object): """Base class for data types.""" def __repr__(self): return self.__class__.__name__ def __hash__(self): return hash(str(self)) def __eq__(self, other): return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) @classmethod
[docs] def typeName(cls): return cls.__name__[:-4].lower()
[docs] def simpleString(self): return self.typeName()
[docs] def jsonValue(self): return self.typeName()
[docs] def json(self): return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle
class DataTypeSingleton(type): """Metaclass for DataType""" _instances = {} def __call__(cls): if cls not in cls._instances: cls._instances[cls] = super(DataTypeSingleton, cls).__call__() return cls._instances[cls]
[docs]class NullType(DataType): """Null type. The data type representing None, used for the types that cannot be inferred. """ __metaclass__ = DataTypeSingleton
class AtomicType(DataType): """An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps.""" class NumericType(AtomicType): """Numeric data types. """ class IntegralType(NumericType): """Integral data types. """ __metaclass__ = DataTypeSingleton class FractionalType(NumericType): """Fractional data types. """
[docs]class StringType(AtomicType): """String data type. """ __metaclass__ = DataTypeSingleton
[docs]class BinaryType(AtomicType): """Binary (byte array) data type. """ __metaclass__ = DataTypeSingleton
[docs]class BooleanType(AtomicType): """Boolean data type. """ __metaclass__ = DataTypeSingleton
[docs]class DateType(AtomicType): """Date (datetime.date) data type. """ __metaclass__ = DataTypeSingleton
[docs]class TimestampType(AtomicType): """Timestamp (datetime.datetime) data type. """ __metaclass__ = DataTypeSingleton
[docs]class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. """ def __init__(self, precision=None, scale=None): self.precision = precision self.scale = scale self.hasPrecisionInfo = precision is not None
[docs] def simpleString(self): if self.hasPrecisionInfo: return "decimal(%d,%d)" % (self.precision, self.scale) else: return "decimal(10,0)"
[docs] def jsonValue(self): if self.hasPrecisionInfo: return "decimal(%d,%d)" % (self.precision, self.scale) else: return "decimal"
def __repr__(self): if self.hasPrecisionInfo: return "DecimalType(%d,%d)" % (self.precision, self.scale) else: return "DecimalType()"
[docs]class DoubleType(FractionalType): """Double data type, representing double precision floats. """ __metaclass__ = DataTypeSingleton
[docs]class FloatType(FractionalType): """Float data type, representing single precision floats. """ __metaclass__ = DataTypeSingleton
[docs]class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte. """
[docs] def simpleString(self): return 'tinyint'
[docs]class IntegerType(IntegralType): """Int data type, i.e. a signed 32-bit integer. """
[docs] def simpleString(self): return 'int'
[docs]class LongType(IntegralType): """Long data type, i.e. a signed 64-bit integer. If the values are beyond the range of [-9223372036854775808, 9223372036854775807], please use :class:`DecimalType`. """
[docs] def simpleString(self): return 'bigint'
[docs]class ShortType(IntegralType): """Short data type, i.e. a signed 16-bit integer. """
[docs] def simpleString(self): return 'smallint'
[docs]class ArrayType(DataType): """Array data type. :param elementType: :class:`DataType` of each element in the array. :param containsNull: boolean, whether the array can contain null (None) values. """ def __init__(self, elementType, containsNull=True): """ >>> ArrayType(StringType()) == ArrayType(StringType(), True) True >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ assert isinstance(elementType, DataType), "elementType should be DataType" self.elementType = elementType self.containsNull = containsNull
[docs] def simpleString(self): return 'array<%s>' % self.elementType.simpleString()
def __repr__(self): return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower())
[docs] def jsonValue(self): return {"type": self.typeName(), "elementType": self.elementType.jsonValue(), "containsNull": self.containsNull}
@classmethod
[docs] def fromJson(cls, json): return ArrayType(_parse_datatype_json_value(json["elementType"]), json["containsNull"])
[docs]class MapType(DataType): """Map data type. :param keyType: :class:`DataType` of the keys in the map. :param valueType: :class:`DataType` of the values in the map. :param valueContainsNull: indicates whether values can contain null (None) values. Keys in a map data type are not allowed to be null (None). """ def __init__(self, keyType, valueType, valueContainsNull=True): """ >>> (MapType(StringType(), IntegerType()) ... == MapType(StringType(), IntegerType(), True)) True >>> (MapType(StringType(), IntegerType(), False) ... == MapType(StringType(), FloatType())) False """ assert isinstance(keyType, DataType), "keyType should be DataType" assert isinstance(valueType, DataType), "valueType should be DataType" self.keyType = keyType self.valueType = valueType self.valueContainsNull = valueContainsNull
[docs] def simpleString(self): return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString())
def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, str(self.valueContainsNull).lower())
[docs] def jsonValue(self): return {"type": self.typeName(), "keyType": self.keyType.jsonValue(), "valueType": self.valueType.jsonValue(), "valueContainsNull": self.valueContainsNull}
@classmethod
[docs] def fromJson(cls, json): return MapType(_parse_datatype_json_value(json["keyType"]), _parse_datatype_json_value(json["valueType"]), json["valueContainsNull"])
[docs]class StructField(DataType): """A field in :class:`StructType`. :param name: string, name of the field. :param dataType: :class:`DataType` of the field. :param nullable: boolean, whether the field can be null (None) or not. :param metadata: a dict from string to simple type that can be serialized to JSON automatically """ def __init__(self, name, dataType, nullable=True, metadata=None): """ >>> (StructField("f1", StringType(), True) ... == StructField("f1", StringType(), True)) True >>> (StructField("f1", StringType(), True) ... == StructField("f2", StringType(), True)) False """ assert isinstance(dataType, DataType), "dataType should be DataType" self.name = name self.dataType = dataType self.nullable = nullable self.metadata = metadata or {}
[docs] def simpleString(self): return '%s:%s' % (self.name, self.dataType.simpleString())
def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, str(self.nullable).lower())
[docs] def jsonValue(self): return {"name": self.name, "type": self.dataType.jsonValue(), "nullable": self.nullable, "metadata": self.metadata}
@classmethod
[docs] def fromJson(cls, json): return StructField(json["name"], _parse_datatype_json_value(json["type"]), json["nullable"], json["metadata"])
[docs]class StructType(DataType): """Struct type, consisting of a list of :class:`StructField`. This is the data type representing a :class:`Row`. """ def __init__(self, fields): """ >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True), ... StructField("f2", IntegerType(), False)]) >>> struct1 == struct2 False """ assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType" self.fields = fields
[docs] def simpleString(self): return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))
def __repr__(self): return ("StructType(List(%s))" % ",".join(str(field) for field in self.fields))
[docs] def jsonValue(self): return {"type": self.typeName(), "fields": [f.jsonValue() for f in self.fields]}
@classmethod
[docs] def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]])
class UserDefinedType(DataType): """User-defined type (UDT). .. note:: WARN: Spark Internal Use Only """ @classmethod def typeName(cls): return cls.__name__.lower() @classmethod def sqlType(cls): """ Underlying SQL storage type for this UDT. """ raise NotImplementedError("UDT must implement sqlType().") @classmethod def module(cls): """ The Python module of the UDT. """ raise NotImplementedError("UDT must implement module().") @classmethod def scalaUDT(cls): """ The class name of the paired Scala UDT. """ raise NotImplementedError("UDT must have a paired Scala UDT.") def serialize(self, obj): """ Converts the a user-type object into a SQL datum. """ raise NotImplementedError("UDT must implement serialize().") def deserialize(self, datum): """ Converts a SQL datum into a user-type object. """ raise NotImplementedError("UDT must implement deserialize().") def simpleString(self): return 'udt' def json(self): return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) def jsonValue(self): schema = { "type": "udt", "class": self.scalaUDT(), "pyClass": "%s.%s" % (self.module(), type(self).__name__), "sqlType": self.sqlType().jsonValue() } return schema @classmethod def fromJson(cls, json): pyUDT = json["pyClass"] split = pyUDT.rfind(".") pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] m = __import__(pyModule, globals(), locals(), [pyClass]) UDT = getattr(m, pyClass) return UDT() def __eq__(self, other): return type(self) == type(other) _atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType, ByteType, ShortType, IntegerType, LongType, DateType, TimestampType] _all_atomic_types = dict((t.typeName(), t) for t in _atomic_types) _all_complex_types = dict((v.typeName(), v) for v in [ArrayType, MapType, StructType]) def _parse_datatype_json_string(json_string): """Parses the given data type JSON string. >>> import pickle >>> def check_datatype(datatype): ... pickled = pickle.loads(pickle.dumps(datatype)) ... assert datatype == pickled ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype >>> for cls in _all_atomic_types.values(): ... check_datatype(cls()) >>> # Simple ArrayType. >>> simple_arraytype = ArrayType(StringType(), True) >>> check_datatype(simple_arraytype) >>> # Simple MapType. >>> simple_maptype = MapType(StringType(), LongType()) >>> check_datatype(simple_maptype) >>> # Simple StructType. >>> simple_structtype = StructType([ ... StructField("a", DecimalType(), False), ... StructField("b", BooleanType(), True), ... StructField("c", LongType(), True), ... StructField("d", BinaryType(), False)]) >>> check_datatype(simple_structtype) >>> # Complex StructType. >>> complex_structtype = StructType([ ... StructField("simpleArray", simple_arraytype, True), ... StructField("simpleMap", simple_maptype, True), ... StructField("simpleStruct", simple_structtype, True), ... StructField("boolean", BooleanType(), False), ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) >>> check_datatype(complex_structtype) >>> # Complex ArrayType. >>> complex_arraytype = ArrayType(complex_structtype, True) >>> check_datatype(complex_arraytype) >>> # Complex MapType. >>> complex_maptype = MapType(complex_structtype, ... complex_arraytype, False) >>> check_datatype(complex_maptype) >>> check_datatype(ExamplePointUDT()) >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), ... StructField("point", ExamplePointUDT(), False)]) >>> check_datatype(structtype_with_udt) """ return _parse_datatype_json_value(json.loads(json_string)) _FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") def _parse_datatype_json_value(json_value): if not isinstance(json_value, dict): if json_value in _all_atomic_types.keys(): return _all_atomic_types[json_value]() elif json_value == 'decimal': return DecimalType() elif _FIXED_DECIMAL.match(json_value): m = _FIXED_DECIMAL.match(json_value) return DecimalType(int(m.group(1)), int(m.group(2))) else: raise ValueError("Could not parse datatype: %s" % json_value) else: tpe = json_value["type"] if tpe in _all_complex_types: return _all_complex_types[tpe].fromJson(json_value) elif tpe == 'udt': return UserDefinedType.fromJson(json_value) else: raise ValueError("not supported type: %s" % tpe) # Mapping Python types to Spark SQL DataType _type_mappings = { type(None): NullType, bool: BooleanType, int: LongType, float: DoubleType, str: StringType, bytearray: BinaryType, decimal.Decimal: DecimalType, datetime.date: DateType, datetime.datetime: TimestampType, datetime.time: TimestampType, } if sys.version < "3": _type_mappings.update({ unicode: StringType, long: LongType, }) def _infer_type(obj): """Infer the DataType from obj >>> p = ExamplePoint(1.0, 2.0) >>> _infer_type(p) ExamplePointUDT """ if obj is None: return NullType() if hasattr(obj, '__UDT__'): return obj.__UDT__ dataType = _type_mappings.get(type(obj)) if dataType is not None: return dataType() if isinstance(obj, dict): for key, value in obj.items(): if key is not None and value is not None: return MapType(_infer_type(key), _infer_type(value), True) else: return MapType(NullType(), NullType(), True) elif isinstance(obj, (list, array)): for v in obj: if v is not None: return ArrayType(_infer_type(obj[0]), True) else: return ArrayType(NullType(), True) else: try: return _infer_schema(obj) except TypeError: raise TypeError("not supported type: %s" % type(obj)) def _infer_schema(row): """Infer the schema from dict/namedtuple/object""" if isinstance(row, dict): items = sorted(row.items()) elif isinstance(row, (tuple, list)): if hasattr(row, "__fields__"): # Row items = zip(row.__fields__, tuple(row)) elif hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) else: names = ['_%d' % i for i in range(1, len(row) + 1)] items = zip(names, row) elif hasattr(row, "__dict__"): # object items = sorted(row.__dict__.items()) else: raise TypeError("Can not infer schema for type: %s" % type(row)) fields = [StructField(k, _infer_type(v), True) for k, v in items] return StructType(fields) def _need_python_to_sql_conversion(dataType): """ Checks whether we need python to sql conversion for the given type. For now, only UDTs need this conversion. >>> _need_python_to_sql_conversion(DoubleType()) False >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), ... StructField("values", ArrayType(DoubleType(), False), False)]) >>> _need_python_to_sql_conversion(schema0) False >>> _need_python_to_sql_conversion(ExamplePointUDT()) True >>> schema1 = ArrayType(ExamplePointUDT(), False) >>> _need_python_to_sql_conversion(schema1) True >>> schema2 = StructType([StructField("label", DoubleType(), False), ... StructField("point", ExamplePointUDT(), False)]) >>> _need_python_to_sql_conversion(schema2) True """ if isinstance(dataType, StructType): return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) elif isinstance(dataType, ArrayType): return _need_python_to_sql_conversion(dataType.elementType) elif isinstance(dataType, MapType): return _need_python_to_sql_conversion(dataType.keyType) or \ _need_python_to_sql_conversion(dataType.valueType) elif isinstance(dataType, UserDefinedType): return True else: return False def _python_to_sql_converter(dataType): """ Returns a converter that converts a Python object into a SQL datum for the given type. >>> conv = _python_to_sql_converter(DoubleType()) >>> conv(1.0) 1.0 >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) >>> conv([1.0, 2.0]) [1.0, 2.0] >>> conv = _python_to_sql_converter(ExamplePointUDT()) >>> conv(ExamplePoint(1.0, 2.0)) [1.0, 2.0] >>> schema = StructType([StructField("label", DoubleType(), False), ... StructField("point", ExamplePointUDT(), False)]) >>> conv = _python_to_sql_converter(schema) >>> conv((1.0, ExamplePoint(1.0, 2.0))) (1.0, [1.0, 2.0]) """ if not _need_python_to_sql_conversion(dataType): return lambda x: x if isinstance(dataType, StructType): names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) converters = [_python_to_sql_converter(t) for t in types] def converter(obj): if isinstance(obj, dict): return tuple(c(obj.get(n)) for n, c in zip(names, converters)) elif isinstance(obj, tuple): if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): return tuple(c(v) for c, v in zip(converters, obj)) elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs d = dict(obj) return tuple(c(d.get(n)) for n, c in zip(names, converters)) else: return tuple(c(v) for c, v in zip(converters, obj)) else: raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) return converter elif isinstance(dataType, ArrayType): element_converter = _python_to_sql_converter(dataType.elementType) return lambda a: [element_converter(v) for v in a] elif isinstance(dataType, MapType): key_converter = _python_to_sql_converter(dataType.keyType) value_converter = _python_to_sql_converter(dataType.valueType) return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) elif isinstance(dataType, UserDefinedType): return lambda obj: dataType.serialize(obj) else: raise ValueError("Unexpected type %r" % dataType) def _has_nulltype(dt): """ Return whether there is NullType in `dt` or not """ if isinstance(dt, StructType): return any(_has_nulltype(f.dataType) for f in dt.fields) elif isinstance(dt, ArrayType): return _has_nulltype((dt.elementType)) elif isinstance(dt, MapType): return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) else: return isinstance(dt, NullType) def _merge_type(a, b): if isinstance(a, NullType): return b elif isinstance(b, NullType): return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) # same type if isinstance(a, StructType): nfs = dict((f.name, f.dataType) for f in b.fields) fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) for f in a.fields] names = set([f.name for f in fields]) for n in nfs: if n not in names: fields.append(StructField(n, nfs[n])) return StructType(fields) elif isinstance(a, ArrayType): return ArrayType(_merge_type(a.elementType, b.elementType), True) elif isinstance(a, MapType): return MapType(_merge_type(a.keyType, b.keyType), _merge_type(a.valueType, b.valueType), True) else: return a def _need_converter(dataType): if isinstance(dataType, StructType): return True elif isinstance(dataType, ArrayType): return _need_converter(dataType.elementType) elif isinstance(dataType, MapType): return _need_converter(dataType.keyType) or _need_converter(dataType.valueType) elif isinstance(dataType, NullType): return True else: return False def _create_converter(dataType): """Create an converter to drop the names of fields in obj """ if not _need_converter(dataType): return lambda x: x if isinstance(dataType, ArrayType): conv = _create_converter(dataType.elementType) return lambda row: [conv(v) for v in row] elif isinstance(dataType, MapType): kconv = _create_converter(dataType.keyType) vconv = _create_converter(dataType.valueType) return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items()) elif isinstance(dataType, NullType): return lambda x: None elif not isinstance(dataType, StructType): return lambda x: x # dataType must be StructType names = [f.name for f in dataType.fields] converters = [_create_converter(f.dataType) for f in dataType.fields] convert_fields = any(_need_converter(f.dataType) for f in dataType.fields) def convert_struct(obj): if obj is None: return if isinstance(obj, (tuple, list)): if convert_fields: return tuple(conv(v) for v, conv in zip(obj, converters)) else: return tuple(obj) if isinstance(obj, dict): d = obj elif hasattr(obj, "__dict__"): # object d = obj.__dict__ else: raise TypeError("Unexpected obj type: %s" % type(obj)) if convert_fields: return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) else: return tuple([d.get(name) for name in names]) return convert_struct _BRACKETS = {'(': ')', '[': ']', '{': '}'} def _split_schema_abstract(s): """ split the schema abstract into fields >>> _split_schema_abstract("a b c") ['a', 'b', 'c'] >>> _split_schema_abstract("a(a b)") ['a(a b)'] >>> _split_schema_abstract("a b[] c{a b}") ['a', 'b[]', 'c{a b}'] >>> _split_schema_abstract(" ") [] """ r = [] w = '' brackets = [] for c in s: if c == ' ' and not brackets: if w: r.append(w) w = '' else: w += c if c in _BRACKETS: brackets.append(c) elif c in _BRACKETS.values(): if not brackets or c != _BRACKETS[brackets.pop()]: raise ValueError("unexpected " + c) if brackets: raise ValueError("brackets not closed: %s" % brackets) if w: r.append(w) return r def _parse_field_abstract(s): """ Parse a field in schema abstract >>> _parse_field_abstract("a") StructField(a,NullType,true) >>> _parse_field_abstract("b(c d)") StructField(b,StructType(...c,NullType,true),StructField(d... >>> _parse_field_abstract("a[]") StructField(a,ArrayType(NullType,true),true) >>> _parse_field_abstract("a{[]}") StructField(a,MapType(NullType,ArrayType(NullType,true),true),true) """ if set(_BRACKETS.keys()) & set(s): idx = min((s.index(c) for c in _BRACKETS if c in s)) name = s[:idx] return StructField(name, _parse_schema_abstract(s[idx:]), True) else: return StructField(s, NullType(), True) def _parse_schema_abstract(s): """ parse abstract into schema >>> _parse_schema_abstract("a b c") StructType...a...b...c... >>> _parse_schema_abstract("a[b c] b{}") StructType...a,ArrayType...b...c...b,MapType... >>> _parse_schema_abstract("c{} d{a b}") StructType...c,MapType...d,MapType...a...b... >>> _parse_schema_abstract("a b(t)").fields[1] StructField(b,StructType(List(StructField(t,NullType,true))),true) """ s = s.strip() if not s: return NullType() elif s.startswith('('): return _parse_schema_abstract(s[1:-1]) elif s.startswith('['): return ArrayType(_parse_schema_abstract(s[1:-1]), True) elif s.startswith('{'): return MapType(NullType(), _parse_schema_abstract(s[1:-1])) parts = _split_schema_abstract(s) fields = [_parse_field_abstract(p) for p in parts] return StructType(fields) def _infer_schema_type(obj, dataType): """ Fill the dataType with types inferred from obj >>> schema = _parse_schema_abstract("a b c d") >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) >>> _infer_schema_type(row, schema) StructType...LongType...DoubleType...StringType...DateType... >>> row = [[1], {"key": (1, 2.0)}] >>> schema = _parse_schema_abstract("a[] b{c d}") >>> _infer_schema_type(row, schema) StructType...a,ArrayType...b,MapType(StringType,...c,LongType... """ if isinstance(dataType, NullType): return _infer_type(obj) if not obj: return NullType() if isinstance(dataType, ArrayType): eType = _infer_schema_type(obj[0], dataType.elementType) return ArrayType(eType, True) elif isinstance(dataType, MapType): k, v = next(iter(obj.items())) return MapType(_infer_schema_type(k, dataType.keyType), _infer_schema_type(v, dataType.valueType)) elif isinstance(dataType, StructType): fs = dataType.fields assert len(fs) == len(obj), \ "Obj(%s) have different length with fields(%s)" % (obj, fs) fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) for o, f in zip(obj, fs)] return StructType(fields) else: raise TypeError("Unexpected dataType: %s" % type(dataType)) _acceptable_types = { BooleanType: (bool,), ByteType: (int, long), ShortType: (int, long), IntegerType: (int, long), LongType: (int, long), FloatType: (float,), DoubleType: (float,), DecimalType: (decimal.Decimal,), StringType: (str, unicode), BinaryType: (bytearray,), DateType: (datetime.date, datetime.datetime), TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), StructType: (tuple, list), } def _verify_type(obj, dataType): """ Verify the type of obj against dataType, raise an exception if they do not match. >>> _verify_type(None, StructType([])) >>> _verify_type("", StringType()) >>> _verify_type(0, LongType()) >>> _verify_type(list(range(3)), ArrayType(ShortType())) >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError:... >>> _verify_type({}, MapType(StringType(), IntegerType())) >>> _verify_type((), StructType([])) >>> _verify_type([], StructType([])) >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... ValueError:... """ # all objects are nullable if obj is None: return if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): raise ValueError("%r is not an instance of type %r" % (obj, dataType)) _verify_type(dataType.serialize(obj), dataType.sqlType()) return _type = type(dataType) assert _type in _acceptable_types, "unknown datatype: %s" % dataType # subclass of them can not be deserialized in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) if isinstance(dataType, ArrayType): for i in obj: _verify_type(i, dataType.elementType) elif isinstance(dataType, MapType): for k, v in obj.items(): _verify_type(k, dataType.keyType) _verify_type(v, dataType.valueType) elif isinstance(dataType, StructType): if len(obj) != len(dataType.fields): raise ValueError("Length of object (%d) does not match with " "length of fields (%d)" % (len(obj), len(dataType.fields))) for v, f in zip(obj, dataType.fields): _verify_type(v, f.dataType) _cached_cls = weakref.WeakValueDictionary() def _restore_object(dataType, obj): """ Restore object during unpickling. """ # use id(dataType) as key to speed up lookup in dict # Because of batched pickling, dataType will be the # same object in most cases. k = id(dataType) cls = _cached_cls.get(k) if cls is None or cls.__datatype is not dataType: # use dataType as key to avoid create multiple class cls = _cached_cls.get(dataType) if cls is None: cls = _create_cls(dataType) _cached_cls[dataType] = cls cls.__datatype = dataType _cached_cls[k] = cls return cls(obj) def _create_object(cls, v): """ Create an customized object with class `cls`. """ # datetime.date would be deserialized as datetime.datetime # from java type, so we need to set it back. if cls is datetime.date and isinstance(v, datetime.datetime): return v.date() return cls(v) if v is not None else v def _create_getter(dt, i): """ Create a getter for item `i` with schema """ cls = _create_cls(dt) def getter(self): return _create_object(cls, self[i]) return getter def _has_struct_or_date(dt): """Return whether `dt` is or has StructType/DateType in it""" if isinstance(dt, StructType): return True elif isinstance(dt, ArrayType): return _has_struct_or_date(dt.elementType) elif isinstance(dt, MapType): return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) elif isinstance(dt, DateType): return True elif isinstance(dt, UserDefinedType): return True return False def _create_properties(fields): """Create properties according to fields""" ps = {} for i, f in enumerate(fields): name = f.name if (name.startswith("__") and name.endswith("__") or keyword.iskeyword(name)): warnings.warn("field name %s can not be accessed in Python," "use position to access it instead" % name) if _has_struct_or_date(f.dataType): # delay creating object until accessing it getter = _create_getter(f.dataType, i) else: getter = itemgetter(i) ps[name] = property(getter) return ps def _create_cls(dataType): """ Create an class by dataType The created class is similar to namedtuple, but can have nested schema. >>> schema = _parse_schema_abstract("a b c") >>> row = (1, 1.0, "str") >>> schema = _infer_schema_type(row, schema) >>> obj = _create_cls(schema)(row) >>> import pickle >>> pickle.loads(pickle.dumps(obj)) Row(a=1, b=1.0, c='str') >>> row = [[1], {"key": (1, 2.0)}] >>> schema = _parse_schema_abstract("a[] b{c d}") >>> schema = _infer_schema_type(row, schema) >>> obj = _create_cls(schema)(row) >>> pickle.loads(pickle.dumps(obj)) Row(a=[1], b={'key': Row(c=1, d=2.0)}) >>> pickle.loads(pickle.dumps(obj.a)) [1] >>> pickle.loads(pickle.dumps(obj.b)) {'key': Row(c=1, d=2.0)} """ if isinstance(dataType, ArrayType): cls = _create_cls(dataType.elementType) def List(l): if l is None: return return [_create_object(cls, v) for v in l] return List elif isinstance(dataType, MapType): kcls = _create_cls(dataType.keyType) vcls = _create_cls(dataType.valueType) def Dict(d): if d is None: return return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) return Dict elif isinstance(dataType, DateType): return datetime.date elif isinstance(dataType, UserDefinedType): return lambda datum: dataType.deserialize(datum) elif not isinstance(dataType, StructType): # no wrapper for atomic types return lambda x: x class Row(tuple): """ Row in DataFrame """ __datatype = dataType __fields__ = tuple(f.name for f in dataType.fields) __slots__ = () # create property for fast access locals().update(_create_properties(dataType.fields)) def asDict(self): """ Return as a dict """ return dict((n, getattr(self, n)) for n in self.__fields__) def __repr__(self): # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) for n in self.__fields__)) def __reduce__(self): return (_restore_object, (self.__datatype, tuple(self))) return Row def _create_row(fields, values): row = Row(*values) row.__fields__ = fields return row class Row(tuple): """ A row in L{DataFrame}. The fields in it can be accessed like attributes. Row can be used to create a row object by using named arguments, the fields will be sorted by names. >>> row = Row(name="Alice", age=11) >>> row Row(age=11, name='Alice') >>> row.name, row.age ('Alice', 11) Row also can be used to create another Row like class, then it could be used to create Row objects, such as >>> Person = Row("name", "age") >>> Person <Row(name, age)> >>> Person("Alice", 11) Row(name='Alice', age=11) """ def __new__(self, *args, **kwargs): if args and kwargs: raise ValueError("Can not use both args " "and kwargs to create Row") if args: # create row class or objects return tuple.__new__(self, args) elif kwargs: # create row objects names = sorted(kwargs.keys()) row = tuple.__new__(self, [kwargs[n] for n in names]) row.__fields__ = names return row else: raise ValueError("No args or kwargs") def asDict(self): """ Return as an dict """ if not hasattr(self, "__fields__"): raise TypeError("Cannot convert a Row class into dict") return dict(zip(self.__fields__, self)) # let object acts like class def __call__(self, *args): """create new Row object""" return _create_row(self, args) def __getattr__(self, item): if item.startswith("__"): raise AttributeError(item) try: # it will be slow when it has many fields, # but this will not be used in normal cases idx = self.__fields__.index(item) return self[idx] except IndexError: raise AttributeError(item) except ValueError: raise AttributeError(item) def __reduce__(self): """Returns a tuple so Python knows how to pickle Row.""" if hasattr(self, "__fields__"): return (_create_row, (self.__fields__, tuple(self))) else: return tuple.__reduce__(self) def __repr__(self): """Printable representation of Row used in Python REPL.""" if hasattr(self, "__fields__"): return "Row(%s)" % ", ".join("%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self))) else: return "<Row(%s)>" % ", ".join(self) class DateConverter(object): def can_convert(self, obj): return isinstance(obj, datetime.date) def convert(self, obj, gateway_client): Date = JavaClass("java.sql.Date", gateway_client) return Date.valueOf(obj.strftime("%Y-%m-%d")) class DatetimeConverter(object): def can_convert(self, obj): return isinstance(obj, datetime.datetime) def convert(self, obj, gateway_client): Timestamp = JavaClass("java.sql.Timestamp", gateway_client) return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000) # datetime is a subclass of date, we should register DatetimeConverter first register_input_converter(DatetimeConverter()) register_input_converter(DateConverter()) def _test(): import doctest from pyspark.context import SparkContext # let doctest run in pyspark.sql.types, so DataTypes can be picklable import pyspark.sql.types from pyspark.sql import Row, SQLContext from pyspark.sql.tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.types.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) globs['ExamplePoint'] = ExamplePoint globs['ExamplePointUDT'] = ExamplePointUDT (failure_count, test_count) = doctest.testmod( pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) if __name__ == "__main__": _test()