DecisionTree

class pyspark.mllib.tree.DecisionTree[source]

Learning algorithm for a decision tree model for classification or regression.

New in version 1.1.0.

Methods

trainClassifier(data, numClasses, …[, …])

Train a decision tree model for classification.

trainRegressor(data, categoricalFeaturesInfo)

Train a decision tree model for regression.

Methods Documentation

classmethod trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity='gini', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0)[source]

Train a decision tree model for classification.

New in version 1.1.0.

Parameters:
datapyspark.RDD

Training data: RDD of LabeledPoint. Labels should take values {0, 1, …, numClasses-1}.

numClassesint

Number of classes for classification.

categoricalFeaturesInfodict

Map storing arity of categorical features. An entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, …, k-1}.

impuritystr, optional

Criterion used for information gain calculation. Supported values: “gini” or “entropy”. (default: “gini”)

maxDepthint, optional

Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 means 1 internal node + 2 leaf nodes). (default: 5)

maxBinsint, optional

Number of bins used for finding splits at each node. (default: 32)

minInstancesPerNodeint, optional

Minimum number of instances required at child nodes to create the parent split. (default: 1)

minInfoGainfloat, optional

Minimum info gain required to create a split. (default: 0.0)

Returns:
DecisionTreeModel

Examples

>>> from numpy import array
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import DecisionTree
>>>
>>> data = [
...     LabeledPoint(0.0, [0.0]),
...     LabeledPoint(1.0, [1.0]),
...     LabeledPoint(1.0, [2.0]),
...     LabeledPoint(1.0, [3.0])
... ]
>>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
>>> print(model)
DecisionTreeModel classifier of depth 1 with 3 nodes
>>> print(model.toDebugString())
DecisionTreeModel classifier of depth 1 with 3 nodes
  If (feature 0 <= 0.5)
   Predict: 0.0
  Else (feature 0 > 0.5)
   Predict: 1.0

>>> model.predict(array([1.0]))
1.0
>>> model.predict(array([0.0]))
0.0
>>> rdd = sc.parallelize([[1.0], [0.0]])
>>> model.predict(rdd).collect()
[1.0, 0.0]
classmethod trainRegressor(data, categoricalFeaturesInfo, impurity='variance', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0)[source]

Train a decision tree model for regression.

Parameters:
datapyspark.RDD

Training data: RDD of LabeledPoint. Labels are real numbers.

categoricalFeaturesInfodict

Map storing arity of categorical features. An entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, …, k-1}.

impuritystr, optional

Criterion used for information gain calculation. The only supported value for regression is “variance”. (default: “variance”)

maxDepthint, optional

Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 means 1 internal node + 2 leaf nodes). (default: 5)

maxBinsint, optional

Number of bins used for finding splits at each node. (default: 32)

minInstancesPerNodeint, optional

Minimum number of instances required at child nodes to create the parent split. (default: 1)

minInfoGainfloat, optional

Minimum info gain required to create a split. (default: 0.0)

Returns:
DecisionTreeModel

Examples

>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import DecisionTree
>>> from pyspark.mllib.linalg import SparseVector
>>>
>>> sparse_data = [
...     LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
...     LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
...     LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
...     LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>>
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {})
>>> model.predict(SparseVector(2, {1: 1.0}))
1.0
>>> model.predict(SparseVector(2, {1: 0.0}))
0.0
>>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]])
>>> model.predict(rdd).collect()
[1.0, 0.0]

New in version 1.1.0.