Package pyspark :: Package mllib :: Module tree
[frames] | no frames]

Source Code for Module pyspark.mllib.tree

  1  # 
  2  # Licensed to the Apache Software Foundation (ASF) under one or more 
  3  # contributor license agreements.  See the NOTICE file distributed with 
  4  # this work for additional information regarding copyright ownership. 
  5  # The ASF licenses this file to You under the Apache License, Version 2.0 
  6  # (the "License"); you may not use this file except in compliance with 
  7  # the License.  You may obtain a copy of the License at 
  8  # 
  9  #    http://www.apache.org/licenses/LICENSE-2.0 
 10  # 
 11  # Unless required by applicable law or agreed to in writing, software 
 12  # distributed under the License is distributed on an "AS IS" BASIS, 
 13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 14  # See the License for the specific language governing permissions and 
 15  # limitations under the License. 
 16  # 
 17   
 18  from py4j.java_collections import MapConverter 
 19   
 20  from pyspark import SparkContext, RDD 
 21  from pyspark.mllib._common import \ 
 22      _get_unmangled_rdd, _get_unmangled_double_vector_rdd, _serialize_double_vector, \ 
 23      _deserialize_labeled_point, _get_unmangled_labeled_point_rdd, \ 
 24      _deserialize_double 
 25  from pyspark.mllib.regression import LabeledPoint 
 26  from pyspark.serializers import NoOpSerializer 
