public class RandomForest
extends java.lang.Object
This is a sketch of the algorithm to help new developers.
The algorithm partitions data by instances (rows). On each iteration, the algorithm splits a set of nodes. In order to choose the best split for a given node, sufficient statistics are collected from the distributed data. For each node, the statistics are collected to some worker node, and that worker selects the best split.
This setup requires discretization of continuous features. This binning is done in the findSplits() method during initialization, after which each continuous feature becomes an ordered discretized feature with at most maxBins possible values.
The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes lie at the periphery of the tree being trained. If multiple trees are being trained at once, then this queue contains nodes from all of them. Each iteration works roughly as follows: On the master node: - Some number of nodes are pulled off of the queue (based on the amount of memory required for their sufficient statistics). - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate features are chosen for each node. See method selectNodesToSplit(). On worker nodes, via method findBestSplits(): - The worker makes one pass over its subset of instances. - For each (tree, node, feature, split) tuple, the worker collects statistics about splitting. Note that the set of (tree, node) pairs is limited to the nodes selected from the queue for this iteration. The set of features considered can also be limited based on featureSubsetStrategy. - For each node, the statistics for that node are aggregated to a particular worker via reduceByKey(). The designated worker chooses the best (feature, split) pair, or chooses to stop splitting if the stopping criteria are met. On the master node: - The master collects all decisions about splitting nodes and updates the model. - The updated model is passed to the workers on the next iteration. This process continues until the node queue is empty.
Most of the methods in this implementation support the statistics aggregation, which is the heaviest part of the computation. In general, this implementation is bound by either the cost of statistics computation on workers or by communicating the sufficient statistics.
Constructor and Description |
---|
RandomForest() |
Modifier and Type | Method and Description |
---|---|
protected static Split[][] |
findSplits(RDD<LabeledPoint> input,
org.apache.spark.ml.tree.impl.DecisionTreeMetadata metadata,
long seed)
Returns splits for decision tree calculation.
|
protected static void |
initializeLogIfNecessary(boolean isInterpreter) |
protected static boolean |
isTraceEnabled() |
protected static org.slf4j.Logger |
log() |
protected static void |
logDebug(scala.Function0<java.lang.String> msg) |
protected static void |
logDebug(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logError(scala.Function0<java.lang.String> msg) |
protected static void |
logError(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static java.lang.String |
logName() |
protected static void |
logTrace(scala.Function0<java.lang.String> msg) |
protected static void |
logTrace(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
static org.apache.spark.ml.tree.DecisionTreeModel[] |
run(RDD<LabeledPoint> input,
Strategy strategy,
int numTrees,
java.lang.String featureSubsetStrategy,
long seed,
scala.Option<
Train a random forest.
|
public static org.apache.spark.ml.tree.DecisionTreeModel[] run(RDD<LabeledPoint> input, Strategy strategy, int numTrees, java.lang.String featureSubsetStrategy, long seed, scala.Option<> instr, scala.Option<java.lang.String> parentUID)
input
- Training data: RDD of LabeledPoint
strategy
- (undocumented)numTrees
- (undocumented)featureSubsetStrategy
- (undocumented)seed
- (undocumented)instr
- (undocumented)parentUID
- (undocumented)protected static Split[][] findSplits(RDD<LabeledPoint> input, org.apache.spark.ml.tree.impl.DecisionTreeMetadata metadata, long seed)
Continuous features: For each feature, there are numBins - 1 possible splits representing the possible binary decisions at each node in the tree. This finds locations (feature values) for splits using a subsample of the data.
Categorical features: For each feature, there is 1 bin per split. Splits and bins are handled in 2 ways: (a) "unordered features" For multiclass classification with a low-arity feature (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), the feature is split based on subsets of categories. (b) "ordered features" For regression and binary classification, and for multiclass classification with a high-arity feature, there is one bin per category.
input
- Training data: RDD of LabeledPoint
metadata
- Learning and dataset metadataseed
- random seedSplit
of size (numFeatures, numSplits)protected static java.lang.String logName()
protected static org.slf4j.Logger log()
protected static void logInfo(scala.Function0<java.lang.String> msg)
protected static void logDebug(scala.Function0<java.lang.String> msg)
protected static void logTrace(scala.Function0<java.lang.String> msg)
protected static void logWarning(scala.Function0<java.lang.String> msg)
protected static void logError(scala.Function0<java.lang.String> msg)
protected static void logInfo(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logDebug(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logTrace(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logWarning(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logError(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static boolean isTraceEnabled()
protected static void initializeLogIfNecessary(boolean isInterpreter)