Class RandomForest

Object
org.apache.spark.ml.tree.impl.RandomForest

public class RandomForest extends Object
ALGORITHM

This is a sketch of the algorithm to help new developers.

The algorithm partitions data by instances (rows). On each iteration, the algorithm splits a set of nodes. In order to choose the best split for a given node, sufficient statistics are collected from the distributed data. For each node, the statistics are collected to some worker node, and that worker selects the best split.

This setup requires discretization of continuous features. This binning is done in the findSplits() method during initialization, after which each continuous feature becomes an ordered discretized feature with at most maxBins possible values.

The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes lie at the periphery of the tree being trained. If multiple trees are being trained at once, then this queue contains nodes from all of them. Each iteration works roughly as follows: On the master node: - Some number of nodes are pulled off of the queue (based on the amount of memory required for their sufficient statistics). - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate features are chosen for each node. See method selectNodesToSplit(). On worker nodes, via method findBestSplits(): - The worker makes one pass over its subset of instances. - For each (tree, node, feature, split) tuple, the worker collects statistics about splitting. Note that the set of (tree, node) pairs is limited to the nodes selected from the queue for this iteration. The set of features considered can also be limited based on featureSubsetStrategy. - For each node, the statistics for that node are aggregated to a particular worker via reduceByKey(). The designated worker chooses the best (feature, split) pair, or chooses to stop splitting if the stopping criteria are met. On the master node: - The master collects all decisions about splitting nodes and updates the model. - The updated model is passed to the workers on the next iteration. This process continues until the node queue is empty.

Most of the methods in this implementation support the statistics aggregation, which is the heaviest part of the computation. In general, this implementation is bound by either the cost of statistics computation on workers or by communicating the sufficient statistics.

  • Constructor Details

    • RandomForest

      public RandomForest()
  • Method Details

    • run

      public static DecisionTreeModel[] run(RDD<LabeledPoint> input, Strategy strategy, int numTrees, String featureSubsetStrategy, long seed)
      Train a random forest.

      Parameters:
      input - Training data: RDD of LabeledPoint
      strategy - (undocumented)
      numTrees - (undocumented)
      featureSubsetStrategy - (undocumented)
      seed - (undocumented)
      Returns:
      an unweighted set of trees
    • runBagged

      public static DecisionTreeModel[] runBagged(RDD<org.apache.spark.ml.tree.impl.BaggedPoint<org.apache.spark.ml.tree.impl.TreePoint>> baggedInput, org.apache.spark.ml.tree.impl.DecisionTreeMetadata metadata, Broadcast<Split[][]> bcSplits, Strategy strategy, int numTrees, String featureSubsetStrategy, long seed, scala.Option<org.apache.spark.ml.util.Instrumentation> instr, boolean prune, scala.Option<String> parentUID)
      Train a random forest with metadata and splits. This method is mainly for GBT, in which bagged input can be reused among trees.

      Parameters:
      baggedInput - bagged training data: RDD of BaggedPoint
      metadata - Learning and dataset metadata for DecisionTree.
      bcSplits - (undocumented)
      strategy - (undocumented)
      numTrees - (undocumented)
      featureSubsetStrategy - (undocumented)
      seed - (undocumented)
      instr - (undocumented)
      prune - (undocumented)
      parentUID - (undocumented)
      Returns:
      an unweighted set of trees
    • run

      public static DecisionTreeModel[] run(RDD<org.apache.spark.ml.feature.Instance> input, Strategy strategy, int numTrees, String featureSubsetStrategy, long seed, scala.Option<org.apache.spark.ml.util.Instrumentation> instr, boolean prune, scala.Option<String> parentUID)
      Train a random forest.

      Parameters:
      input - Training data: RDD of Instance
      strategy - (undocumented)
      numTrees - (undocumented)
      featureSubsetStrategy - (undocumented)
      seed - (undocumented)
      instr - (undocumented)
      prune - (undocumented)
      parentUID - (undocumented)
      Returns:
      an unweighted set of trees
    • org$apache$spark$internal$Logging$$log_

      public static org.slf4j.Logger org$apache$spark$internal$Logging$$log_()
    • org$apache$spark$internal$Logging$$log__$eq

      public static void org$apache$spark$internal$Logging$$log__$eq(org.slf4j.Logger x$1)
    • LogStringContext

      public static org.apache.spark.internal.Logging.LogStringContext LogStringContext(scala.StringContext sc)