org.apache.spark.mllib.tree
Class RandomForest

Object
  extended by org.apache.spark.mllib.tree.RandomForest
All Implemented Interfaces:
java.io.Serializable, Logging

public class RandomForest
extends Object
implements scala.Serializable, Logging

:: Experimental :: A class that implements a Random Forest learning algorithm for classification and regression. It supports both continuous and categorical features.

The settings for featureSubsetStrategy are based on the following references: - log2: tested in Breiman (2001) - sqrt: recommended by Breiman manual for random forests - The defaults of sqrt (classification) and onethird (regression) match the R randomForest package.

See Also:
, http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for random forests}

param: strategy The configuration parameters for the random forest algorithm which specify the type of algorithm (classification, regression, etc.), feature type (continuous, categorical), depth of the tree, quantile calculation strategy, etc. param: numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. param: featureSubsetStrategy Number of features to consider for splits at each node. Supported: "auto", "all", "sqrt", "log2", "onethird". If "auto" is set, this parameter is set based on numTrees: if numTrees == 1, set to "all"; if numTrees > 1 (forest) set to "sqrt" for classification and to "onethird" for regression. param: seed Random seed for bootstrapping and choosing feature subsets., Serialized Form


Constructor Summary
RandomForest(Strategy strategy, int numTrees, String featureSubsetStrategy, int seed)
           
 
Method Summary
 RandomForestModel run(RDD<LabeledPoint> input)
          Method to train a decision tree model over an RDD
static String[] supportedFeatureSubsetStrategies()
          List of supported feature subset sampling strategies.
static RandomForestModel trainClassifier(JavaRDD<LabeledPoint> input, int numClasses, java.util.Map<Integer,Integer> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed)
          Java-friendly API for RandomForest$.trainClassifier(org.apache.spark.rdd.RDD, org.apache.spark.mllib.tree.configuration.Strategy, int, java.lang.String, int)
static RandomForestModel trainClassifier(RDD<LabeledPoint> input, int numClasses, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed)
          Method to train a decision tree model for binary or multiclass classification.
static RandomForestModel trainClassifier(RDD<LabeledPoint> input, Strategy strategy, int numTrees, String featureSubsetStrategy, int seed)
          Method to train a decision tree model for binary or multiclass classification.
static RandomForestModel trainRegressor(JavaRDD<LabeledPoint> input, java.util.Map<Integer,Integer> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed)
          Java-friendly API for RandomForest$.trainRegressor(org.apache.spark.rdd.RDD, org.apache.spark.mllib.tree.configuration.Strategy, int, java.lang.String, int)
static RandomForestModel trainRegressor(RDD<LabeledPoint> input, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed)
          Method to train a decision tree model for regression.
static RandomForestModel trainRegressor(RDD<LabeledPoint> input, Strategy strategy, int numTrees, String featureSubsetStrategy, int seed)
          Method to train a decision tree model for regression.
 
Methods inherited from class Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface org.apache.spark.Logging
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
 

Constructor Detail

RandomForest

public RandomForest(Strategy strategy,
                    int numTrees,
                    String featureSubsetStrategy,
                    int seed)
Method Detail

trainClassifier

public static RandomForestModel trainClassifier(RDD<LabeledPoint> input,
                                                Strategy strategy,
                                                int numTrees,
                                                String featureSubsetStrategy,
                                                int seed)
Method to train a decision tree model for binary or multiclass classification.

Parameters:
input - Training dataset: RDD of LabeledPoint. Labels should take values {0, 1, ..., numClasses-1}.
strategy - Parameters for training each tree in the forest.
numTrees - Number of trees in the random forest.
featureSubsetStrategy - Number of features to consider for splits at each node. Supported: "auto", "all", "sqrt", "log2", "onethird". If "auto" is set, this parameter is set based on numTrees: if numTrees == 1, set to "all"; if numTrees > 1 (forest) set to "sqrt".
seed - Random seed for bootstrapping and choosing feature subsets.
Returns:
a random forest model that can be used for prediction

trainClassifier

public static RandomForestModel trainClassifier(RDD<LabeledPoint> input,
                                                int numClasses,
                                                scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo,
                                                int numTrees,
                                                String featureSubsetStrategy,
                                                String impurity,
                                                int maxDepth,
                                                int maxBins,
                                                int seed)
Method to train a decision tree model for binary or multiclass classification.

