Package pyspark :: Module sql
[frames] | no frames]

Source Code for Module pyspark.sql

   1  # 
   2  # Licensed to the Apache Software Foundation (ASF) under one or more 
   3  # contributor license agreements.  See the NOTICE file distributed with 
   4  # this work for additional information regarding copyright ownership. 
   5  # The ASF licenses this file to You under the Apache License, Version 2.0 
   6  # (the "License"); you may not use this file except in compliance with 
   7  # the License.  You may obtain a copy of the License at 
   8  # 
   9  #    http://www.apache.org/licenses/LICENSE-2.0 
  10  # 
  11  # Unless required by applicable law or agreed to in writing, software 
  12  # distributed under the License is distributed on an "AS IS" BASIS, 
  13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
  14  # See the License for the specific language governing permissions and 
  15  # limitations under the License. 
  16  # 
  17   
  18   
  19  import sys 
  20  import types 
  21  import itertools 
  22  import warnings 
  23  import decimal 
  24  import datetime 
  25  import keyword 
  26  import warnings 
  27  from array import array 
  28  from operator import itemgetter 
  29   
  30  from pyspark.rdd import RDD, PipelinedRDD 
  31  from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer 
  32   
  33  from itertools import chain, ifilter, imap 
  34   
  35  from py4j.protocol import Py4JError 
  36  from py4j.java_collections import ListConverter, MapConverter 
  37   
  38   
  39  __all__ = [ 
  40      "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", 
  41      "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", 
  42      "ShortType", "ArrayType", "MapType", "StructField", "StructType", 
  43      "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", 
  44      "SchemaRDD", "Row"] 
