org.apache.spark.mllib.tree
Class GradientBoostedTrees

Object
  extended by org.apache.spark.mllib.tree.GradientBoostedTrees
All Implemented Interfaces:
java.io.Serializable, Logging

public class GradientBoostedTrees
extends Object
implements scala.Serializable, Logging

:: Experimental :: A class that implements Stochastic Gradient Boosting for regression and binary classification.

The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.

Notes on Gradient Boosting vs. TreeBoost: - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. - Both algorithms learn tree ensembles by minimizing loss functions. - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes based on the loss function, whereas the original gradient boosting method does not. - When the loss is SquaredError, these methods give the same result, but they could differ for other loss functions.

param: boostingStrategy Parameters for the gradient boosting algorithm.

See Also:
Serialized Form

Constructor Summary
GradientBoostedTrees(BoostingStrategy boostingStrategy)
           
 
Method Summary
 GradientBoostedTreesModel run(JavaRDD<LabeledPoint> input)
          Java-friendly API for org.apache.spark.mllib.tree.GradientBoostedTrees!#run.
 GradientBoostedTreesModel run(RDD<LabeledPoint> input)
          Method to train a gradient boosting model
 GradientBoostedTreesModel runWithValidation(JavaRDD<LabeledPoint> input, JavaRDD<LabeledPoint> validationInput)
          Java-friendly API for org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation.
 GradientBoostedTreesModel runWithValidation(RDD<LabeledPoint> input, RDD<LabeledPoint> validationInput)
          Method to validate a gradient boosting model
static GradientBoostedTreesModel train(JavaRDD<LabeledPoint> input, BoostingStrategy boostingStrategy)
          Java-friendly API for GradientBoostedTrees$.train(org.apache.spark.rdd.RDD, org.apache.spark.mllib.tree.configuration.BoostingStrategy)
static GradientBoostedTreesModel train(RDD<LabeledPoint> input, BoostingStrategy boostingStrategy)
          Method to train a gradient boosting model.
 
Methods inherited from class Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface org.apache.spark.Logging
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
 

Constructor Detail

GradientBoostedTrees

public GradientBoostedTrees(BoostingStrategy boostingStrategy)
Method Detail

train

public static GradientBoostedTreesModel train(RDD<LabeledPoint> input,
                                              BoostingStrategy boostingStrategy)
Method to train a gradient boosting model.

Parameters:
input - Training dataset: RDD of LabeledPoint. For classification, labels should take values {0, 1, ..., numClasses-1}. For regression, labels are real numbers.
boostingStrategy - Configuration options for the boosting algorithm.
Returns:
a gradient boosted trees model that can be used for prediction

train

public static GradientBoostedTreesModel train(JavaRDD<LabeledPoint> input,
                                              BoostingStrategy boostingStrategy)
Java-friendly API for GradientBoostedTrees$.train(org.apache.spark.rdd.RDD, org.apache.spark.mllib.tree.configuration.BoostingStrategy)

Parameters:
input - (undocumented)
boostingStrategy - (undocumented)
Returns:
(undocumented)

run

public GradientBoostedTreesModel run(RDD<LabeledPoint> input)
Method to train a gradient boosting model

Parameters:
input - Training dataset: RDD of LabeledPoint.
Returns:
a gradient boosted trees model that can be used for prediction

run

public GradientBoostedTreesModel run(JavaRDD<LabeledPoint> input)
Java-friendly API for org.apache.spark.mllib.tree.GradientBoostedTrees!#run.

Parameters:
input - (undocumented)
Returns:
(undocumented)

runWithValidation

public GradientBoostedTreesModel runWithValidation(RDD<LabeledPoint> input,
                                                   RDD<LabeledPoint> validationInput)
Method to validate a gradient boosting model

Parameters:
input - Training dataset: RDD of LabeledPoint.
validationInput - Validation dataset. This dataset should be different from the training dataset, but it should follow the same distribution. E.g., these two datasets could be created from an original dataset by using org.apache.spark.rdd.RDD.randomSplit()
Returns:
a gradient boosted trees model that can be used for prediction

runWithValidation

public GradientBoostedTreesModel runWithValidation(JavaRDD<LabeledPoint> input,
                                                   JavaRDD<LabeledPoint> validationInput)
Java-friendly API for org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation.

Parameters:
input - (undocumented)
validationInput - (undocumented)
Returns:
(undocumented)