Parameters:
input - Training dataset: RDD of LabeledPoint. Labels should take values {0, 1, ..., numClasses-1}.
numClasses - number of classes for classification.
categoricalFeaturesInfo - Map storing arity of categorical features. E.g., an entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}.
numTrees - Number of trees in the random forest.
featureSubsetStrategy - Number of features to consider for splits at each node. Supported: "auto", "all", "sqrt", "log2", "onethird". If "auto" is set, this parameter is set based on numTrees: if numTrees == 1, set to "all"; if numTrees > 1 (forest) set to "sqrt".
impurity - Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy".
maxDepth - Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (suggested value: 4)
maxBins - maximum number of bins used for splitting features (suggested value: 100)
seed - Random seed for bootstrapping and choosing feature subsets.
Returns:
a random forest model that can be used for prediction

trainClassifier

public static RandomForestModel trainClassifier(JavaRDD<LabeledPoint> input,
                                                int numClasses,
                                                java.util.Map<Integer,Integer> categoricalFeaturesInfo,
                                                int numTrees,
                                                String featureSubsetStrategy,
                                                String impurity,
                                                int maxDepth,
                                                int maxBins,
                                                int seed)
Java-friendly API for RandomForest$.trainClassifier(org.apache.spark.rdd.RDD, org.apache.spark.mllib.tree.configuration.Strategy, int, java.lang.String, int)

Parameters:
input - (undocumented)
numClasses - (undocumented)
categoricalFeaturesInfo - (undocumented)
numTrees - (undocumented)
featureSubsetStrategy - (undocumented)
impurity - (undocumented)
maxDepth - (undocumented)
maxBins - (undocumented)
seed - (undocumented)
Returns:
(undocumented)

trainRegressor

public static RandomForestModel trainRegressor(RDD<LabeledPoint> input,
                                               Strategy strategy,
                                               int numTrees,
                                               String featureSubsetStrategy,
                                               int seed)
Method to train a decision tree model for regression.

Parameters:
input - Training dataset: RDD of LabeledPoint. Labels are real numbers.
strategy - Parameters for training each tree in the forest.
numTrees - Number of trees in the random forest.
featureSubsetStrategy - Number of features to consider for splits at each node. Supported: "auto", "all", "sqrt", "log2", "onethird". If "auto" is set, this parameter is set based on numTrees: if numTrees == 1, set to "all"; if numTrees > 1 (forest) set to "onethird".
seed - Random seed for bootstrapping and choosing feature subsets.
Returns:
a random forest model that can be used for prediction

trainRegressor

public static RandomForestModel trainRegressor(RDD<LabeledPoint> input,
                                               scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo,
                                               int numTrees,
                                               String featureSubsetStrategy,
                                               String impurity,
                                               int maxDepth,
                                               int maxBins,
                                               int seed)
Method to train a decision tree model for regression.

Parameters:
input - Training dataset: RDD of LabeledPoint. Labels are real numbers.
categoricalFeaturesInfo - Map storing arity of categorical features. E.g., an entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}.
numTrees - Number of trees in the random forest.
featureSubsetStrategy - Number of features to consider for splits at each node. Supported: "auto", "all", "sqrt", "log2", "onethird". If "auto" is set, this parameter is set based on numTrees: if numTrees == 1, set to "all"; if numTrees > 1 (forest) set to "onethird".
impurity - Criterion used for information gain calculation. Supported values: "variance".
maxDepth - Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (suggested value: 4)
maxBins - maximum number of bins used for splitting features (suggested value: 100)
seed - Random seed for bootstrapping and choosing feature subsets.
Returns:
a random forest model that can be used for prediction

trainRegressor

public static RandomForestModel trainRegressor(JavaRDD<LabeledPoint> input,
                                               java.util.Map<Integer,Integer> categoricalFeaturesInfo,
                                               int numTrees,
                                               String featureSubsetStrategy,
                                               String impurity,
                                               int maxDepth,
                                               int maxBins,
                                               int seed)
Java-friendly API for RandomForest$.trainRegressor(org.apache.spark.rdd.RDD, org.apache.spark.mllib.tree.configuration.Strategy, int, java.lang.String, int)

Parameters:
input - (undocumented)
categoricalFeaturesInfo - (undocumented)
numTrees - (undocumented)
featureSubsetStrategy - (undocumented)
impurity - (undocumented)
maxDepth - (undocumented)
maxBins - (undocumented)
seed - (undocumented)
Returns:
(undocumented)

supportedFeatureSubsetStrategies

public static String[] supportedFeatureSubsetStrategies()
List of supported feature subset sampling strategies.

Returns:
(undocumented)

run

public RandomForestModel run(RDD<LabeledPoint> input)
Method to train a decision tree model over an RDD

Parameters:
input - Training data: RDD of LabeledPoint
Returns:
a random forest model that can be used for prediction