public class RandomForest
extends Object
implements scala.Serializable, org.apache.spark.internal.Logging
The settings for featureSubsetStrategy are based on the following references: - log2: tested in Breiman (2001) - sqrt: recommended by Breiman manual for random forests - The defaults of sqrt (classification) and onethird (regression) match the R randomForest package.
Constructor and Description |
---|
RandomForest(Strategy strategy,
int numTrees,
String featureSubsetStrategy,
int seed) |
Modifier and Type | Method and Description |
---|---|
static void |
org$apache$spark$internal$Logging$$log__$eq(org.slf4j.Logger x$1) |
static org.slf4j.Logger |
org$apache$spark$internal$Logging$$log_() |
RandomForestModel |
run(RDD<LabeledPoint> input)
Method to train a decision tree model over an RDD
|
static String[] |
supportedFeatureSubsetStrategies()
List of supported feature subset sampling strategies.
|
static RandomForestModel |
trainClassifier(JavaRDD<LabeledPoint> input,
int numClasses,
java.util.Map<Integer,Integer> categoricalFeaturesInfo,
int numTrees,
String featureSubsetStrategy,
String impurity,
int maxDepth,
int maxBins,
int seed)
Java-friendly API for
org.apache.spark.mllib.tree.RandomForest.trainClassifier |
static RandomForestModel |
trainClassifier(RDD<LabeledPoint> input,
int numClasses,
scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo,
int numTrees,
String featureSubsetStrategy,
String impurity,
int maxDepth,
int maxBins,
int seed)
Method to train a decision tree model for binary or multiclass classification.
|
static RandomForestModel |
trainClassifier(RDD<LabeledPoint> input,
Strategy strategy,
int numTrees,
String featureSubsetStrategy,
int seed)
Method to train a decision tree model for binary or multiclass classification.
|
static RandomForestModel |
trainRegressor(JavaRDD<LabeledPoint> input,
java.util.Map<Integer,Integer> categoricalFeaturesInfo,
int numTrees,
String featureSubsetStrategy,
String impurity,
int maxDepth,
int maxBins,
int seed)
Java-friendly API for
org.apache.spark.mllib.tree.RandomForest.trainRegressor |
static RandomForestModel |
trainRegressor(RDD<LabeledPoint> input,
scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo,
int numTrees,
String featureSubsetStrategy,
String impurity,
int maxDepth,
int maxBins,
int seed)
Method to train a decision tree model for regression.
|
static RandomForestModel |
trainRegressor(RDD<LabeledPoint> input,
Strategy strategy,
int numTrees,
String featureSubsetStrategy,
int seed)
Method to train a decision tree model for regression.
|
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
$init$, initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, initLock, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log__$eq, org$apache$spark$internal$Logging$$log_, uninitialize
public RandomForest(Strategy strategy, int numTrees, String featureSubsetStrategy, int seed)
public static RandomForestModel trainClassifier(RDD<LabeledPoint> input, Strategy strategy, int numTrees, String featureSubsetStrategy, int seed)
input
- Training dataset: RDD of LabeledPoint
.
Labels should take values {0, 1, ..., numClasses-1}.strategy
- Parameters for training each tree in the forest.numTrees
- Number of trees in the random forest.featureSubsetStrategy
- Number of features to consider for splits at each node.
Supported values: "auto", "all", "sqrt", "log2", "onethird".
If "auto" is set, this parameter is set based on numTrees:
if numTrees == 1, set to "all";
if numTrees is greater than 1 (forest) set to "sqrt".seed
- Random seed for bootstrapping and choosing feature subsets.public static RandomForestModel trainClassifier(RDD<LabeledPoint> input, int numClasses, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed)
input
- Training dataset: RDD of LabeledPoint
.
Labels should take values {0, 1, ..., numClasses-1}.numClasses
- Number of classes for classification.categoricalFeaturesInfo
- Map storing arity of categorical features. An entry (n to k)
indicates that feature n is categorical with k categories
indexed from 0: {0, 1, ..., k-1}.numTrees
- Number of trees in the random forest.featureSubsetStrategy
- Number of features to consider for splits at each node.
Supported values: "auto", "all", "sqrt", "log2", "onethird".
If "auto" is set, this parameter is set based on numTrees:
if numTrees == 1, set to "all";
if numTrees is greater than 1 (forest) set to "sqrt".impurity
- Criterion used for information gain calculation.
Supported values: "gini" (recommended) or "entropy".maxDepth
- Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means
1 internal node + 2 leaf nodes).
(suggested value: 4)maxBins
- Maximum number of bins used for splitting features
(suggested value: 100)seed
- Random seed for bootstrapping and choosing feature subsets.public static RandomForestModel trainClassifier(JavaRDD<LabeledPoint> input, int numClasses, java.util.Map<Integer,Integer> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed)
org.apache.spark.mllib.tree.RandomForest.trainClassifier
input
- (undocumented)numClasses
- (undocumented)categoricalFeaturesInfo
- (undocumented)numTrees
- (undocumented)featureSubsetStrategy
- (undocumented)impurity
- (undocumented)maxDepth
- (undocumented)maxBins
- (undocumented)seed
- (undocumented)public static RandomForestModel trainRegressor(RDD<LabeledPoint> input, Strategy strategy, int numTrees, String featureSubsetStrategy, int seed)
input
- Training dataset: RDD of LabeledPoint
.
Labels are real numbers.strategy
- Parameters for training each tree in the forest.numTrees
- Number of trees in the random forest.featureSubsetStrategy
- Number of features to consider for splits at each node.
Supported values: "auto", "all", "sqrt", "log2", "onethird".
If "auto" is set, this parameter is set based on numTrees:
if numTrees == 1, set to "all";
if numTrees is greater than 1 (forest) set to "onethird".seed
- Random seed for bootstrapping and choosing feature subsets.public static RandomForestModel trainRegressor(RDD<LabeledPoint> input, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed)
input
- Training dataset: RDD of LabeledPoint
.
Labels are real numbers.categoricalFeaturesInfo
- Map storing arity of categorical features. An entry (n to k)
indicates that feature n is categorical with k categories
indexed from 0: {0, 1, ..., k-1}.numTrees
- Number of trees in the random forest.featureSubsetStrategy
- Number of features to consider for splits at each node.
Supported values: "auto", "all", "sqrt", "log2", "onethird".
If "auto" is set, this parameter is set based on numTrees:
if numTrees == 1, set to "all";
if numTrees is greater than 1 (forest) set to "onethird".impurity
- Criterion used for information gain calculation.
The only supported value for regression is "variance".maxDepth
- Maximum depth of the tree. (e.g., depth 0 means 1 leaf node, depth 1 means
1 internal node + 2 leaf nodes).
(suggested value: 4)maxBins
- Maximum number of bins used for splitting features.
(suggested value: 100)seed
- Random seed for bootstrapping and choosing feature subsets.public static RandomForestModel trainRegressor(JavaRDD<LabeledPoint> input, java.util.Map<Integer,Integer> categoricalFeaturesInfo, int numTrees, String featureSubsetStrategy, String impurity, int maxDepth, int maxBins, int seed)
org.apache.spark.mllib.tree.RandomForest.trainRegressor
input
- (undocumented)categoricalFeaturesInfo
- (undocumented)numTrees
- (undocumented)featureSubsetStrategy
- (undocumented)impurity
- (undocumented)maxDepth
- (undocumented)maxBins
- (undocumented)seed
- (undocumented)public static String[] supportedFeatureSubsetStrategies()
public static org.slf4j.Logger org$apache$spark$internal$Logging$$log_()
public static void org$apache$spark$internal$Logging$$log__$eq(org.slf4j.Logger x$1)
public RandomForestModel run(RDD<LabeledPoint> input)
input
- Training data: RDD of LabeledPoint
.