Package org.apache.spark.mllib.tree
Class GradientBoostedTrees
Object
org.apache.spark.mllib.tree.GradientBoostedTrees
- All Implemented Interfaces:
Serializable
,org.apache.spark.internal.Logging
,scala.Serializable
public class GradientBoostedTrees
extends Object
implements scala.Serializable, org.apache.spark.internal.Logging
A class that implements
Stochastic Gradient Boosting
for regression and binary classification.
The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
Notes on Gradient Boosting vs. TreeBoost: - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. - Both algorithms learn tree ensembles by minimizing loss functions. - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes based on the loss function, whereas the original gradient boosting method does not. - When the loss is SquaredError, these methods give the same result, but they could differ for other loss functions.
param: boostingStrategy Parameters for the gradient boosting algorithm. param: seed Random seed.
- See Also:
-
Nested Class Summary
Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging
org.apache.spark.internal.Logging.SparkShellLoggingFilter
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionstatic org.slf4j.Logger
static void
org$apache$spark$internal$Logging$$log__$eq
(org.slf4j.Logger x$1) run
(JavaRDD<LabeledPoint> input) Java-friendly API fororg.apache.spark.mllib.tree.GradientBoostedTrees.run
.run
(RDD<LabeledPoint> input) Method to train a gradient boosting modelrunWithValidation
(JavaRDD<LabeledPoint> input, JavaRDD<LabeledPoint> validationInput) Java-friendly API fororg.apache.spark.mllib.tree.GradientBoostedTrees.runWithValidation
.runWithValidation
(RDD<LabeledPoint> input, RDD<LabeledPoint> validationInput) Method to validate a gradient boosting modelstatic GradientBoostedTreesModel
train
(JavaRDD<LabeledPoint> input, BoostingStrategy boostingStrategy) Java-friendly API fororg.apache.spark.mllib.tree.GradientBoostedTrees.train
static GradientBoostedTreesModel
train
(RDD<LabeledPoint> input, BoostingStrategy boostingStrategy) Method to train a gradient boosting model.Methods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
Methods inherited from interface org.apache.spark.internal.Logging
initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq
-
Constructor Details
-
GradientBoostedTrees
- Parameters:
boostingStrategy
- Parameters for the gradient boosting algorithm.
-
-
Method Details
-
train
public static GradientBoostedTreesModel train(RDD<LabeledPoint> input, BoostingStrategy boostingStrategy) Method to train a gradient boosting model.- Parameters:
input
- Training dataset: RDD ofLabeledPoint
. For classification, labels should take values {0, 1, ..., numClasses-1}. For regression, labels are real numbers.boostingStrategy
- Configuration options for the boosting algorithm.- Returns:
- GradientBoostedTreesModel that can be used for prediction.
-
train
public static GradientBoostedTreesModel train(JavaRDD<LabeledPoint> input, BoostingStrategy boostingStrategy) Java-friendly API fororg.apache.spark.mllib.tree.GradientBoostedTrees.train
- Parameters:
input
- (undocumented)boostingStrategy
- (undocumented)- Returns:
- (undocumented)
-
org$apache$spark$internal$Logging$$log_
public static org.slf4j.Logger org$apache$spark$internal$Logging$$log_() -
org$apache$spark$internal$Logging$$log__$eq
public static void org$apache$spark$internal$Logging$$log__$eq(org.slf4j.Logger x$1) -
run
Method to train a gradient boosting model- Parameters:
input
- Training dataset: RDD ofLabeledPoint
.- Returns:
- GradientBoostedTreesModel that can be used for prediction.
-
run
Java-friendly API fororg.apache.spark.mllib.tree.GradientBoostedTrees.run
.- Parameters:
input
- (undocumented)- Returns:
- (undocumented)
-
runWithValidation
public GradientBoostedTreesModel runWithValidation(RDD<LabeledPoint> input, RDD<LabeledPoint> validationInput) Method to validate a gradient boosting model- Parameters:
input
- Training dataset: RDD ofLabeledPoint
.validationInput
- Validation dataset. This dataset should be different from the training dataset, but it should follow the same distribution. E.g., these two datasets could be created from an original dataset by usingorg.apache.spark.rdd.RDD.randomSplit()
- Returns:
- GradientBoostedTreesModel that can be used for prediction.
-
runWithValidation
public GradientBoostedTreesModel runWithValidation(JavaRDD<LabeledPoint> input, JavaRDD<LabeledPoint> validationInput) Java-friendly API fororg.apache.spark.mllib.tree.GradientBoostedTrees.runWithValidation
.- Parameters:
input
- (undocumented)validationInput
- (undocumented)- Returns:
- (undocumented)
-