org.apache.spark.mllib.tree
Class DecisionTree

Object
  extended by org.apache.spark.mllib.tree.DecisionTree
All Implemented Interfaces:
java.io.Serializable, Logging

public class DecisionTree
extends Object
implements scala.Serializable, Logging

:: Experimental :: A class which implements a decision tree learning algorithm for classification and regression. It supports both continuous and categorical features. param: strategy The configuration parameters for the tree algorithm which specify the type of algorithm (classification, regression, etc.), feature type (continuous, categorical), depth of the tree, quantile calculation strategy, etc.

See Also:
Serialized Form

Constructor Summary
DecisionTree(Strategy strategy)
           
 
Method Summary
 DecisionTreeModel run(RDD<LabeledPoint> input)
          Method to train a decision tree model over an RDD
static DecisionTreeModel train(RDD<LabeledPoint> input, scala.Enumeration.Value algo, Impurity impurity, int maxDepth)
          Method to train a decision tree model.
static DecisionTreeModel train(RDD<LabeledPoint> input, scala.Enumeration.Value algo, Impurity impurity, int maxDepth, int numClasses)
          Method to train a decision tree model.
static DecisionTreeModel train(RDD<LabeledPoint> input, scala.Enumeration.Value algo, Impurity impurity, int maxDepth, int numClasses, int maxBins, scala.Enumeration.Value quantileCalculationStrategy, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo)
          Method to train a decision tree model.
static DecisionTreeModel train(RDD<LabeledPoint> input, Strategy strategy)
          Method to train a decision tree model.
static DecisionTreeModel trainClassifier(JavaRDD<LabeledPoint> input, int numClasses, java.util.Map<Integer,Integer> categoricalFeaturesInfo, String impurity, int maxDepth, int maxBins)
          Java-friendly API for DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD, int, scala.collection.immutable.Map, java.lang.String, int, int)
static DecisionTreeModel trainClassifier(RDD<LabeledPoint> input, int numClasses, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo, String impurity, int maxDepth, int maxBins)
          Method to train a decision tree model for binary or multiclass classification.
static DecisionTreeModel trainRegressor(JavaRDD<LabeledPoint> input, java.util.Map<Integer,Integer> categoricalFeaturesInfo, String impurity, int maxDepth, int maxBins)
          Java-friendly API for DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD, scala.collection.immutable.Map, java.lang.String, int, int)
static DecisionTreeModel trainRegressor(RDD<LabeledPoint> input, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo, String impurity, int maxDepth, int maxBins)
          Method to train a decision tree model for regression.
 
Methods inherited from class Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface org.apache.spark.Logging
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
 

Constructor Detail

DecisionTree

public DecisionTree(Strategy strategy)
Method Detail

train

public static DecisionTreeModel train(RDD<LabeledPoint> input,
                                      Strategy strategy)
Method to train a decision tree model. The method supports binary and multiclass classification and regression.

Note: Using DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD, int, scala.collection.immutable.Map, java.lang.String, int, int) and DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD, scala.collection.immutable.Map, java.lang.String, int, int) is recommended to clearly separate classification and regression.

Parameters:
input - Training dataset: RDD of LabeledPoint. For classification, labels should take values {0, 1, ..., numClasses-1}. For regression, labels are real numbers.
strategy - The configuration parameters for the tree algorithm which specify the type of algorithm (classification, regression, etc.), feature type (continuous, categorical), depth of the tree, quantile calculation strategy, etc.
Returns:
DecisionTreeModel that can be used for prediction

train

public static DecisionTreeModel train(RDD<LabeledPoint> input,
                                      scala.Enumeration.Value algo,
                                      Impurity impurity,
                                      int maxDepth)
Method to train a decision tree model. The method supports binary and multiclass classification and regression.

Note: Using DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD, int, scala.collection.immutable.Map, java.lang.String, int, int) and DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD, scala.collection.immutable.Map, java.lang.String, int, int) is recommended to clearly separate classification and regression.

Parameters:
input - Training dataset: RDD of LabeledPoint. For classification, labels should take values {0, 1, ..., numClasses-1}. For regression, labels are real numbers.
algo - algorithm, classification or regression
impurity - impurity criterion used for information gain calculation
maxDepth - Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
Returns:
DecisionTreeModel that can be used for prediction

train

public static DecisionTreeModel train(RDD<LabeledPoint> input,
                                      scala.Enumeration.Value algo,
                                      Impurity impurity,
                                      int maxDepth,
                                      int numClasses)
Method to train a decision tree model. The method supports binary and multiclass classification and regression.

