Package org.apache.spark.mllib.tree
Class GradientBoostedTrees
Object
org.apache.spark.mllib.tree.GradientBoostedTrees
- All Implemented Interfaces:
- Serializable,- org.apache.spark.internal.Logging
public class GradientBoostedTrees
extends Object
implements Serializable, org.apache.spark.internal.Logging
A class that implements
 Stochastic Gradient Boosting
 for regression and binary classification.
 
The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
Notes on Gradient Boosting vs. TreeBoost: - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. - Both algorithms learn tree ensembles by minimizing loss functions. - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes based on the loss function, whereas the original gradient boosting method does not. - When the loss is SquaredError, these methods give the same result, but they could differ for other loss functions.
param: boostingStrategy Parameters for the gradient boosting algorithm. param: seed Random seed.
- See Also:
- 
Nested Class SummaryNested classes/interfaces inherited from interface org.apache.spark.internal.Loggingorg.apache.spark.internal.Logging.LogStringContext, org.apache.spark.internal.Logging.SparkShellLoggingFilter
- 
Constructor SummaryConstructors
- 
Method SummaryModifier and TypeMethodDescriptionstatic org.apache.spark.internal.Logging.LogStringContextLogStringContext(scala.StringContext sc) static org.slf4j.Loggerstatic voidorg$apache$spark$internal$Logging$$log__$eq(org.slf4j.Logger x$1) run(JavaRDD<LabeledPoint> input) Java-friendly API fororg.apache.spark.mllib.tree.GradientBoostedTrees.run.run(RDD<LabeledPoint> input) Method to train a gradient boosting modelrunWithValidation(JavaRDD<LabeledPoint> input, JavaRDD<LabeledPoint> validationInput) Java-friendly API fororg.apache.spark.mllib.tree.GradientBoostedTrees.runWithValidation.runWithValidation(RDD<LabeledPoint> input, RDD<LabeledPoint> validationInput) Method to validate a gradient boosting modelstatic GradientBoostedTreesModeltrain(JavaRDD<LabeledPoint> input, BoostingStrategy boostingStrategy) Java-friendly API fororg.apache.spark.mllib.tree.GradientBoostedTrees.trainstatic GradientBoostedTreesModeltrain(RDD<LabeledPoint> input, BoostingStrategy boostingStrategy) Method to train a gradient boosting model.Methods inherited from class java.lang.Objectequals, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitMethods inherited from interface org.apache.spark.internal.LogginginitializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logBasedOnLevel, logDebug, logDebug, logDebug, logDebug, logError, logError, logError, logError, logInfo, logInfo, logInfo, logInfo, logName, LogStringContext, logTrace, logTrace, logTrace, logTrace, logWarning, logWarning, logWarning, logWarning, MDC, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq, withLogContext
- 
Constructor Details- 
GradientBoostedTrees- Parameters:
- boostingStrategy- Parameters for the gradient boosting algorithm.
 
 
- 
- 
Method Details- 
trainpublic static GradientBoostedTreesModel train(RDD<LabeledPoint> input, BoostingStrategy boostingStrategy) Method to train a gradient boosting model.- Parameters:
- input- Training dataset: RDD of- LabeledPoint. For classification, labels should take values {0, 1, ..., numClasses-1}. For regression, labels are real numbers.
- boostingStrategy- Configuration options for the boosting algorithm.
- Returns:
- GradientBoostedTreesModel that can be used for prediction.
 
- 
trainpublic static GradientBoostedTreesModel train(JavaRDD<LabeledPoint> input, BoostingStrategy boostingStrategy) Java-friendly API fororg.apache.spark.mllib.tree.GradientBoostedTrees.train- Parameters:
- input- (undocumented)
- boostingStrategy- (undocumented)
- Returns:
- (undocumented)
 
- 
org$apache$spark$internal$Logging$$log_public static org.slf4j.Logger org$apache$spark$internal$Logging$$log_()
- 
org$apache$spark$internal$Logging$$log__$eqpublic static void org$apache$spark$internal$Logging$$log__$eq(org.slf4j.Logger x$1) 
- 
LogStringContextpublic static org.apache.spark.internal.Logging.LogStringContext LogStringContext(scala.StringContext sc) 
- 
runMethod to train a gradient boosting model- Parameters:
- input- Training dataset: RDD of- LabeledPoint.
- Returns:
- GradientBoostedTreesModel that can be used for prediction.
 
- 
runJava-friendly API fororg.apache.spark.mllib.tree.GradientBoostedTrees.run.- Parameters:
- input- (undocumented)
- Returns:
- (undocumented)
 
- 
runWithValidationpublic GradientBoostedTreesModel runWithValidation(RDD<LabeledPoint> input, RDD<LabeledPoint> validationInput) Method to validate a gradient boosting model- Parameters:
- input- Training dataset: RDD of- LabeledPoint.
- validationInput- Validation dataset. This dataset should be different from the training dataset, but it should follow the same distribution. E.g., these two datasets could be created from an original dataset by using- org.apache.spark.rdd.RDD.randomSplit()
- Returns:
- GradientBoostedTreesModel that can be used for prediction.
 
- 
runWithValidationpublic GradientBoostedTreesModel runWithValidation(JavaRDD<LabeledPoint> input, JavaRDD<LabeledPoint> validationInput) Java-friendly API fororg.apache.spark.mllib.tree.GradientBoostedTrees.runWithValidation.- Parameters:
- input- (undocumented)
- validationInput- (undocumented)
- Returns:
- (undocumented)
 
 
-