45 46 47 -class DataType(object):
48 49 """Spark SQL DataType""" 50
51 - def __repr__(self):
52 return self.__class__.__name__
53
54 - def __hash__(self):
55 return hash(str(self))
56
57 - def __eq__(self, other):
58 return (isinstance(other, self.__class__) and 59 self.__dict__ == other.__dict__)
60
61 - def __ne__(self, other):
62 return not self.__eq__(other)
63
64 65 -class PrimitiveTypeSingleton(type):
66 67 """Metaclass for PrimitiveType""" 68 69 _instances = {} 70
71 - def __call__(cls):
72 if cls not in cls._instances: 73 cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() 74 return cls._instances[cls]
75
76 77 -class PrimitiveType(DataType):
78 79 """Spark SQL PrimitiveType""" 80 81 __metaclass__ = PrimitiveTypeSingleton 82
83 - def __eq__(self, other):
84 # because they should be the same object 85 return self is other
86
87 88 -class StringType(PrimitiveType):
89 90 """Spark SQL StringType 91 92 The data type representing string values. 93 """
94
95 96 -class BinaryType(PrimitiveType):
97 98 """Spark SQL BinaryType 99 100 The data type representing bytearray values. 101 """
102
103 104 -class BooleanType(PrimitiveType):
105 106 """Spark SQL BooleanType 107 108 The data type representing bool values. 109 """
110
111 112 -class TimestampType(PrimitiveType):
113 114 """Spark SQL TimestampType 115 116 The data type representing datetime.datetime values. 117 """
118
119 120 -class DecimalType(PrimitiveType):
121 122 """Spark SQL DecimalType 123 124 The data type representing decimal.Decimal values. 125 """
126
127 128 -class DoubleType(PrimitiveType):
129 130 """Spark SQL DoubleType 131 132 The data type representing float values. 133 """
134
135 136 -class FloatType(PrimitiveType):
137 138 """Spark SQL FloatType 139 140 The data type representing single precision floating-point values. 141 """
142
143 144 -class ByteType(PrimitiveType):
145 146 """Spark SQL ByteType 147 148 The data type representing int values with 1 singed byte. 149 """
150
151 152 -class IntegerType(PrimitiveType):
153 154 """Spark SQL IntegerType 155 156 The data type representing int values. 157 """
158
159 160 -class LongType(PrimitiveType):
161 162 """Spark SQL LongType 163 164 The data type representing long values. If the any value is 165 beyond the range of [-9223372036854775808, 9223372036854775807], 166 please use DecimalType. 167 """
168
169 170 -class ShortType(PrimitiveType):
171 172 """Spark SQL ShortType 173 174 The data type representing int values with 2 signed bytes. 175 """
176
177 178 -class ArrayType(DataType):
179 180 """Spark SQL ArrayType 181 182 The data type representing list values. An ArrayType object 183 comprises two fields, elementType (a DataType) and containsNull (a bool). 184 The field of elementType is used to specify the type of array elements. 185 The field of containsNull is used to specify if the array has None values. 186 187 """ 188
189 - def __init__(self, elementType, containsNull=True):
190 """Creates an ArrayType 191 192 :param elementType: the data type of elements. 193 :param containsNull: indicates whether the list contains None values. 194 195 >>> ArrayType(StringType) == ArrayType(StringType, True) 196 True 197 >>> ArrayType(StringType, False) == ArrayType(StringType) 198 False 199 """ 200 self.elementType = elementType 201 self.containsNull = containsNull
202
203 - def __str__(self):
204 return "ArrayType(%s,%s)" % (self.elementType, 205 str(self.containsNull).lower())
206
207 208 -class MapType(DataType):
209 210 """Spark SQL MapType 211 212 The data type representing dict values. A MapType object comprises 213 three fields, keyType (a DataType), valueType (a DataType) and 214 valueContainsNull (a bool). 215 216 The field of keyType is used to specify the type of keys in the map. 217 The field of valueType is used to specify the type of values in the map. 218 The field of valueContainsNull is used to specify if values of this 219 map has None values. 220 221 For values of a MapType column, keys are not allowed to have None values. 222 223 """ 224
225 - def __init__(self, keyType, valueType, valueContainsNull=True):
226 """Creates a MapType 227 :param keyType: the data type of keys. 228 :param valueType: the data type of values. 229 :param valueContainsNull: indicates whether values contains 230 null values. 231 232 >>> (MapType(StringType, IntegerType) 233 ... == MapType(StringType, IntegerType, True)) 234 True 235 >>> (MapType(StringType, IntegerType, False) 236 ... == MapType(StringType, FloatType)) 237 False 238 """ 239 self.keyType = keyType 240 self.valueType = valueType 241 self.valueContainsNull = valueContainsNull
242
243 - def __repr__(self):
244 return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, 245 str(self.valueContainsNull).lower())
246
247 248 -class StructField(DataType):
249 250 """Spark SQL StructField 251 252 Represents a field in a StructType. 253 A StructField object comprises three fields, name (a string), 254 dataType (a DataType) and nullable (a bool). The field of name 255 is the name of a StructField. The field of dataType specifies 256 the data type of a StructField. 257 258 The field of nullable specifies if values of a StructField can 259 contain None values. 260 261 """ 262
263 - def __init__(self, name, dataType, nullable):
264 """Creates a StructField 265 :param name: the name of this field. 266 :param dataType: the data type of this field. 267 :param nullable: indicates whether values of this field 268 can be null. 269 270 >>> (StructField("f1", StringType, True) 271 ... == StructField("f1", StringType, True)) 272 True 273 >>> (StructField("f1", StringType, True) 274 ... == StructField("f2", StringType, True)) 275 False 276 """ 277 self.name = name 278 self.dataType = dataType 279 self.nullable = nullable
280
281 - def __repr__(self):
282 return "StructField(%s,%s,%s)" % (self.name, self.dataType, 283 str(self.nullable).lower())
284
285 286 -class StructType(DataType):
287 288 """Spark SQL StructType 289 290 The data type representing rows. 291 A StructType object comprises a list of L{StructField}s. 292 293 """ 294
295 - def __init__(self, fields):
296 """Creates a StructType 297 298 >>> struct1 = StructType([StructField("f1", StringType, True)]) 299 >>> struct2 = StructType([StructField("f1", StringType, True)]) 300 >>> struct1 == struct2 301 True 302 >>> struct1 = StructType([StructField("f1", StringType, True)]) 303 >>> struct2 = StructType([StructField("f1", StringType, True), 304 ... [StructField("f2", IntegerType, False)]]) 305 >>> struct1 == struct2 306 False 307 """ 308 self.fields = fields
309
310 - def __repr__(self):
311 return ("StructType(List(%s))" % 312 ",".join(str(field) for field in self.fields))
313
314 315 -def _parse_datatype_list(datatype_list_string):
316 """Parses a list of comma separated data types.""" 317 index = 0 318 datatype_list = [] 319 start = 0 320 depth = 0 321 while index < len(datatype_list_string): 322 if depth == 0 and datatype_list_string[index] == ",": 323 datatype_string = datatype_list_string[start:index].strip() 324 datatype_list.append(_parse_datatype_string(datatype_string)) 325 start = index + 1 326 elif datatype_list_string[index] == "(": 327 depth += 1 328 elif datatype_list_string[index] == ")": 329 depth -= 1 330 331 index += 1 332 333 # Handle the last data type 334 datatype_string = datatype_list_string[start:index].strip() 335 datatype_list.append(_parse_datatype_string(datatype_string)) 336 return datatype_list
337 338 339 _all_primitive_types = dict((k, v) for k, v in globals().iteritems() 340 if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType)
341 342 343 -def _parse_datatype_string(datatype_string):
344 """Parses the given data type string. 345 346 >>> def check_datatype(datatype): 347 ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) 348 ... python_datatype = _parse_datatype_string( 349 ... scala_datatype.toString()) 350 ... return datatype == python_datatype 351 >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) 352 True 353 >>> # Simple ArrayType. 354 >>> simple_arraytype = ArrayType(StringType(), True) 355 >>> check_datatype(simple_arraytype) 356 True 357 >>> # Simple MapType. 358 >>> simple_maptype = MapType(StringType(), LongType()) 359 >>> check_datatype(simple_maptype) 360 True 361 >>> # Simple StructType. 362 >>> simple_structtype = StructType([ 363 ... StructField("a", DecimalType(), False), 364 ... StructField("b", BooleanType(), True), 365 ... StructField("c", LongType(), True), 366 ... StructField("d", BinaryType(), False)]) 367 >>> check_datatype(simple_structtype) 368 True 369 >>> # Complex StructType. 370 >>> complex_structtype = StructType([ 371 ... StructField("simpleArray", simple_arraytype, True), 372 ... StructField("simpleMap", simple_maptype, True), 373 ... StructField("simpleStruct", simple_structtype, True), 374 ... StructField("boolean", BooleanType(), False)]) 375 >>> check_datatype(complex_structtype) 376 True 377 >>> # Complex ArrayType. 378 >>> complex_arraytype = ArrayType(complex_structtype, True) 379 >>> check_datatype(complex_arraytype) 380 True 381 >>> # Complex MapType. 382 >>> complex_maptype = MapType(complex_structtype, 383 ... complex_arraytype, False) 384 >>> check_datatype(complex_maptype) 385 True 386 """ 387 index = datatype_string.find("(") 388 if index == -1: 389 # It is a primitive type. 390 index = len(datatype_string) 391 type_or_field = datatype_string[:index] 392 rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip() 393 394 if type_or_field in _all_primitive_types: 395 return _all_primitive_types[type_or_field]() 396 397 elif type_or_field == "ArrayType": 398 last_comma_index = rest_part.rfind(",") 399 containsNull = True 400 if rest_part[last_comma_index + 1:].strip().lower() == "false": 401 containsNull = False 402 elementType = _parse_datatype_string( 403 rest_part[:last_comma_index].strip()) 404 return ArrayType(elementType, containsNull) 405 406 elif type_or_field == "MapType": 407 last_comma_index = rest_part.rfind(",") 408 valueContainsNull = True 409 if rest_part[last_comma_index + 1:].strip().lower() == "false": 410 valueContainsNull = False 411 keyType, valueType = _parse_datatype_list( 412 rest_part[:last_comma_index].strip()) 413 return MapType(keyType, valueType, valueContainsNull) 414 415 elif type_or_field == "StructField": 416 first_comma_index = rest_part.find(",") 417 name = rest_part[:first_comma_index].strip() 418 last_comma_index = rest_part.rfind(",") 419 nullable = True 420 if rest_part[last_comma_index + 1:].strip().lower() == "false": 421 nullable = False 422 dataType = _parse_datatype_string( 423 rest_part[first_comma_index + 1:last_comma_index].strip()) 424 return StructField(name, dataType, nullable) 425 426 elif type_or_field == "StructType": 427 # rest_part should be in the format like 428 # List(StructField(field1,IntegerType,false)). 429 field_list_string = rest_part[rest_part.find("(") + 1:-1] 430 fields = _parse_datatype_list(field_list_string) 431 return StructType(fields)
432 433 434 # Mapping Python types to Spark SQL DateType 435 _type_mappings = { 436 bool: BooleanType, 437 int: IntegerType, 438 long: LongType, 439 float: DoubleType, 440 str: StringType, 441 unicode: StringType, 442 decimal.Decimal: DecimalType, 443 datetime.datetime: TimestampType, 444 datetime.date: TimestampType, 445 datetime.time: TimestampType, 446 }
447 448 449 -def _infer_type(obj):
450 """Infer the DataType from obj""" 451 if obj is None: 452 raise ValueError("Can not infer type for None") 453 454 dataType = _type_mappings.get(type(obj)) 455 if dataType is not None: 456 return dataType() 457 458 if isinstance(obj, dict): 459 if not obj: 460 raise ValueError("Can not infer type for empty dict") 461 key, value = obj.iteritems().next() 462 return MapType(_infer_type(key), _infer_type(value), True) 463 elif isinstance(obj, (list, array)): 464 if not obj: 465 raise ValueError("Can not infer type for empty list/array") 466 return ArrayType(_infer_type(obj[0]), True) 467 else: 468 try: 469 return _infer_schema(obj) 470 except ValueError: 471 raise ValueError("not supported type: %s" % type(obj))
472
473 474 -def _infer_schema(row):
475 """Infer the schema from dict/namedtuple/object""" 476 if isinstance(row, dict): 477 items = sorted(row.items()) 478 479 elif isinstance(row, tuple): 480 if hasattr(row, "_fields"): # namedtuple 481 items = zip(row._fields, tuple(row)) 482 elif hasattr(row, "__FIELDS__"): # Row 483 items = zip(row.__FIELDS__, tuple(row)) 484 elif all(isinstance(x, tuple) and len(x) == 2 for x in row): 485 items = row 486 else: 487 raise ValueError("Can't infer schema from tuple") 488 489 elif hasattr(row, "__dict__"): # object 490 items = sorted(row.__dict__.items()) 491 492 else: 493 raise ValueError("Can not infer schema for type: %s" % type(row)) 494 495 fields = [StructField(k, _infer_type(v), True) for k, v in items] 496 return StructType(fields)
497
498 499 -def _create_converter(obj, dataType):
500 """Create an converter to drop the names of fields in obj """ 501 if isinstance(dataType, ArrayType): 502 conv = _create_converter(obj[0], dataType.elementType) 503 return lambda row: map(conv, row) 504 505 elif isinstance(dataType, MapType): 506 value = obj.values()[0] 507 conv = _create_converter(value, dataType.valueType) 508 return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) 509 510 elif not isinstance(dataType, StructType): 511 return lambda x: x 512 513 # dataType must be StructType 514 names = [f.name for f in dataType.fields] 515 516 if isinstance(obj, dict): 517 conv = lambda o: tuple(o.get(n) for n in names) 518 519 elif isinstance(obj, tuple): 520 if hasattr(obj, "_fields"): # namedtuple 521 conv = tuple 522 elif hasattr(obj, "__FIELDS__"): 523 conv = tuple 524 elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): 525 conv = lambda o: tuple(v for k, v in o) 526 else: 527 raise ValueError("unexpected tuple") 528 529 elif hasattr(obj, "__dict__"): # object 530 conv = lambda o: [o.__dict__.get(n, None) for n in names] 531 532 if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields): 533 return conv 534 535 row = conv(obj) 536 convs = [_create_converter(v, f.dataType) 537 for v, f in zip(row, dataType.fields)] 538 539 def nested_conv(row): 540 return tuple(f(v) for f, v in zip(convs, conv(row)))
541 542 return nested_conv 543
544 545 -def _drop_schema(rows, schema):
546 """ all the names of fields, becoming tuples""" 547 iterator = iter(rows) 548 row = iterator.next() 549 converter = _create_converter(row, schema) 550 yield converter(row) 551 for i in iterator: 552 yield converter(i)
553 554 555 _BRACKETS = {'(': ')', '[': ']', '{': '}'}
556 557 558 -def _split_schema_abstract(s):
559 """ 560 split the schema abstract into fields 561 562 >>> _split_schema_abstract("a b c") 563 ['a', 'b', 'c'] 564 >>> _split_schema_abstract("a(a b)") 565 ['a(a b)'] 566 >>> _split_schema_abstract("a b[] c{a b}") 567 ['a', 'b[]', 'c{a b}'] 568 >>> _split_schema_abstract(" ") 569 [] 570 """ 571 572 r = [] 573 w = '' 574 brackets = [] 575 for c in s: 576 if c == ' ' and not brackets: 577 if w: 578 r.append(w) 579 w = '' 580 else: 581 w += c 582 if c in _BRACKETS: 583 brackets.append(c) 584 elif c in _BRACKETS.values(): 585 if not brackets or c != _BRACKETS[brackets.pop()]: 586 raise ValueError("unexpected " + c) 587 588 if brackets: 589 raise ValueError("brackets not closed: %s" % brackets) 590 if w: 591 r.append(w) 592 return r
593
594 595 -def _parse_field_abstract(s):
596 """ 597 Parse a field in schema abstract 598 599 >>> _parse_field_abstract("a") 600 StructField(a,None,true) 601 >>> _parse_field_abstract("b(c d)") 602 StructField(b,StructType(...c,None,true),StructField(d... 603 >>> _parse_field_abstract("a[]") 604 StructField(a,ArrayType(None,true),true) 605 >>> _parse_field_abstract("a{[]}") 606 StructField(a,MapType(None,ArrayType(None,true),true),true) 607 """ 608 if set(_BRACKETS.keys()) & set(s): 609 idx = min((s.index(c) for c in _BRACKETS if c in s)) 610 name = s[:idx] 611 return StructField(name, _parse_schema_abstract(s[idx:]), True) 612 else: 613 return StructField(s, None, True)
614
615 616 -def _parse_schema_abstract(s):
617 """ 618 parse abstract into schema 619 620 >>> _parse_schema_abstract("a b c") 621 StructType...a...b...c... 622 >>> _parse_schema_abstract("a[b c] b{}") 623 StructType...a,ArrayType...b...c...b,MapType... 624 >>> _parse_schema_abstract("c{} d{a b}") 625 StructType...c,MapType...d,MapType...a...b... 626 >>> _parse_schema_abstract("a b(t)").fields[1] 627 StructField(b,StructType(List(StructField(t,None,true))),true) 628 """ 629 s = s.strip() 630 if not s: 631 return 632 633 elif s.startswith('('): 634 return _parse_schema_abstract(s[1:-1]) 635 636 elif s.startswith('['): 637 return ArrayType(_parse_schema_abstract(s[1:-1]), True) 638 639 elif s.startswith('{'): 640 return MapType(None, _parse_schema_abstract(s[1:-1])) 641 642 parts = _split_schema_abstract(s) 643 fields = [_parse_field_abstract(p) for p in parts] 644 return StructType(fields)
645
646 647 -def _infer_schema_type(obj, dataType):
648 """ 649 Fill the dataType with types infered from obj 650 651 >>> schema = _parse_schema_abstract("a b c") 652 >>> row = (1, 1.0, "str") 653 >>> _infer_schema_type(row, schema) 654 StructType...IntegerType...DoubleType...StringType... 655 >>> row = [[1], {"key": (1, 2.0)}] 656 >>> schema = _parse_schema_abstract("a[] b{c d}") 657 >>> _infer_schema_type(row, schema) 658 StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... 659 """ 660 if dataType is None: 661 return _infer_type(obj) 662 663 if not obj: 664 raise ValueError("Can not infer type from empty value") 665 666 if isinstance(dataType, ArrayType): 667 eType = _infer_schema_type(obj[0], dataType.elementType) 668 return ArrayType(eType, True) 669 670 elif isinstance(dataType, MapType): 671 k, v = obj.iteritems().next() 672 return MapType(_infer_type(k), 673 _infer_schema_type(v, dataType.valueType)) 674 675 elif isinstance(dataType, StructType): 676 fs = dataType.fields 677 assert len(fs) == len(obj), \ 678 "Obj(%s) have different length with fields(%s)" % (obj, fs) 679 fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) 680 for o, f in zip(obj, fs)] 681 return StructType(fields) 682 683 else: 684 raise ValueError("Unexpected dataType: %s" % dataType)
685 686 687 _acceptable_types = { 688 BooleanType: (bool,), 689 ByteType: (int, long), 690 ShortType: (int, long), 691 IntegerType: (int, long), 692 LongType: (long,), 693 FloatType: (float,), 694 DoubleType: (float,), 695 DecimalType: (decimal.Decimal,), 696 StringType: (str, unicode), 697 TimestampType: (datetime.datetime,), 698 ArrayType: (list, tuple, array), 699 MapType: (dict,), 700 StructType: (tuple, list), 701 }
702 703 704 -def _verify_type(obj, dataType):
705 """ 706 Verify the type of obj against dataType, raise an exception if 707 they do not match. 708 709 >>> _verify_type(None, StructType([])) 710 >>> _verify_type("", StringType()) 711 >>> _verify_type(0, IntegerType()) 712 >>> _verify_type(range(3), ArrayType(ShortType())) 713 >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL 714 Traceback (most recent call last): 715 ... 716 TypeError:... 717 >>> _verify_type({}, MapType(StringType(), IntegerType())) 718 >>> _verify_type((), StructType([])) 719 >>> _verify_type([], StructType([])) 720 >>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL 721 Traceback (most recent call last): 722 ... 723 ValueError:... 724 """ 725 # all objects are nullable 726 if obj is None: 727 return 728 729 _type = type(dataType) 730 if _type not in _acceptable_types: 731 return 732 733 if type(obj) not in _acceptable_types[_type]: 734 raise TypeError("%s can not accept abject in type %s" 735 % (dataType, type(obj))) 736 737 if isinstance(dataType, ArrayType): 738 for i in obj: 739 _verify_type(i, dataType.elementType) 740 741 elif isinstance(dataType, MapType): 742 for k, v in obj.iteritems(): 743 _verify_type(k, dataType.keyType) 744 _verify_type(v, dataType.valueType) 745 746 elif isinstance(dataType, StructType): 747 if len(obj) != len(dataType.fields): 748 raise ValueError("Length of object (%d) does not match with" 749 "length of fields (%d)" % (len(obj), len(dataType.fields))) 750 for v, f in zip(obj, dataType.fields): 751 _verify_type(v, f.dataType)
752 753 754 _cached_cls = {}
755 756 757 -def _restore_object(dataType, obj):
758 """ Restore object during unpickling. """ 759 # use id(dataType) as key to speed up lookup in dict 760 # Because of batched pickling, dataType will be the 761 # same object in mose cases. 762 k = id(dataType) 763 cls = _cached_cls.get(k) 764 if cls is None: 765 # use dataType as key to avoid create multiple class 766 cls = _cached_cls.get(dataType) 767 if cls is None: 768 cls = _create_cls(dataType) 769 _cached_cls[dataType] = cls 770 _cached_cls[k] = cls 771 return cls(obj)
772
773 774 -def _create_object(cls, v):
775 """ Create an customized object with class `cls`. """ 776 return cls(v) if v is not None else v
777
778 779 -def _create_getter(dt, i):
780 """ Create a getter for item `i` with schema """ 781 cls = _create_cls(dt) 782 783 def getter(self): 784 return _create_object(cls, self[i])
785 786 return getter 787
788 789 -def _has_struct(dt):
790 """Return whether `dt` is or has StructType in it""" 791 if isinstance(dt, StructType): 792 return True 793 elif isinstance(dt, ArrayType): 794 return _has_struct(dt.elementType) 795 elif isinstance(dt, MapType): 796 return _has_struct(dt.valueType) 797 return False
798
799 800 -def _create_properties(fields):
801 """Create properties according to fields""" 802 ps = {} 803 for i, f in enumerate(fields): 804 name = f.name 805 if (name.startswith("__") and name.endswith("__") 806 or keyword.iskeyword(name)): 807 warnings.warn("field name %s can not be accessed in Python," 808 "use position to access it instead" % name) 809 if _has_struct(f.dataType): 810 # delay creating object until accessing it 811 getter = _create_getter(f.dataType, i) 812 else: 813 getter = itemgetter(i) 814 ps[name] = property(getter) 815 return ps
816
817 818 -def _create_cls(dataType):
819 """ 820 Create an class by dataType 821 822 The created class is similar to namedtuple, but can have nested schema. 823 824 >>> schema = _parse_schema_abstract("a b c") 825 >>> row = (1, 1.0, "str") 826 >>> schema = _infer_schema_type(row, schema) 827 >>> obj = _create_cls(schema)(row) 828 >>> import pickle 829 >>> pickle.loads(pickle.dumps(obj)) 830 Row(a=1, b=1.0, c='str') 831 832 >>> row = [[1], {"key": (1, 2.0)}] 833 >>> schema = _parse_schema_abstract("a[] b{c d}") 834 >>> schema = _infer_schema_type(row, schema) 835 >>> obj = _create_cls(schema)(row) 836 >>> pickle.loads(pickle.dumps(obj)) 837 Row(a=[1], b={'key': Row(c=1, d=2.0)}) 838 """ 839 840 if isinstance(dataType, ArrayType): 841 cls = _create_cls(dataType.elementType) 842 843 class List(list): 844 845 def __getitem__(self, i): 846 # create object with datetype 847 return _create_object(cls, list.__getitem__(self, i))
848 849 def __repr__(self): 850 # call collect __repr__ for nested objects 851 return "[%s]" % (", ".join(repr(self[i]) 852 for i in range(len(self)))) 853 854 def __reduce__(self): 855 return list.__reduce__(self) 856 857 return List 858 859 elif isinstance(dataType, MapType): 860 vcls = _create_cls(dataType.valueType) 861 862 class Dict(dict): 863 864 def __getitem__(self, k): 865 # create object with datetype 866 return _create_object(vcls, dict.__getitem__(self, k)) 867 868 def __repr__(self): 869 # call collect __repr__ for nested objects 870 return "{%s}" % (", ".join("%r: %r" % (k, self[k]) 871 for k in self)) 872 873 def __reduce__(self): 874 return dict.__reduce__(self) 875 876 return Dict 877 878 elif not isinstance(dataType, StructType): 879 raise Exception("unexpected data type: %s" % dataType) 880 881 class Row(tuple): 882 883 """ Row in SchemaRDD """ 884 __DATATYPE__ = dataType 885 __FIELDS__ = tuple(f.name for f in dataType.fields) 886 __slots__ = () 887 888 # create property for fast access 889 locals().update(_create_properties(dataType.fields)) 890 891 def __repr__(self): 892 # call collect __repr__ for nested objects 893 return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) 894 for n in self.__FIELDS__)) 895 896 def __reduce__(self): 897 return (_restore_object, (self.__DATATYPE__, tuple(self))) 898 899 return Row 900
901 902 -class SQLContext:
903 904 """Main entry point for SparkSQL functionality. 905 906 A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as 907 tables, execute SQL over tables, cache tables, and read parquet files. 908 """ 909
910 - def __init__(self, sparkContext, sqlContext=None):
911 """Create a new SQLContext. 912 913 @param sparkContext: The SparkContext to wrap. 914 @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new 915 SQLContext in the JVM, instead we make all calls to this object. 916 917 >>> srdd = sqlCtx.inferSchema(rdd) 918 >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL 919 Traceback (most recent call last): 920 ... 921 TypeError:... 922 923 >>> bad_rdd = sc.parallelize([1,2,3]) 924 >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL 925 Traceback (most recent call last): 926 ... 927 ValueError:... 928 929 >>> from datetime import datetime 930 >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, 931 ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), 932 ... time=datetime(2014, 8, 1, 14, 1, 5))]) 933 >>> srdd = sqlCtx.inferSchema(allTypes) 934 >>> srdd.registerTempTable("allTypes") 935 >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' 936 ... 'from allTypes where b and i > 0').collect() 937 [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] 938 >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, 939 ... x.row.a, x.list)).collect() 940 [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] 941 """ 942 self._sc = sparkContext 943 self._jsc = self._sc._jsc 944 self._jvm = self._sc._jvm 945 self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray 946 947 if sqlContext: 948 self._scala_SQLContext = sqlContext
949 950 @property
951 - def _ssql_ctx(self):
952 """Accessor for the JVM SparkSQL context. 953 954 Subclasses can override this property to provide their own 955 JVM Contexts. 956 """ 957 if not hasattr(self, '_scala_SQLContext'): 958 self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) 959 return self._scala_SQLContext
960
961 - def registerFunction(self, name, f, returnType=StringType()):
962 """Registers a lambda function as a UDF so it can be used in SQL statements. 963 964 In addition to a name and the function itself, the return type can be optionally specified. 965 When the return type is not given it default to a string and conversion will automatically 966 be done. For any other return type, the produced object must match the specified type. 967 968 >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) 969 >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() 970 [Row(c0=u'4')] 971 >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) 972 >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() 973 [Row(c0=4)] 974 >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) 975 >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() 976 [Row(c0=5)] 977 """ 978 func = lambda _, it: imap(lambda x: f(*x), it) 979 command = (func, 980 BatchedSerializer(PickleSerializer(), 1024), 981 BatchedSerializer(PickleSerializer(), 1024)) 982 env = MapConverter().convert(self._sc.environment, 983 self._sc._gateway._gateway_client) 984 includes = ListConverter().convert(self._sc._python_includes, 985 self._sc._gateway._gateway_client) 986 self._ssql_ctx.registerPython(name, 987 bytearray(CloudPickleSerializer().dumps(command)), 988 env, 989 includes, 990 self._sc.pythonExec, 991 self._sc._javaAccumulator, 992 str(returnType))
993
994 - def inferSchema(self, rdd):
995 """Infer and apply a schema to an RDD of L{Row}s. 996 997 We peek at the first row of the RDD to determine the fields' names 998 and types. Nested collections are supported, which include array, 999 dict, list, Row, tuple, namedtuple, or object. 1000 1001 All the rows in `rdd` should have the same type with the first one, 1002 or it will cause runtime exceptions. 1003 1004 Each row could be L{pyspark.sql.Row} object or namedtuple or objects, 1005 using dict is deprecated. 1006 1007 >>> rdd = sc.parallelize( 1008 ... [Row(field1=1, field2="row1"), 1009 ... Row(field1=2, field2="row2"), 1010 ... Row(field1=3, field2="row3")]) 1011 >>> srdd = sqlCtx.inferSchema(rdd) 1012 >>> srdd.collect()[0] 1013 Row(field1=1, field2=u'row1') 1014 1015 >>> NestedRow = Row("f1", "f2") 1016 >>> nestedRdd1 = sc.parallelize([ 1017 ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), 1018 ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) 1019 >>> srdd = sqlCtx.inferSchema(nestedRdd1) 1020 >>> srdd.collect() 1021 [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] 1022 1023 >>> nestedRdd2 = sc.parallelize([ 1024 ... NestedRow([[1, 2], [2, 3]], [1, 2]), 1025 ... NestedRow([[2, 3], [3, 4]], [2, 3])]) 1026 >>> srdd = sqlCtx.inferSchema(nestedRdd2) 1027 >>> srdd.collect() 1028 [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] 1029 """ 1030 1031 if isinstance(rdd, SchemaRDD): 1032 raise TypeError("Cannot apply schema to SchemaRDD") 1033 1034 first = rdd.first() 1035 if not first: 1036 raise ValueError("The first row in RDD is empty, " 1037 "can not infer schema") 1038 if type(first) is dict: 1039 warnings.warn("Using RDD of dict to inferSchema is deprecated," 1040 "please use pyspark.Row instead") 1041 1042 schema = _infer_schema(first) 1043 rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) 1044 return self.applySchema(rdd, schema)
1045
1046 - def applySchema(self, rdd, schema):
1047 """ 1048 Applies the given schema to the given RDD of L{tuple} or L{list}s. 1049 1050 These tuples or lists can contain complex nested structures like 1051 lists, maps or nested rows. 1052 1053 The schema should be a StructType. 1054 1055 It is important that the schema matches the types of the objects 1056 in each row or exceptions could be thrown at runtime. 1057 1058 >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) 1059 >>> schema = StructType([StructField("field1", IntegerType(), False), 1060 ... StructField("field2", StringType(), False)]) 1061 >>> srdd = sqlCtx.applySchema(rdd2, schema) 1062 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 1063 >>> srdd2 = sqlCtx.sql("SELECT * from table1") 1064 >>> srdd2.collect() 1065 [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] 1066 1067 >>> from datetime import datetime 1068 >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, 1069 ... datetime(2010, 1, 1, 1, 1, 1), 1070 ... {"a": 1}, (2,), [1, 2, 3], None)]) 1071 >>> schema = StructType([ 1072 ... StructField("byte1", ByteType(), False), 1073 ... StructField("byte2", ByteType(), False), 1074 ... StructField("short1", ShortType(), False), 1075 ... StructField("short2", ShortType(), False), 1076 ... StructField("int", IntegerType(), False), 1077 ... StructField("float", FloatType(), False), 1078 ... StructField("time", TimestampType(), False), 1079 ... StructField("map", 1080 ... MapType(StringType(), IntegerType(), False), False), 1081 ... StructField("struct", 1082 ... StructType([StructField("b", ShortType(), False)]), False), 1083 ... StructField("list", ArrayType(ByteType(), False), False), 1084 ... StructField("null", DoubleType(), True)]) 1085 >>> srdd = sqlCtx.applySchema(rdd, schema) 1086 >>> results = srdd.map( 1087 ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time, 1088 ... x.map["a"], x.struct.b, x.list, x.null)) 1089 >>> results.collect()[0] 1090 (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) 1091 1092 >>> srdd.registerTempTable("table2") 1093 >>> sqlCtx.sql( 1094 ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + 1095 ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + 1096 ... "float + 1.5 as float FROM table2").collect() 1097 [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] 1098 1099 >>> rdd = sc.parallelize([(127, -32768, 1.0, 1100 ... datetime(2010, 1, 1, 1, 1, 1), 1101 ... {"a": 1}, (2,), [1, 2, 3])]) 1102 >>> abstract = "byte short float time map{} struct(b) list[]" 1103 >>> schema = _parse_schema_abstract(abstract) 1104 >>> typedSchema = _infer_schema_type(rdd.first(), schema) 1105 >>> srdd = sqlCtx.applySchema(rdd, typedSchema) 1106 >>> srdd.collect() 1107 [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] 1108 """ 1109 1110 if isinstance(rdd, SchemaRDD): 1111 raise TypeError("Cannot apply schema to SchemaRDD") 1112 1113 if not isinstance(schema, StructType): 1114 raise TypeError("schema should be StructType") 1115 1116 # take the first few rows to verify schema 1117 rows = rdd.take(10) 1118 for row in rows: 1119 _verify_type(row, schema) 1120 1121 batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) 1122 jrdd = self._pythonToJava(rdd._jrdd, batched) 1123 srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) 1124 return SchemaRDD(srdd.toJavaSchemaRDD(), self)
1125
1126 - def registerRDDAsTable(self, rdd, tableName):
1127 """Registers the given RDD as a temporary table in the catalog. 1128 1129 Temporary tables exist only during the lifetime of this instance of 1130 SQLContext. 1131 1132 >>> srdd = sqlCtx.inferSchema(rdd) 1133 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 1134 """ 1135 if (rdd.__class__ is SchemaRDD): 1136 srdd = rdd._jschema_rdd.baseSchemaRDD() 1137 self._ssql_ctx.registerRDDAsTable(srdd, tableName) 1138 else: 1139 raise ValueError("Can only register SchemaRDD as table")
1140
1141 - def parquetFile(self, path):
1142 """Loads a Parquet file, returning the result as a L{SchemaRDD}. 1143 1144 >>> import tempfile, shutil 1145 >>> parquetFile = tempfile.mkdtemp() 1146 >>> shutil.rmtree(parquetFile) 1147 >>> srdd = sqlCtx.inferSchema(rdd) 1148 >>> srdd.saveAsParquetFile(parquetFile) 1149 >>> srdd2 = sqlCtx.parquetFile(parquetFile) 1150 >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 1151 True 1152 """ 1153 jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() 1154 return SchemaRDD(jschema_rdd, self)
1155
1156 - def jsonFile(self, path, schema=None):
1157 """ 1158 Loads a text file storing one JSON object per line as a 1159 L{SchemaRDD}. 1160 1161 If the schema is provided, applies the given schema to this 1162 JSON dataset. 1163 1164 Otherwise, it goes through the entire dataset once to determine 1165 the schema. 1166 1167 >>> import tempfile, shutil 1168 >>> jsonFile = tempfile.mkdtemp() 1169 >>> shutil.rmtree(jsonFile) 1170 >>> ofn = open(jsonFile, 'w') 1171 >>> for json in jsonStrings: 1172 ... print>>ofn, json 1173 >>> ofn.close() 1174 >>> srdd1 = sqlCtx.jsonFile(jsonFile) 1175 >>> sqlCtx.registerRDDAsTable(srdd1, "table1") 1176 >>> srdd2 = sqlCtx.sql( 1177 ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " 1178 ... "field6 as f4 from table1") 1179 >>> for r in srdd2.collect(): 1180 ... print r 1181 Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) 1182 Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) 1183 Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) 1184 >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) 1185 >>> sqlCtx.registerRDDAsTable(srdd3, "table2") 1186 >>> srdd4 = sqlCtx.sql( 1187 ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " 1188 ... "field6 as f4 from table2") 1189 >>> for r in srdd4.collect(): 1190 ... print r 1191 Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) 1192 Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) 1193 Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) 1194 >>> schema = StructType([ 1195 ... StructField("field2", StringType(), True), 1196 ... StructField("field3", 1197 ... StructType([ 1198 ... StructField("field5", 1199 ... ArrayType(IntegerType(), False), True)]), False)]) 1200 >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) 1201 >>> sqlCtx.registerRDDAsTable(srdd5, "table3") 1202 >>> srdd6 = sqlCtx.sql( 1203 ... "SELECT field2 AS f1, field3.field5 as f2, " 1204 ... "field3.field5[0] as f3 from table3") 1205 >>> srdd6.collect() 1206 [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] 1207 """ 1208 if schema is None: 1209 srdd = self._ssql_ctx.jsonFile(path) 1210 else: 1211 scala_datatype = self._ssql_ctx.parseDataType(str(schema)) 1212 srdd = self._ssql_ctx.jsonFile(path, scala_datatype) 1213 return SchemaRDD(srdd.toJavaSchemaRDD(), self)
1214
1215 - def jsonRDD(self, rdd, schema=None):
1216 """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. 1217 1218 If the schema is provided, applies the given schema to this 1219 JSON dataset. 1220 1221 Otherwise, it goes through the entire dataset once to determine 1222 the schema. 1223 1224 >>> srdd1 = sqlCtx.jsonRDD(json) 1225 >>> sqlCtx.registerRDDAsTable(srdd1, "table1") 1226 >>> srdd2 = sqlCtx.sql( 1227 ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " 1228 ... "field6 as f4 from table1") 1229 >>> for r in srdd2.collect(): 1230 ... print r 1231 Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) 1232 Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) 1233 Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) 1234 >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) 1235 >>> sqlCtx.registerRDDAsTable(srdd3, "table2") 1236 >>> srdd4 = sqlCtx.sql( 1237 ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " 1238 ... "field6 as f4 from table2") 1239 >>> for r in srdd4.collect(): 1240 ... print r 1241 Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) 1242 Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) 1243 Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) 1244 >>> schema = StructType([ 1245 ... StructField("field2", StringType(), True), 1246 ... StructField("field3", 1247 ... StructType([ 1248 ... StructField("field5", 1249 ... ArrayType(IntegerType(), False), True)]), False)]) 1250 >>> srdd5 = sqlCtx.jsonRDD(json, schema) 1251 >>> sqlCtx.registerRDDAsTable(srdd5, "table3") 1252 >>> srdd6 = sqlCtx.sql( 1253 ... "SELECT field2 AS f1, field3.field5 as f2, " 1254 ... "field3.field5[0] as f3 from table3") 1255 >>> srdd6.collect() 1256 [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] 1257 1258 >>> sqlCtx.jsonRDD(sc.parallelize(['{}', 1259 ... '{"key0": {"key1": "value1"}}'])).collect() 1260 [Row(key0=None), Row(key0=Row(key1=u'value1'))] 1261 >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', 1262 ... '{"key0": {"key1": "value1"}}'])).collect() 1263 [Row(key0=None), Row(key0=Row(key1=u'value1'))] 1264 """ 1265 1266 def func(iterator): 1267 for x in iterator: 1268 if not isinstance(x, basestring): 1269 x = unicode(x) 1270 if isinstance(x, unicode): 1271 x = x.encode("utf-8") 1272 yield x
1273 keyed = rdd.mapPartitions(func) 1274 keyed._bypass_serializer = True 1275 jrdd = keyed._jrdd.map(self._jvm.BytesToString()) 1276 if schema is None: 1277 srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) 1278 else: 1279 scala_datatype = self._ssql_ctx.parseDataType(str(schema)) 1280 srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) 1281 return SchemaRDD(srdd.toJavaSchemaRDD(), self)
1282
1283 - def sql(self, sqlQuery):
1284 """Return a L{SchemaRDD} representing the result of the given query. 1285 1286 >>> srdd = sqlCtx.inferSchema(rdd) 1287 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 1288 >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") 1289 >>> srdd2.collect() 1290 [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] 1291 """ 1292 return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self)
1293
1294 - def table(self, tableName):
1295 """Returns the specified table as a L{SchemaRDD}. 1296 1297 >>> srdd = sqlCtx.inferSchema(rdd) 1298 >>> sqlCtx.registerRDDAsTable(srdd, "table1") 1299 >>> srdd2 = sqlCtx.table("table1") 1300 >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 1301 True 1302 """ 1303 return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self)
1304
1305 - def cacheTable(self, tableName):
1306 """Caches the specified table in-memory.""" 1307 self._ssql_ctx.cacheTable(tableName)
1308
1309 - def uncacheTable(self, tableName):
1310 """Removes the specified table from the in-memory cache.""" 1311 self._ssql_ctx.uncacheTable(tableName)
1312
1313 1314 -class HiveContext(SQLContext):
1315 1316 """A variant of Spark SQL that integrates with data stored in Hive. 1317 1318 Configuration for Hive is read from hive-site.xml on the classpath. 1319 It supports running both SQL and HiveQL commands. 1320 """ 1321
1322 - def __init__(self, sparkContext, hiveContext=None):
1323 """Create a new HiveContext. 1324 1325 @param sparkContext: The SparkContext to wrap. 1326 @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new 1327 HiveContext in the JVM, instead we make all calls to this object. 1328 """ 1329 SQLContext.__init__(self, sparkContext) 1330 1331 if hiveContext: 1332 self._scala_HiveContext = hiveContext
1333 1334 @property
1335 - def _ssql_ctx(self):
1336 try: 1337 if not hasattr(self, '_scala_HiveContext'): 1338 self._scala_HiveContext = self._get_hive_ctx() 1339 return self._scala_HiveContext 1340 except Py4JError as e: 1341 raise Exception("You must build Spark with Hive. " 1342 "Export 'SPARK_HIVE=true' and run " 1343 "sbt/sbt assembly", e)
1344
1345 - def _get_hive_ctx(self):
1346 return self._jvm.HiveContext(self._jsc.sc())
1347
1348 - def hiveql(self, hqlQuery):
1349 """ 1350 DEPRECATED: Use sql() 1351 """ 1352 warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" + 1353 "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", 1354 DeprecationWarning) 1355 return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self)
1356
1357 - def hql(self, hqlQuery):
1358 """ 1359 DEPRECATED: Use sql() 1360 """ 1361 warnings.warn("hql() is deprecated as the sql function now parses using HiveQL by" + 1362 "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", 1363 DeprecationWarning) 1364 return self.hiveql(hqlQuery)
1365
1366 1367 -class LocalHiveContext(HiveContext):
1368 1369 """Starts up an instance of hive where metadata is stored locally. 1370 1371 An in-process metadata data is created with data stored in ./metadata. 1372 Warehouse data is stored in in ./warehouse. 1373 1374 >>> import os 1375 >>> hiveCtx = LocalHiveContext(sc) 1376 >>> try: 1377 ... supress = hiveCtx.sql("DROP TABLE src") 1378 ... except Exception: 1379 ... pass 1380 >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 1381 ... 'examples/src/main/resources/kv1.txt') 1382 >>> supress = hiveCtx.sql( 1383 ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") 1384 >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" 1385 ... % kv1) 1386 >>> results = hiveCtx.sql("FROM src SELECT value" 1387 ... ).map(lambda r: int(r.value.split('_')[1])) 1388 >>> num = results.count() 1389 >>> reduce_sum = results.reduce(lambda x, y: x + y) 1390 >>> num 1391 500 1392 >>> reduce_sum 1393 130091 1394 """ 1395
1396 - def __init__(self, sparkContext, sqlContext=None):
1397 HiveContext.__init__(self, sparkContext, sqlContext) 1398 warnings.warn("LocalHiveContext is deprecated. " 1399 "Use HiveContext instead.", DeprecationWarning)
1400
1401 - def _get_hive_ctx(self):
1402 return self._jvm.LocalHiveContext(self._jsc.sc())
1403
1404 1405 -class TestHiveContext(HiveContext):
1406
1407 - def _get_hive_ctx(self):
1408 return self._jvm.TestHiveContext(self._jsc.sc())
1409
1410 1411 -def _create_row(fields, values):
1412 row = Row(*values) 1413 row.__FIELDS__ = fields 1414 return row
1415
1416 1417 -class Row(tuple):
1418 1419 """ 1420 A row in L{SchemaRDD}. The fields in it can be accessed like attributes. 1421 1422 Row can be used to create a row object by using named arguments, 1423 the fields will be sorted by names. 1424 1425 >>> row = Row(name="Alice", age=11) 1426 >>> row 1427 Row(age=11, name='Alice') 1428 >>> row.name, row.age 1429 ('Alice', 11) 1430 1431 Row also can be used to create another Row like class, then it 1432 could be used to create Row objects, such as 1433 1434 >>> Person = Row("name", "age") 1435 >>> Person 1436 <Row(name, age)> 1437 >>> Person("Alice", 11) 1438 Row(name='Alice', age=11) 1439 """ 1440
1441 - def __new__(self, *args, **kwargs):
1442 if args and kwargs: 1443 raise ValueError("Can not use both args " 1444 "and kwargs to create Row") 1445 if args: 1446 # create row class or objects 1447 return tuple.__new__(self, args) 1448 1449 elif kwargs: 1450 # create row objects 1451 names = sorted(kwargs.keys()) 1452 values = tuple(kwargs[n] for n in names) 1453 row = tuple.__new__(self, values) 1454 row.__FIELDS__ = names 1455 return row 1456 1457 else: 1458 raise ValueError("No args or kwargs")
1459 1460 # let obect acs like class
1461 - def __call__(self, *args):
1462 """create new Row object""" 1463 return _create_row(self, args)
1464
1465 - def __getattr__(self, item):
1466 if item.startswith("__"): 1467 raise AttributeError(item) 1468 try: 1469 # it will be slow when it has many fields, 1470 # but this will not be used in normal cases 1471 idx = self.__FIELDS__.index(item) 1472 return self[idx] 1473 except IndexError: 1474 raise AttributeError(item)
1475
1476 - def __reduce__(self):
1477 if hasattr(self, "__FIELDS__"): 1478 return (_create_row, (self.__FIELDS__, tuple(self))) 1479 else: 1480 return tuple.__reduce__(self)
1481
1482 - def __repr__(self):
1483 if hasattr(self, "__FIELDS__"): 1484 return "Row(%s)" % ", ".join("%s=%r" % (k, v) 1485 for k, v in zip(self.__FIELDS__, self)) 1486 else: 1487 return "<Row(%s)>" % ", ".join(self)
1488
1489 1490 -class SchemaRDD(RDD):
1491 1492 """An RDD of L{Row} objects that has an associated schema. 1493 1494 The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can 1495 utilize the relational query api exposed by SparkSQL. 1496 1497 For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the 1498 L{SchemaRDD} is not operated on directly, as it's underlying 1499 implementation is an RDD composed of Java objects. Instead it is 1500 converted to a PythonRDD in the JVM, on which Python operations can 1501 be done. 1502 1503 This class receives raw tuples from Java but assigns a class to it in 1504 all its data-collection methods (mapPartitionsWithIndex, collect, take, 1505 etc) so that PySpark sees them as Row objects with named fields. 1506 """ 1507
1508 - def __init__(self, jschema_rdd, sql_ctx):
1509 self.sql_ctx = sql_ctx 1510 self._sc = sql_ctx._sc 1511 clsName = jschema_rdd.getClass().getName() 1512 assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD" 1513 self._jschema_rdd = jschema_rdd 1514 1515 self.is_cached = False 1516 self.is_checkpointed = False 1517 self.ctx = self.sql_ctx._sc 1518 # the _jrdd is created by javaToPython(), serialized by pickle 1519 self._jrdd_deserializer = BatchedSerializer(PickleSerializer())
1520 1521 @property
1522 - def _jrdd(self):
1523 """Lazy evaluation of PythonRDD object. 1524 1525 Only done when a user calls methods defined by the 1526 L{pyspark.rdd.RDD} super class (map, filter, etc.). 1527 """ 1528 if not hasattr(self, '_lazy_jrdd'): 1529 self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython() 1530 return self._lazy_jrdd
1531 1532 @property
1533 - def _id(self):
1534 return self._jrdd.id()
1535
1536 - def saveAsParquetFile(self, path):
1537 """Save the contents as a Parquet file, preserving the schema. 1538 1539 Files that are written out using this method can be read back in as 1540 a SchemaRDD using the L{SQLContext.parquetFile} method. 1541 1542 >>> import tempfile, shutil 1543 >>> parquetFile = tempfile.mkdtemp() 1544 >>> shutil.rmtree(parquetFile) 1545 >>> srdd = sqlCtx.inferSchema(rdd) 1546 >>> srdd.saveAsParquetFile(parquetFile) 1547 >>> srdd2 = sqlCtx.parquetFile(parquetFile) 1548 >>> sorted(srdd2.collect()) == sorted(srdd.collect()) 1549 True 1550 """ 1551 self._jschema_rdd.saveAsParquetFile(path)
1552
1553 - def registerTempTable(self, name):
1554 """Registers this RDD as a temporary table using the given name. 1555 1556 The lifetime of this temporary table is tied to the L{SQLContext} 1557 that was used to create this SchemaRDD. 1558 1559 >>> srdd = sqlCtx.inferSchema(rdd) 1560 >>> srdd.registerTempTable("test") 1561 >>> srdd2 = sqlCtx.sql("select * from test") 1562 >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 1563 True 1564 """ 1565 self._jschema_rdd.registerTempTable(name)
1566
1567 - def registerAsTable(self, name):
1568 warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) 1569 self.registerTempTable(name)
1570
1571 - def insertInto(self, tableName, overwrite=False):
1572 """Inserts the contents of this SchemaRDD into the specified table. 1573 1574 Optionally overwriting any existing data. 1575 """ 1576 self._jschema_rdd.insertInto(tableName, overwrite)
1577
1578 - def saveAsTable(self, tableName):
1579 """Creates a new table with the contents of this SchemaRDD.""" 1580 self._jschema_rdd.saveAsTable(tableName)
1581
1582 - def schema(self):
1583 """Returns the schema of this SchemaRDD (represented by 1584 a L{StructType}).""" 1585 return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
1586
1587 - def schemaString(self):
1588 """Returns the output schema in the tree format.""" 1589 return self._jschema_rdd.schemaString()
1590
1591 - def printSchema(self):
1592 """Prints out the schema in the tree format.""" 1593 print self.schemaString()
1594
1595 - def count(self):
1596 """Return the number of elements in this RDD. 1597 1598 Unlike the base RDD implementation of count, this implementation 1599 leverages the query optimizer to compute the count on the SchemaRDD, 1600 which supports features such as filter pushdown. 1601 1602 >>> srdd = sqlCtx.inferSchema(rdd) 1603 >>> srdd.count() 1604 3L 1605 >>> srdd.count() == srdd.map(lambda x: x).count() 1606 True 1607 """ 1608 return self._jschema_rdd.count()
1609
1610 - def collect(self):
1611 """ 1612 Return a list that contains all of the rows in this RDD. 1613 1614 Each object in the list is on Row, the fields can be accessed as 1615 attributes. 1616 """ 1617 rows = RDD.collect(self) 1618 cls = _create_cls(self.schema()) 1619 return map(cls, rows)
1620 1621 # Convert each object in the RDD to a Row with the right class 1622 # for this SchemaRDD, so that fields can be accessed as attributes.
1623 - def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
1624 """ 1625 Return a new RDD by applying a function to each partition of this RDD, 1626 while tracking the index of the original partition. 1627 1628 >>> rdd = sc.parallelize([1, 2, 3, 4], 4) 1629 >>> def f(splitIndex, iterator): yield splitIndex 1630 >>> rdd.mapPartitionsWithIndex(f).sum() 1631 6 1632 """ 1633 rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) 1634 1635 schema = self.schema() 1636 1637 def applySchema(_, it): 1638 cls = _create_cls(schema) 1639 return itertools.imap(cls, it)
1640 1641 objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) 1642 return objrdd.mapPartitionsWithIndex(f, preservesPartitioning)
1643 1644 # We override the default cache/persist/checkpoint behavior 1645 # as we want to cache the underlying SchemaRDD object in the JVM, 1646 # not the PythonRDD checkpointed by the super class
1647 - def cache(self):
1648 self.is_cached = True 1649 self._jschema_rdd.cache() 1650 return self
1651
1652 - def persist(self, storageLevel):
1653 self.is_cached = True 1654 javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) 1655 self._jschema_rdd.persist(javaStorageLevel) 1656 return self
1657
1658 - def unpersist(self, blocking=True):
1659 self.is_cached = False 1660 self._jschema_rdd.unpersist(blocking) 1661 return self
1662
1663 - def checkpoint(self):
1664 self.is_checkpointed = True 1665 self._jschema_rdd.checkpoint()
1666
1667 - def isCheckpointed(self):
1668 return self._jschema_rdd.isCheckpointed()
1669
1670 - def getCheckpointFile(self):
1671 checkpointFile = self._jschema_rdd.getCheckpointFile() 1672 if checkpointFile.isPresent(): 1673 return checkpointFile.get()
1674
1675 - def coalesce(self, numPartitions, shuffle=False):
1676 rdd = self._jschema_rdd.coalesce(numPartitions, shuffle) 1677 return SchemaRDD(rdd, self.sql_ctx)
1678
1679 - def distinct(self):
1680 rdd = self._jschema_rdd.distinct() 1681 return SchemaRDD(rdd, self.sql_ctx)
1682
1683 - def intersection(self, other):
1684 if (other.__class__ is SchemaRDD): 1685 rdd = self._jschema_rdd.intersection(other._jschema_rdd) 1686 return SchemaRDD(rdd, self.sql_ctx) 1687 else: 1688 raise ValueError("Can only intersect with another SchemaRDD")
1689
1690 - def repartition(self, numPartitions):
1691 rdd = self._jschema_rdd.repartition(numPartitions) 1692 return SchemaRDD(rdd, self.sql_ctx)
1693
1694 - def subtract(self, other, numPartitions=None):
1695 if (other.__class__ is SchemaRDD): 1696 if numPartitions is None: 1697 rdd = self._jschema_rdd.subtract(other._jschema_rdd) 1698 else: 1699 rdd = self._jschema_rdd.subtract(other._jschema_rdd, 1700 numPartitions) 1701 return SchemaRDD(rdd, self.sql_ctx) 1702 else: 1703 raise ValueError("Can only subtract another SchemaRDD")
1704
1705 1706 -def _test():
1707 import doctest 1708 from array import array 1709 from pyspark.context import SparkContext 1710 # let doctest run in pyspark.sql, so DataTypes can be picklable 1711 import pyspark.sql 1712 from pyspark.sql import Row, SQLContext 1713 globs = pyspark.sql.__dict__.copy() 1714 # The small batch size here ensures that we see multiple batches, 1715 # even in these small test examples: 1716 sc = SparkContext('local[4]', 'PythonTest', batchSize=2) 1717 globs['sc'] = sc 1718 globs['sqlCtx'] = SQLContext(sc) 1719 globs['rdd'] = sc.parallelize( 1720 [Row(field1=1, field2="row1"), 1721 Row(field1=2, field2="row2"), 1722 Row(field1=3, field2="row3")] 1723 ) 1724 jsonStrings = [ 1725 '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', 1726 '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' 1727 '"field6":[{"field7": "row2"}]}', 1728 '{"field1" : null, "field2": "row3", ' 1729 '"field3":{"field4":33, "field5": []}}' 1730 ] 1731 globs['jsonStrings'] = jsonStrings 1732 globs['json'] = sc.parallelize(jsonStrings) 1733 (failure_count, test_count) = doctest.testmod( 1734 pyspark.sql, globs=globs, optionflags=doctest.ELLIPSIS) 1735 globs['sc'].stop() 1736 if failure_count: 1737 exit(-1)
1738 1739 1740 if __name__ == "__main__": 1741 _test() 1742