org.apache.spark.mllib.tree.model
Class GradientBoostedTreesModel

Object
  extended by org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
All Implemented Interfaces:
java.io.Serializable, Saveable

public class GradientBoostedTreesModel
extends Object
implements Saveable

:: Experimental :: Represents a gradient boosted trees model.

param: algo algorithm for the ensemble model, either Classification or Regression param: trees tree ensembles param: treeWeights tree ensemble weights

See Also:
Serialized Form

Constructor Summary
GradientBoostedTreesModel(scala.Enumeration.Value algo, DecisionTreeModel[] trees, double[] treeWeights)
           
 
Method Summary
 scala.Enumeration.Value algo()
           
static RDD<scala.Tuple2<Object,Object>> computeInitialPredictionAndError(RDD<LabeledPoint> data, double initTreeWeight, DecisionTreeModel initTree, Loss loss)
          Compute the initial predictions and errors for a dataset for the first iteration of gradient boosting.
 double[] evaluateEachIteration(RDD<LabeledPoint> data, Loss loss)
          Method to compute error or loss for every iteration of gradient boosting.
static GradientBoostedTreesModel load(SparkContext sc, String path)
           
 int numTrees()
          Get number of trees in ensemble.
 JavaRDD<Double> predict(JavaRDD<Vector> features)
          Java-friendly version of TreeEnsembleModel.predict(org.apache.spark.mllib.linalg.Vector).
 RDD<Object> predict(RDD<Vector> features)
          Predict values for the given data set.
 double predict(Vector features)
          Predict values for a single data point using the model trained.
 void save(SparkContext sc, String path)
          Save this model to the given path.
 String toDebugString()
          Print the full model to a string.
 String toString()
          Print a summary of the model.
 int totalNumNodes()
          Get total number of nodes, summed over all trees in the ensemble.
 DecisionTreeModel[] trees()
           
 double[] treeWeights()
           
static RDD<scala.Tuple2<Object,Object>> updatePredictionError(RDD<LabeledPoint> data, RDD<scala.Tuple2<Object,Object>> predictionAndError, double treeWeight, DecisionTreeModel tree, Loss loss)
          Update a zipped predictionError RDD (as obtained with computeInitialPredictionAndError)
 
Methods inherited from class Object
equals, getClass, hashCode, notify, notifyAll, wait, wait, wait
 

Constructor Detail

GradientBoostedTreesModel

public GradientBoostedTreesModel(scala.Enumeration.Value algo,
                                 DecisionTreeModel[] trees,
                                 double[] treeWeights)
Method Detail

computeInitialPredictionAndError

public static RDD<scala.Tuple2<Object,Object>> computeInitialPredictionAndError(RDD<LabeledPoint> data,
                                                                                double initTreeWeight,
                                                                                DecisionTreeModel initTree,
                                                                                Loss loss)
Compute the initial predictions and errors for a dataset for the first iteration of gradient boosting.

Parameters:
data: - training data.
initTreeWeight: - learning rate assigned to the first tree.
initTree: - first DecisionTreeModel.
loss: - evaluation metric.
Returns:
a RDD with each element being a zip of the prediction and error corresponding to every sample.

updatePredictionError

public static RDD<scala.Tuple2<Object,Object>> updatePredictionError(RDD<LabeledPoint> data,
                                                                     RDD<scala.Tuple2<Object,Object>> predictionAndError,
                                                                     double treeWeight,
                                                                     DecisionTreeModel tree,
                                                                     Loss loss)
Update a zipped predictionError RDD (as obtained with computeInitialPredictionAndError)

Parameters:
data: - training data.
predictionAndError: - predictionError RDD
treeWeight: - Learning rate.
tree: - Tree using which the prediction and error should be updated.
loss: - evaluation metric.
Returns:
a RDD with each element being a zip of the prediction and error corresponding to each sample.

load

public static GradientBoostedTreesModel load(SparkContext sc,
                                             String path)

algo

public scala.Enumeration.Value algo()

trees

public DecisionTreeModel[] trees()

treeWeights

public double[] treeWeights()

save

public void save(SparkContext sc,
                 String path)
Description copied from interface: Saveable
Save this model to the given path.

This saves: - human-readable (JSON) model metadata to path/metadata/ - Parquet formatted data to path/data/

The model may be loaded using Loader.load.

Specified by:
save in interface Saveable
Parameters:
sc - Spark context used to save model data.
path - Path specifying the directory in which to save this model. If the directory already exists, this method throws an exception.

evaluateEachIteration

public double[] evaluateEachIteration(RDD<LabeledPoint> data,
                                      Loss loss)
Method to compute error or loss for every iteration of gradient boosting.

Parameters:
data - RDD of LabeledPoint
loss - evaluation metric.
Returns:
an array with index i having the losses or errors for the ensemble containing the first i+1 trees

predict

public double predict(Vector features)
Predict values for a single data point using the model trained.

Parameters:
features - array representing a single data point
Returns:
predicted category from the trained model

predict

public RDD<Object> predict(RDD<Vector> features)
Predict values for the given data set.

Parameters:
features - RDD representing data points to be predicted
Returns:
RDD[Double] where each entry contains the corresponding prediction

predict

public JavaRDD<Double> predict(JavaRDD<Vector> features)
Java-friendly version of TreeEnsembleModel.predict(org.apache.spark.mllib.linalg.Vector).

Parameters:
features - (undocumented)
Returns:
(undocumented)

toString

public String toString()
Print a summary of the model.

Overrides:
toString in class Object
Returns:
(undocumented)

toDebugString

public String toDebugString()
Print the full model to a string.

Returns:
(undocumented)

numTrees

public int numTrees()
Get number of trees in ensemble.

Returns:
(undocumented)

totalNumNodes

public int totalNumNodes()
Get total number of nodes, summed over all trees in the ensemble.

Returns:
(undocumented)