public class GradientBoostedTreesModel extends java.lang.Object implements Saveable
param: algo algorithm for the ensemble model, either Classification or Regression param: trees tree ensembles param: treeWeights tree ensemble weights
Constructor and Description |
---|
GradientBoostedTreesModel(scala.Enumeration.Value algo,
DecisionTreeModel[] trees,
double[] treeWeights) |
Modifier and Type | Method and Description |
---|---|
scala.Enumeration.Value |
algo() |
protected static scala.Enumeration.Value |
combiningStrategy() |
protected scala.Enumeration.Value |
combiningStrategy() |
static RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> |
computeInitialPredictionAndError(RDD<LabeledPoint> data,
double initTreeWeight,
DecisionTreeModel initTree,
Loss loss)
:: DeveloperApi ::
Compute the initial predictions and errors for a dataset for the first
iteration of gradient boosting.
|
double[] |
evaluateEachIteration(RDD<LabeledPoint> data,
Loss loss)
Method to compute error or loss for every iteration of gradient boosting.
|
protected java.lang.String |
formatVersion()
Current version of model save/load format.
|
static GradientBoostedTreesModel |
load(SparkContext sc,
java.lang.String path) |
static int |
numTrees() |
int |
numTrees()
Get number of trees in ensemble.
|
static JavaRDD<java.lang.Double> |
predict(JavaRDD<Vector> features) |
JavaRDD<java.lang.Double> |
predict(JavaRDD<Vector> features)
Java-friendly version of
TreeEnsembleModel.predict(org.apache.spark.mllib.linalg.Vector) . |
static RDD<java.lang.Object> |
predict(RDD<Vector> features) |
RDD<java.lang.Object> |
predict(RDD<Vector> features)
Predict values for the given data set.
|
static double |
predict(Vector features) |
double |
predict(Vector features)
Predict values for a single data point using the model trained.
|
void |
save(SparkContext sc,
java.lang.String path)
Save this model to the given path.
|
static java.lang.String |
toDebugString() |
java.lang.String |
toDebugString()
Print the full model to a string.
|
static java.lang.String |
toString() |
java.lang.String |
toString()
Print a summary of the model.
|
static int |
totalNumNodes() |
int |
totalNumNodes()
Get total number of nodes, summed over all trees in the ensemble.
|
DecisionTreeModel[] |
trees() |
double[] |
treeWeights() |
static RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> |
updatePredictionError(RDD<LabeledPoint> data,
RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> predictionAndError,
double treeWeight,
DecisionTreeModel tree,
Loss loss)
:: DeveloperApi ::
Update a zipped predictionError RDD
(as obtained with computeInitialPredictionAndError)
|
public GradientBoostedTreesModel(scala.Enumeration.Value algo, DecisionTreeModel[] trees, double[] treeWeights)
public static RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> computeInitialPredictionAndError(RDD<LabeledPoint> data, double initTreeWeight, DecisionTreeModel initTree, Loss loss)
data:
- training data.initTreeWeight:
- learning rate assigned to the first tree.initTree:
- first DecisionTreeModel.loss:
- evaluation metric.public static RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> updatePredictionError(RDD<LabeledPoint> data, RDD<scala.Tuple2<java.lang.Object,java.lang.Object>> predictionAndError, double treeWeight, DecisionTreeModel tree, Loss loss)
data:
- training data.predictionAndError:
- predictionError RDDtreeWeight:
- Learning rate.tree:
- Tree using which the prediction and error should be updated.loss:
- evaluation metric.public static GradientBoostedTreesModel load(SparkContext sc, java.lang.String path)
sc
- Spark context used for loading model files.path
- Path specifying the directory to which the model was saved.protected static scala.Enumeration.Value combiningStrategy()
public static double predict(Vector features)
public static java.lang.String toString()
public static java.lang.String toDebugString()
public static int numTrees()
public static int totalNumNodes()
public scala.Enumeration.Value algo()
public DecisionTreeModel[] trees()
public double[] treeWeights()
public void save(SparkContext sc, java.lang.String path)
Saveable
This saves: - human-readable (JSON) model metadata to path/metadata/ - Parquet formatted data to path/data/
The model may be loaded using Loader.load
.
public double[] evaluateEachIteration(RDD<LabeledPoint> data, Loss loss)
data
- RDD of LabeledPoint
loss
- evaluation metric.protected java.lang.String formatVersion()
Saveable
formatVersion
in interface Saveable
protected scala.Enumeration.Value combiningStrategy()
public double predict(Vector features)
features
- array representing a single data pointpublic RDD<java.lang.Object> predict(RDD<Vector> features)
features
- RDD representing data points to be predictedpublic JavaRDD<java.lang.Double> predict(JavaRDD<Vector> features)
TreeEnsembleModel.predict(org.apache.spark.mllib.linalg.Vector)
.features
- (undocumented)public java.lang.String toString()
toString
in class java.lang.Object
public java.lang.String toDebugString()
public int numTrees()
public int totalNumNodes()