Package org.apache.spark.mllib.tree
Class RandomForest
Object
org.apache.spark.mllib.tree.RandomForest
- All Implemented Interfaces:
Serializable
,org.apache.spark.internal.Logging
A class that implements a Random Forest
learning algorithm for classification and regression.
It supports both continuous and categorical features.
The settings for featureSubsetStrategy are based on the following references: - log2: tested in Breiman (2001) - sqrt: recommended by Breiman manual for random forests - The defaults of sqrt (classification) and onethird (regression) match the R randomForest package.
- See Also:
-
- Breiman (2001)
- Breiman manual for random forests param: strategy The configuration parameters for the random forest algorithm which specify the type of random forest (classification or regression), feature type (continuous, categorical), depth of the tree, quantile calculation strategy, etc. param: numTrees If 1, then no bootstrapping is used. If greater than 1, then bootstrapping is done. param: featureSubsetStrategy Number of features to consider for splits at each node. Supported values: "auto", "all", "sqrt", "log2", "onethird". Supported numerical values: "(0.0-1.0]", "[1-n]". If "auto" is set, this parameter is set based on numTrees: if numTrees == 1, set to "all"; if numTrees is greater than 1 (forest) set to "sqrt" for classification and to "onethird" for regression. If a real value "n" in the range (0, 1.0] is set, use n * number of features. If an integer value "n" in the range (1, num features) is set, use n features. param: seed Random seed for bootstrapping and choosing feature subsets.
- Serialized Form
-
Nested Class Summary
Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging
org.apache.spark.internal.Logging.LogStringContext, org.apache.spark.internal.Logging.SparkShellLoggingFilter
-
Constructor Summary
ConstructorDescriptionRandomForest
(Strategy strategy, int numTrees, String featureSubsetStrategy, int seed) -
Method Summary
Modifier and TypeMethodDescriptionstatic org.apache.spark.internal.Logging.LogStringContext
LogStringContext
(scala.StringContext sc) static org.slf4j.Logger
static void
org$apache$spark$internal$Logging$$log__$eq
(org.slf4j.Logger x$1) run
(RDD<LabeledPoint> input) Method to train a decision tree model over an RDDstatic String[]
List of supported feature subset sampling strategies.static RandomForestModel
trainClassifier
(JavaRDD<LabeledPoint> input, int numClasses, Map<Integer, Integer> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed) Java-friendly API fororg.apache.spark.mllib.tree.RandomForest.trainClassifier
static RandomForestModel
trainClassifier
(RDD<LabeledPoint> input, int numClasses, scala.collection.immutable.Map<Object, Object> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed) Method to train a decision tree model for binary or multiclass classification.static RandomForestModel
trainClassifier
(RDD<LabeledPoint> input, Strategy strategy, int numTrees, String featureSubsetStrategy, int seed) Method to train a decision tree model for binary or multiclass classification.static RandomForestModel
trainRegressor
(JavaRDD<LabeledPoint> input, Map<Integer, Integer> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed) Java-friendly API fororg.apache.spark.mllib.tree.RandomForest.trainRegressor
static RandomForestModel
trainRegressor
(RDD<LabeledPoint> input, Strategy strategy, int numTrees, String featureSubsetStrategy, int seed) Method to train a decision tree model for regression.static RandomForestModel
trainRegressor
(RDD<LabeledPoint> input, scala.collection.immutable.Map<Object, Object> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed) Method to train a decision tree model for regression.Methods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
Methods inherited from interface org.apache.spark.internal.Logging
initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logDebug, logDebug, logDebug, logDebug, logError, logError, logError, logError, logInfo, logInfo, logInfo, logInfo, logName, LogStringContext, logTrace, logTrace, logTrace, logTrace, logWarning, logWarning, logWarning, logWarning, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq, withLogContext
-
Constructor Details
-
RandomForest
-
-
Method Details
-
trainClassifier
public static RandomForestModel trainClassifier(RDD<LabeledPoint> input, Strategy strategy, int numTrees, String featureSubsetStrategy, int seed) Method to train a decision tree model for binary or multiclass classification.- Parameters:
input
- Training dataset: RDD ofLabeledPoint
. Labels should take values {0, 1, ..., numClasses-1}.strategy
- Parameters for training each tree in the forest.numTrees
- Number of trees in the random forest.featureSubsetStrategy
- Number of features to consider for splits at each node. Supported values: "auto", "all", "sqrt", "log2", "onethird". If "auto" is set, this parameter is set based on numTrees: if numTrees == 1, set to "all"; if numTrees is greater than 1 (forest) set to "sqrt".seed
- Random seed for bootstrapping and choosing feature subsets.- Returns:
- RandomForestModel that can be used for prediction.
-
trainClassifier
public static RandomForestModel trainClassifier(RDD<LabeledPoint> input, int numClasses, scala.collection.immutable.Map<Object, Object> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed) Method to train a decision tree model for binary or multiclass classification.- Parameters:
input
- Training dataset: RDD ofLabeledPoint
. Labels should take values {0, 1, ..., numClasses-1}.numClasses
- Number of classes for classification.categoricalFeaturesInfo
- Map storing arity of categorical features. An entry (n to k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}.numTrees
- Number of trees in the random forest.featureSubsetStrategy
- Number of features to consider for splits at each node. Supported values: "auto", "all", "sqrt", "log2", "onethird". If "auto" is set, this parameter is set based on numTrees: if numTrees == 1, set to "all"; if numTrees is greater than 1 (forest) set to "sqrt".impurity
- Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy".maxDepth
- Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means 1 internal node + 2 leaf nodes). (suggested value: 4)maxBins
- Maximum number of bins used for splitting features (suggested value: 100)seed
- Random seed for bootstrapping and choosing feature subsets.- Returns:
- RandomForestModel that can be used for prediction.
-
trainClassifier
public static RandomForestModel trainClassifier(JavaRDD<LabeledPoint> input, int numClasses, Map<Integer, Integer> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed) Java-friendly API fororg.apache.spark.mllib.tree.RandomForest.trainClassifier
- Parameters:
input
- (undocumented)numClasses
- (undocumented)categoricalFeaturesInfo
- (undocumented)numTrees
- (undocumented)featureSubsetStrategy
- (undocumented)impurity
- (undocumented)maxDepth
- (undocumented)maxBins
- (undocumented)seed
- (undocumented)- Returns:
- (undocumented)
-
trainRegressor
public static RandomForestModel trainRegressor(RDD<LabeledPoint> input, Strategy strategy, int numTrees, String featureSubsetStrategy, int seed) Method to train a decision tree model for regression.- Parameters:
input
- Training dataset: RDD ofLabeledPoint
. Labels are real numbers.strategy
- Parameters for training each tree in the forest.numTrees
- Number of trees in the random forest.featureSubsetStrategy
- Number of features to consider for splits at each node. Supported values: "auto", "all", "sqrt", "log2", "onethird". If "auto" is set, this parameter is set based on numTrees: if numTrees == 1, set to "all"; if numTrees is greater than 1 (forest) set to "onethird".seed
- Random seed for bootstrapping and choosing feature subsets.- Returns:
- RandomForestModel that can be used for prediction.
-
trainRegressor
public static RandomForestModel trainRegressor(RDD<LabeledPoint> input, scala.collection.immutable.Map<Object, Object> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed) Method to train a decision tree model for regression.- Parameters:
input
- Training dataset: RDD ofLabeledPoint
. Labels are real numbers.categoricalFeaturesInfo
- Map storing arity of categorical features. An entry (n to k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}.numTrees
- Number of trees in the random forest.featureSubsetStrategy
- Number of features to consider for splits at each node. Supported values: "auto", "all", "sqrt", "log2", "onethird". If "auto" is set, this parameter is set based on numTrees: if numTrees == 1, set to "all"; if numTrees is greater than 1 (forest) set to "onethird".impurity
- Criterion used for information gain calculation. The only supported value for regression is "variance".maxDepth
- Maximum depth of the tree. (e.g., depth 0 means 1 leaf node, depth 1 means 1 internal node + 2 leaf nodes). (suggested value: 4)maxBins
- Maximum number of bins used for splitting features. (suggested value: 100)seed
- Random seed for bootstrapping and choosing feature subsets.- Returns:
- RandomForestModel that can be used for prediction.
-
trainRegressor
public static RandomForestModel trainRegressor(JavaRDD<LabeledPoint> input, Map<Integer, Integer> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed) Java-friendly API fororg.apache.spark.mllib.tree.RandomForest.trainRegressor
- Parameters:
input
- (undocumented)categoricalFeaturesInfo
- (undocumented)numTrees
- (undocumented)featureSubsetStrategy
- (undocumented)impurity
- (undocumented)maxDepth
- (undocumented)maxBins
- (undocumented)seed
- (undocumented)- Returns:
- (undocumented)
-
supportedFeatureSubsetStrategies
List of supported feature subset sampling strategies.- Returns:
- (undocumented)
-
org$apache$spark$internal$Logging$$log_
public static org.slf4j.Logger org$apache$spark$internal$Logging$$log_() -
org$apache$spark$internal$Logging$$log__$eq
public static void org$apache$spark$internal$Logging$$log__$eq(org.slf4j.Logger x$1) -
LogStringContext
public static org.apache.spark.internal.Logging.LogStringContext LogStringContext(scala.StringContext sc) -
run
Method to train a decision tree model over an RDD- Parameters:
input
- Training data: RDD ofLabeledPoint
.- Returns:
- RandomForestModel that can be used for prediction.
-