Note: Using DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD, int, scala.collection.immutable.Map, java.lang.String, int, int) and DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD, scala.collection.immutable.Map, java.lang.String, int, int) is recommended to clearly separate classification and regression.

Parameters:
input - Training dataset: RDD of LabeledPoint. For classification, labels should take values {0, 1, ..., numClasses-1}. For regression, labels are real numbers.
algo - algorithm, classification or regression
impurity - impurity criterion used for information gain calculation
maxDepth - Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
numClasses - number of classes for classification. Default value of 2.
Returns:
DecisionTreeModel that can be used for prediction

train

public static DecisionTreeModel train(RDD<LabeledPoint> input,
                                      scala.Enumeration.Value algo,
                                      Impurity impurity,
                                      int maxDepth,
                                      int numClasses,
                                      int maxBins,
                                      scala.Enumeration.Value quantileCalculationStrategy,
                                      scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo)
Method to train a decision tree model. The method supports binary and multiclass classification and regression.

Note: Using DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD, int, scala.collection.immutable.Map, java.lang.String, int, int) and DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD, scala.collection.immutable.Map, java.lang.String, int, int) is recommended to clearly separate classification and regression.

Parameters:
input - Training dataset: RDD of LabeledPoint. For classification, labels should take values {0, 1, ..., numClasses-1}. For regression, labels are real numbers.
algo - classification or regression
impurity - criterion used for information gain calculation
maxDepth - Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
numClasses - number of classes for classification. Default value of 2.
maxBins - maximum number of bins used for splitting features
quantileCalculationStrategy - algorithm for calculating quantiles
categoricalFeaturesInfo - Map storing arity of categorical features. E.g., an entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}.
Returns:
DecisionTreeModel that can be used for prediction

trainClassifier

public static DecisionTreeModel trainClassifier(RDD<LabeledPoint> input,
                                                int numClasses,
                                                scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo,
                                                String impurity,
                                                int maxDepth,
                                                int maxBins)
Method to train a decision tree model for binary or multiclass classification.

Parameters:
input - Training dataset: RDD of LabeledPoint. Labels should take values {0, 1, ..., numClasses-1}.
numClasses - number of classes for classification.
categoricalFeaturesInfo - Map storing arity of categorical features. E.g., an entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}.
impurity - Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy".
maxDepth - Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (suggested value: 5)
maxBins - maximum number of bins used for splitting features (suggested value: 32)
Returns:
DecisionTreeModel that can be used for prediction

trainClassifier

public static DecisionTreeModel trainClassifier(JavaRDD<LabeledPoint> input,
                                                int numClasses,
                                                java.util.Map<Integer,Integer> categoricalFeaturesInfo,
                                                String impurity,
                                                int maxDepth,
                                                int maxBins)
Java-friendly API for DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD, int, scala.collection.immutable.Map, java.lang.String, int, int)

Parameters:
input - (undocumented)
numClasses - (undocumented)
categoricalFeaturesInfo - (undocumented)
impurity - (undocumented)
maxDepth - (undocumented)
maxBins - (undocumented)
Returns:
(undocumented)

trainRegressor

public static DecisionTreeModel trainRegressor(RDD<LabeledPoint> input,
                                               scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo,
                                               String impurity,
                                               int maxDepth,
                                               int maxBins)
Method to train a decision tree model for regression.

Parameters:
input - Training dataset: RDD of LabeledPoint. Labels are real numbers.
categoricalFeaturesInfo - Map storing arity of categorical features. E.g., an entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}.
impurity - Criterion used for information gain calculation. Supported values: "variance".
maxDepth - Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (suggested value: 5)
maxBins - maximum number of bins used for splitting features (suggested value: 32)
Returns:
DecisionTreeModel that can be used for prediction

trainRegressor

public static DecisionTreeModel trainRegressor(JavaRDD<LabeledPoint> input,
                                               java.util.Map<Integer,Integer> categoricalFeaturesInfo,
                                               String impurity,
                                               int maxDepth,
                                               int maxBins)
Java-friendly API for DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD, scala.collection.immutable.Map, java.lang.String, int, int)

Parameters:
input - (undocumented)
categoricalFeaturesInfo - (undocumented)
impurity - (undocumented)
maxDepth - (undocumented)
maxBins - (undocumented)
Returns:
(undocumented)

run

public DecisionTreeModel run(RDD<LabeledPoint> input)
Method to train a decision tree model over an RDD

Parameters:
input - Training data: RDD of LabeledPoint
Returns:
DecisionTreeModel that can be used for prediction