Class RandomForest
This is a sketch of the algorithm to help new developers.
The algorithm partitions data by instances (rows). On each iteration, the algorithm splits a set of nodes. In order to choose the best split for a given node, sufficient statistics are collected from the distributed data. For each node, the statistics are collected to some worker node, and that worker selects the best split.
This setup requires discretization of continuous features. This binning is done in the findSplits() method during initialization, after which each continuous feature becomes an ordered discretized feature with at most maxBins possible values.
The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes lie at the periphery of the tree being trained. If multiple trees are being trained at once, then this queue contains nodes from all of them. Each iteration works roughly as follows: On the master node: - Some number of nodes are pulled off of the queue (based on the amount of memory required for their sufficient statistics). - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate features are chosen for each node. See method selectNodesToSplit(). On worker nodes, via method findBestSplits(): - The worker makes one pass over its subset of instances. - For each (tree, node, feature, split) tuple, the worker collects statistics about splitting. Note that the set of (tree, node) pairs is limited to the nodes selected from the queue for this iteration. The set of features considered can also be limited based on featureSubsetStrategy. - For each node, the statistics for that node are aggregated to a particular worker via reduceByKey(). The designated worker chooses the best (feature, split) pair, or chooses to stop splitting if the stopping criteria are met. On the master node: - The master collects all decisions about splitting nodes and updates the model. - The updated model is passed to the workers on the next iteration. This process continues until the node queue is empty.
Most of the methods in this implementation support the statistics aggregation, which is the heaviest part of the computation. In general, this implementation is bound by either the cost of statistics computation on workers or by communicating the sufficient statistics.
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionstatic org.apache.spark.internal.Logging.LogStringContext
LogStringContext
(scala.StringContext sc) static org.slf4j.Logger
static void
org$apache$spark$internal$Logging$$log__$eq
(org.slf4j.Logger x$1) static DecisionTreeModel[]
run
(RDD<org.apache.spark.ml.feature.Instance> input, Strategy strategy, int numTrees, String featureSubsetStrategy, long seed, scala.Option<org.apache.spark.ml.util.Instrumentation> instr, boolean prune, scala.Option<String> parentUID) Train a random forest.static DecisionTreeModel[]
run
(RDD<LabeledPoint> input, Strategy strategy, int numTrees, String featureSubsetStrategy, long seed) Train a random forest.static DecisionTreeModel[]
runBagged
(RDD<org.apache.spark.ml.tree.impl.BaggedPoint<org.apache.spark.ml.tree.impl.TreePoint>> baggedInput, org.apache.spark.ml.tree.impl.DecisionTreeMetadata metadata, Broadcast<Split[][]> bcSplits, Strategy strategy, int numTrees, String featureSubsetStrategy, long seed, scala.Option<org.apache.spark.ml.util.Instrumentation> instr, boolean prune, scala.Option<String> parentUID) Train a random forest with metadata and splits.
-
Constructor Details
-
RandomForest
public RandomForest()
-
-
Method Details
-
run
public static DecisionTreeModel[] run(RDD<LabeledPoint> input, Strategy strategy, int numTrees, String featureSubsetStrategy, long seed) Train a random forest.- Parameters:
input
- Training data: RDD ofLabeledPoint
strategy
- (undocumented)numTrees
- (undocumented)featureSubsetStrategy
- (undocumented)seed
- (undocumented)- Returns:
- an unweighted set of trees
-
runBagged
public static DecisionTreeModel[] runBagged(RDD<org.apache.spark.ml.tree.impl.BaggedPoint<org.apache.spark.ml.tree.impl.TreePoint>> baggedInput, org.apache.spark.ml.tree.impl.DecisionTreeMetadata metadata, Broadcast<Split[][]> bcSplits, Strategy strategy, int numTrees, String featureSubsetStrategy, long seed, scala.Option<org.apache.spark.ml.util.Instrumentation> instr, boolean prune, scala.Option<String> parentUID) Train a random forest with metadata and splits. This method is mainly for GBT, in which bagged input can be reused among trees.- Parameters:
baggedInput
- bagged training data: RDD ofBaggedPoint
metadata
- Learning and dataset metadata for DecisionTree.bcSplits
- (undocumented)strategy
- (undocumented)numTrees
- (undocumented)featureSubsetStrategy
- (undocumented)seed
- (undocumented)instr
- (undocumented)prune
- (undocumented)parentUID
- (undocumented)- Returns:
- an unweighted set of trees
-
run
public static DecisionTreeModel[] run(RDD<org.apache.spark.ml.feature.Instance> input, Strategy strategy, int numTrees, String featureSubsetStrategy, long seed, scala.Option<org.apache.spark.ml.util.Instrumentation> instr, boolean prune, scala.Option<String> parentUID) Train a random forest.- Parameters:
input
- Training data: RDD ofInstance
strategy
- (undocumented)numTrees
- (undocumented)featureSubsetStrategy
- (undocumented)seed
- (undocumented)instr
- (undocumented)prune
- (undocumented)parentUID
- (undocumented)- Returns:
- an unweighted set of trees
-
org$apache$spark$internal$Logging$$log_
public static org.slf4j.Logger org$apache$spark$internal$Logging$$log_() -
org$apache$spark$internal$Logging$$log__$eq
public static void org$apache$spark$internal$Logging$$log__$eq(org.slf4j.Logger x$1) -
LogStringContext
public static org.apache.spark.internal.Logging.LogStringContext LogStringContext(scala.StringContext sc)
-