27 28 29 -class DecisionTreeModel(object):
30 31 """ 32 A decision tree model for classification or regression. 33 34 EXPERIMENTAL: This is an experimental API. 35 It will probably be modified for Spark v1.2. 36 """ 37
38 - def __init__(self, sc, java_model):
39 """ 40 :param sc: Spark context 41 :param java_model: Handle to Java model object 42 """ 43 self._sc = sc 44 self._java_model = java_model
45
46 - def __del__(self):
47 self._sc._gateway.detach(self._java_model)
48
49 - def predict(self, x):
50 """ 51 Predict the label of one or more examples. 52 :param x: Data point (feature vector), 53 or an RDD of data points (feature vectors). 54 """ 55 pythonAPI = self._sc._jvm.PythonMLLibAPI() 56 if isinstance(x, RDD): 57 # Bulk prediction 58 if x.count() == 0: 59 return self._sc.parallelize([]) 60 dataBytes = _get_unmangled_double_vector_rdd(x, cache=False) 61 jSerializedPreds = \ 62 pythonAPI.predictDecisionTreeModel(self._java_model, 63 dataBytes._jrdd) 64 serializedPreds = RDD(jSerializedPreds, self._sc, NoOpSerializer()) 65 return serializedPreds.map(lambda bytes: _deserialize_double(bytearray(bytes))) 66 else: 67 # Assume x is a single data point. 68 x_ = _serialize_double_vector(x) 69 return pythonAPI.predictDecisionTreeModel(self._java_model, x_)
70
71 - def numNodes(self):
72 return self._java_model.numNodes()
73
74 - def depth(self):
75 return self._java_model.depth()
76
77 - def __str__(self):
78 return self._java_model.toString()
79
80 81 -class DecisionTree(object):
82 83 """ 84 Learning algorithm for a decision tree model 85 for classification or regression. 86 87 EXPERIMENTAL: This is an experimental API. 88 It will probably be modified for Spark v1.2. 89 90 Example usage: 91 >>> from numpy import array 92 >>> import sys 93 >>> from pyspark.mllib.regression import LabeledPoint 94 >>> from pyspark.mllib.tree import DecisionTree 95 >>> from pyspark.mllib.linalg import SparseVector 96 >>> 97 >>> data = [ 98 ... LabeledPoint(0.0, [0.0]), 99 ... LabeledPoint(1.0, [1.0]), 100 ... LabeledPoint(1.0, [2.0]), 101 ... LabeledPoint(1.0, [3.0]) 102 ... ] 103 >>> categoricalFeaturesInfo = {} # no categorical features 104 >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2, 105 ... categoricalFeaturesInfo=categoricalFeaturesInfo) 106 >>> sys.stdout.write(model) 107 DecisionTreeModel classifier 108 If (feature 0 <= 0.5) 109 Predict: 0.0 110 Else (feature 0 > 0.5) 111 Predict: 1.0 112 >>> model.predict(array([1.0])) > 0 113 True 114 >>> model.predict(array([0.0])) == 0 115 True 116 >>> sparse_data = [ 117 ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), 118 ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), 119 ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), 120 ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) 121 ... ] 122 >>> 123 >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), 124 ... categoricalFeaturesInfo=categoricalFeaturesInfo) 125 >>> model.predict(array([0.0, 1.0])) == 1 126 True 127 >>> model.predict(array([0.0, 0.0])) == 0 128 True 129 >>> model.predict(SparseVector(2, {1: 1.0})) == 1 130 True 131 >>> model.predict(SparseVector(2, {1: 0.0})) == 0 132 True 133 """ 134 135 @staticmethod
136 - def trainClassifier(data, numClasses, categoricalFeaturesInfo, 137 impurity="gini", maxDepth=4, maxBins=100):
138 """ 139 Train a DecisionTreeModel for classification. 140 141 :param data: Training data: RDD of LabeledPoint. 142 Labels are integers {0,1,...,numClasses}. 143 :param numClasses: Number of classes for classification. 144 :param categoricalFeaturesInfo: Map from categorical feature index 145 to number of categories. 146 Any feature not in this map 147 is treated as continuous. 148 :param impurity: Supported values: "entropy" or "gini" 149 :param maxDepth: Max depth of tree. 150 E.g., depth 0 means 1 leaf node. 151 Depth 1 means 1 internal node + 2 leaf nodes. 152 :param maxBins: Number of bins used for finding splits at each node. 153 :return: DecisionTreeModel 154 """ 155 sc = data.context 156 dataBytes = _get_unmangled_labeled_point_rdd(data) 157 categoricalFeaturesInfoJMap = \ 158 MapConverter().convert(categoricalFeaturesInfo, 159 sc._gateway._gateway_client) 160 model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( 161 dataBytes._jrdd, "classification", 162 numClasses, categoricalFeaturesInfoJMap, 163 impurity, maxDepth, maxBins) 164 dataBytes.unpersist() 165 return DecisionTreeModel(sc, model)
166 167 @staticmethod
168 - def trainRegressor(data, categoricalFeaturesInfo, 169 impurity="variance", maxDepth=4, maxBins=100):
170 """ 171 Train a DecisionTreeModel for regression. 172 173 :param data: Training data: RDD of LabeledPoint. 174 Labels are real numbers. 175 :param categoricalFeaturesInfo: Map from categorical feature index 176 to number of categories. 177 Any feature not in this map 178 is treated as continuous. 179 :param impurity: Supported values: "variance" 180 :param maxDepth: Max depth of tree. 181 E.g., depth 0 means 1 leaf node. 182 Depth 1 means 1 internal node + 2 leaf nodes. 183 :param maxBins: Number of bins used for finding splits at each node. 184 :return: DecisionTreeModel 185 """ 186 sc = data.context 187 dataBytes = _get_unmangled_labeled_point_rdd(data) 188 categoricalFeaturesInfoJMap = \ 189 MapConverter().convert(categoricalFeaturesInfo, 190 sc._gateway._gateway_client) 191 model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( 192 dataBytes._jrdd, "regression", 193 0, categoricalFeaturesInfoJMap, 194 impurity, maxDepth, maxBins) 195 dataBytes.unpersist() 196 return DecisionTreeModel(sc, model)
197
198 199 -def _test():
200 import doctest 201 globs = globals().copy() 202 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 203 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) 204 globs['sc'].stop() 205 if failure_count: 206 exit(-1)
207 208 if __name__ == "__main__": 209 _test() 210