Class GradientBoostedTreesModel
Object
org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
- All Implemented Interfaces:
Serializable
,Saveable
Represents a gradient boosted trees model.
param: algo algorithm for the ensemble model, either Classification or Regression param: trees tree ensembles param: treeWeights tree ensemble weights
- See Also:
-
Constructor Summary
ConstructorDescriptionGradientBoostedTreesModel
(scala.Enumeration.Value algo, DecisionTreeModel[] trees, double[] treeWeights) -
Method Summary
Modifier and TypeMethodDescriptionscala.Enumeration.Value
algo()
computeInitialPredictionAndError
(RDD<LabeledPoint> data, double initTreeWeight, DecisionTreeModel initTree, Loss loss) Compute the initial predictions and errors for a dataset for the first iteration of gradient boosting.double[]
evaluateEachIteration
(RDD<LabeledPoint> data, Loss loss) Method to compute error or loss for every iteration of gradient boosting.static GradientBoostedTreesModel
load
(SparkContext sc, String path) static org.apache.spark.internal.Logging.LogStringContext
LogStringContext
(scala.StringContext sc) int
numTrees()
Get number of trees in ensemble.static org.slf4j.Logger
static void
org$apache$spark$internal$Logging$$log__$eq
(org.slf4j.Logger x$1) Java-friendly version oforg.apache.spark.mllib.tree.model.TreeEnsembleModel.predict
.double
Predict values for a single data point using the model trained.Predict values for the given data set.void
save
(SparkContext sc, String path) Save this model to the given path.Print the full model to a string.toString()
Print a summary of the model.int
Get total number of nodes, summed over all trees in the ensemble.trees()
double[]
updatePredictionError
(RDD<LabeledPoint> data, RDD<scala.Tuple2<Object, Object>> predictionAndError, double treeWeight, DecisionTreeModel tree, Loss loss) Update a zipped predictionError RDD (as obtained with computeInitialPredictionAndError)
-
Constructor Details
-
GradientBoostedTreesModel
public GradientBoostedTreesModel(scala.Enumeration.Value algo, DecisionTreeModel[] trees, double[] treeWeights)
-
-
Method Details
-
computeInitialPredictionAndError
public static RDD<scala.Tuple2<Object,Object>> computeInitialPredictionAndError(RDD<LabeledPoint> data, double initTreeWeight, DecisionTreeModel initTree, Loss loss) Compute the initial predictions and errors for a dataset for the first iteration of gradient boosting.- Parameters:
data
- : training data.initTreeWeight
- : learning rate assigned to the first tree.initTree
- : first DecisionTreeModel.loss
- : evaluation metric.- Returns:
- an RDD with each element being a zip of the prediction and error corresponding to every sample.
-
updatePredictionError
public static RDD<scala.Tuple2<Object,Object>> updatePredictionError(RDD<LabeledPoint> data, RDD<scala.Tuple2<Object, Object>> predictionAndError, double treeWeight, DecisionTreeModel tree, Loss loss) Update a zipped predictionError RDD (as obtained with computeInitialPredictionAndError)- Parameters:
data
- : training data.predictionAndError
- : predictionError RDDtreeWeight
- : Learning rate.tree
- : Tree using which the prediction and error should be updated.loss
- : evaluation metric.- Returns:
- an RDD with each element being a zip of the prediction and error corresponding to each sample.
-
load
- Parameters:
sc
- Spark context used for loading model files.path
- Path specifying the directory to which the model was saved.- Returns:
- Model instance
-
algo
public scala.Enumeration.Value algo() -
trees
-
treeWeights
public double[] treeWeights() -
save
Description copied from interface:Saveable
Save this model to the given path.This saves: - human-readable (JSON) model metadata to path/metadata/ - Parquet formatted data to path/data/
The model may be loaded using
Loader.load
. -
evaluateEachIteration
Method to compute error or loss for every iteration of gradient boosting.- Parameters:
data
- RDD ofLabeledPoint
loss
- evaluation metric.- Returns:
- an array with index i having the losses or errors for the ensemble containing the first i+1 trees
-
org$apache$spark$internal$Logging$$log_
public static org.slf4j.Logger org$apache$spark$internal$Logging$$log_() -
org$apache$spark$internal$Logging$$log__$eq
public static void org$apache$spark$internal$Logging$$log__$eq(org.slf4j.Logger x$1) -
LogStringContext
public static org.apache.spark.internal.Logging.LogStringContext LogStringContext(scala.StringContext sc) -
predict
Predict values for a single data point using the model trained.- Parameters:
features
- array representing a single data point- Returns:
- predicted category from the trained model
-
predict
Predict values for the given data set.- Parameters:
features
- RDD representing data points to be predicted- Returns:
- RDD[Double] where each entry contains the corresponding prediction
-
predict
Java-friendly version oforg.apache.spark.mllib.tree.model.TreeEnsembleModel.predict
.- Parameters:
features
- (undocumented)- Returns:
- (undocumented)
-
toString
Print a summary of the model. -
toDebugString
Print the full model to a string.- Returns:
- (undocumented)
-
numTrees
public int numTrees()Get number of trees in ensemble.- Returns:
- (undocumented)
-
totalNumNodes
public int totalNumNodes()Get total number of nodes, summed over all trees in the ensemble.- Returns:
- (undocumented)
-