Class GBTClassificationModel

All Implemented Interfaces:
Serializable, org.apache.spark.internal.Logging, ClassifierParams, ProbabilisticClassifierParams, Params, HasCheckpointInterval, HasFeaturesCol, HasLabelCol, HasMaxIter, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, HasSeed, HasStepSize, HasThresholds, HasValidationIndicatorCol, HasWeightCol, PredictorParams, DecisionTreeParams, GBTClassifierParams, GBTParams, HasVarianceImpurity, TreeEnsembleClassifierParams, TreeEnsembleModel<DecisionTreeRegressionModel>, TreeEnsembleParams, Identifiable, MLWritable

Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) model for classification. It supports binary labels, as well as both continuous and categorical features.

param: _trees Decision trees in the ensemble. param: _treeWeights Weights for the decision trees in the ensemble.

See Also:
Note:
Multiclass labels are not currently supported.
  • Constructor Details

    • GBTClassificationModel

      public GBTClassificationModel(String uid, DecisionTreeRegressionModel[] _trees, double[] _treeWeights)
      Construct a GBTClassificationModel

      Parameters:
      _trees - Decision trees in the ensemble.
      _treeWeights - Weights for the decision trees in the ensemble.
      uid - (undocumented)
  • Method Details

    • read

      public static MLReader<GBTClassificationModel> read()
    • load

      public static GBTClassificationModel load(String path)
    • totalNumNodes

      public int totalNumNodes()
      Description copied from interface: TreeEnsembleModel
      Total number of nodes, summed over all trees in the ensemble.
      Specified by:
      totalNumNodes in interface TreeEnsembleModel<DecisionTreeRegressionModel>
    • lossType

      public Param<String> lossType()
      Description copied from interface: GBTClassifierParams
      Loss function which GBT tries to minimize. (case-insensitive) Supported: "logistic" (default = logistic)
      Specified by:
      lossType in interface GBTClassifierParams
      Returns:
      (undocumented)
    • impurity

      public final Param<String> impurity()
      Description copied from interface: HasVarianceImpurity
      Criterion used for information gain calculation (case-insensitive). This impurity type is used in DecisionTreeRegressor, RandomForestRegressor, GBTRegressor and GBTClassifier (since GBTClassificationModel is internally composed of DecisionTreeRegressionModels). Supported: "variance". (default = variance)
      Specified by:
      impurity in interface HasVarianceImpurity
      Returns:
      (undocumented)
    • validationTol

      public final DoubleParam validationTol()
      Description copied from interface: GBTParams
      Threshold for stopping early when fit with validation is used. (This parameter is ignored when fit without validation is used.) The decision to stop early is decided based on this logic: If the current loss on the validation set is greater than 0.01, the diff of validation error is compared to relative tolerance which is validationTol * (current loss on the validation set). If the current loss on the validation set is less than or equal to 0.01, the diff of validation error is compared to absolute tolerance which is validationTol * 0.01.
      Specified by:
      validationTol in interface GBTParams
      Returns:
      (undocumented)
      See Also:
    • stepSize

      public final DoubleParam stepSize()
      Description copied from interface: GBTParams
      Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator. (default = 0.1)
      Specified by:
      stepSize in interface GBTParams
      Specified by:
      stepSize in interface HasStepSize
      Returns:
      (undocumented)
    • validationIndicatorCol

      public final Param<String> validationIndicatorCol()
      Description copied from interface: HasValidationIndicatorCol
      Param for name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation..
      Specified by:
      validationIndicatorCol in interface HasValidationIndicatorCol
      Returns:
      (undocumented)
    • maxIter

      public final IntParam maxIter()
      Description copied from interface: HasMaxIter
      Param for maximum number of iterations (&gt;= 0).
      Specified by:
      maxIter in interface HasMaxIter
      Returns:
      (undocumented)
    • subsamplingRate

      public final DoubleParam subsamplingRate()
      Description copied from interface: TreeEnsembleParams
      Fraction of the training data used for learning each decision tree, in range (0, 1]. (default = 1.0)
      Specified by:
      subsamplingRate in interface TreeEnsembleParams
      Returns:
      (undocumented)
    • featureSubsetStrategy

      public final Param<String> featureSubsetStrategy()
      Description copied from interface: TreeEnsembleParams
      The number of features to consider for splits at each tree node. Supported options: - "auto": Choose automatically for task: If numTrees == 1, set to "all." If numTrees greater than 1 (forest), set to "sqrt" for classification and to "onethird" for regression. - "all": use all features - "onethird": use 1/3 of the features - "sqrt": use sqrt(number of features) - "log2": use log2(number of features) - "n": when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features. (default = "auto")

      These various settings are based on the following references: - log2: tested in Breiman (2001) - sqrt: recommended by Breiman manual for random forests - The defaults of sqrt (classification) and onethird (regression) match the R randomForest package.

      Specified by:
      featureSubsetStrategy in interface TreeEnsembleParams
      Returns:
      (undocumented)
      See Also:
    • leafCol

      public final Param<String> leafCol()
      Description copied from interface: DecisionTreeParams
      Leaf indices column name. Predicted leaf index of each instance in each tree by preorder. (default = "")
      Specified by:
      leafCol in interface DecisionTreeParams
      Returns:
      (undocumented)
    • maxDepth

      public final IntParam maxDepth()
      Description copied from interface: DecisionTreeParams
      Maximum depth of the tree (nonnegative). E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default = 5)
      Specified by:
      maxDepth in interface DecisionTreeParams
      Returns:
      (undocumented)
    • maxBins

      public final IntParam maxBins()
      Description copied from interface: DecisionTreeParams
      Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node. More bins give higher granularity. Must be at least 2 and at least number of categories in any categorical feature. (default = 32)
      Specified by:
      maxBins in interface DecisionTreeParams
      Returns:
      (undocumented)
    • minInstancesPerNode

      public final IntParam minInstancesPerNode()
      Description copied from interface: DecisionTreeParams
      Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Must be at least 1. (default = 1)
      Specified by:
      minInstancesPerNode in interface DecisionTreeParams
      Returns:
      (undocumented)
    • minWeightFractionPerNode

      public final DoubleParam minWeightFractionPerNode()
      Description copied from interface: DecisionTreeParams
      Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in the interval [0.0, 0.5). (default = 0.0)
      Specified by:
      minWeightFractionPerNode in interface DecisionTreeParams
      Returns:
      (undocumented)
    • minInfoGain

      public final DoubleParam minInfoGain()
      Description copied from interface: DecisionTreeParams
      Minimum information gain for a split to be considered at a tree node. Should be at least 0.0. (default = 0.0)
      Specified by:
      minInfoGain in interface DecisionTreeParams
      Returns:
      (undocumented)
    • maxMemoryInMB

      public final IntParam maxMemoryInMB()
      Description copied from interface: DecisionTreeParams
      Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size. (default = 256 MB)
      Specified by:
      maxMemoryInMB in interface DecisionTreeParams
      Returns:
      (undocumented)
    • cacheNodeIds

      public final BooleanParam cacheNodeIds()
      Description copied from interface: DecisionTreeParams
      If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. (default = false)
      Specified by:
      cacheNodeIds in interface DecisionTreeParams
      Returns:
      (undocumented)
    • weightCol

      public final Param<String> weightCol()
      Description copied from interface: HasWeightCol
      Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.
      Specified by:
      weightCol in interface HasWeightCol
      Returns:
      (undocumented)
    • seed

      public final LongParam seed()
      Description copied from interface: HasSeed
      Param for random seed.
      Specified by:
      seed in interface HasSeed
      Returns:
      (undocumented)
    • checkpointInterval

      public final IntParam checkpointInterval()
      Description copied from interface: HasCheckpointInterval
      Param for set checkpoint interval (&gt;= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.
      Specified by:
      checkpointInterval in interface HasCheckpointInterval
      Returns:
      (undocumented)
    • uid

      public String uid()
      Description copied from interface: Identifiable
      An immutable unique ID for the object and its derivatives.
      Specified by:
      uid in interface Identifiable
      Returns:
      (undocumented)
    • numFeatures

      public int numFeatures()
      Description copied from class: PredictionModel
      Returns the number of features the model was trained on. If unknown, returns -1
      Overrides:
      numFeatures in class PredictionModel<Vector,GBTClassificationModel>
    • numClasses

      public int numClasses()
      Description copied from class: ClassificationModel
      Number of classes (values which the label can take).
      Specified by:
      numClasses in class ClassificationModel<Vector,GBTClassificationModel>
    • trees

      public DecisionTreeRegressionModel[] trees()
      Description copied from interface: TreeEnsembleModel
      Trees in this ensemble. Warning: These have null parent Estimators.
      Specified by:
      trees in interface TreeEnsembleModel<DecisionTreeRegressionModel>
    • getNumTrees

      public int getNumTrees()
      Number of trees in ensemble
      Returns:
      (undocumented)
    • treeWeights

      public double[] treeWeights()
      Description copied from interface: TreeEnsembleModel
      Weights for each tree, zippable with TreeEnsembleModel.trees()
      Specified by:
      treeWeights in interface TreeEnsembleModel<DecisionTreeRegressionModel>
    • transformSchema

      public StructType transformSchema(StructType schema)
      Description copied from class: PipelineStage
      Check transform validity and derive the output schema from the input schema.

      We check validity for interactions between parameters during transformSchema and raise an exception if any parameter value is invalid. Parameter value checks which do not depend on other parameters are handled by Param.validate().

      Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks.

      Overrides:
      transformSchema in class ProbabilisticClassificationModel<Vector,GBTClassificationModel>
      Parameters:
      schema - (undocumented)
      Returns:
      (undocumented)
    • transform

      public Dataset<Row> transform(Dataset<?> dataset)
      Description copied from class: ProbabilisticClassificationModel
      Transforms dataset by reading from PredictionModel.featuresCol(), and appending new columns as specified by parameters: - predicted labels as PredictionModel.predictionCol() of type Double - raw predictions (confidences) as ClassificationModel.rawPredictionCol() of type Vector - probability of each class as ProbabilisticClassificationModel.probabilityCol() of type Vector.

      Overrides:
      transform in class ProbabilisticClassificationModel<Vector,GBTClassificationModel>
      Parameters:
      dataset - input dataset
      Returns:
      transformed dataset
    • predict

      public double predict(Vector features)
      Description copied from class: ClassificationModel
      Predict label for the given features. This method is used to implement transform() and output PredictionModel.predictionCol().

      This default implementation for classification predicts the index of the maximum value from predictRaw().

      Overrides:
      predict in class ClassificationModel<Vector,GBTClassificationModel>
      Parameters:
      features - (undocumented)
      Returns:
      (undocumented)
    • predictRaw

      public Vector predictRaw(Vector features)
      Description copied from class: ClassificationModel
      Raw prediction for each possible label. The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives a measure of confidence in each possible label (where larger = more confident). This internal method is used to implement transform() and output ClassificationModel.rawPredictionCol().

      Specified by:
      predictRaw in class ClassificationModel<Vector,GBTClassificationModel>
      Parameters:
      features - (undocumented)
      Returns:
      vector where element i is the raw prediction for label i. This raw prediction may be any real number, where a larger value indicates greater confidence for that label.
    • copy

      public GBTClassificationModel copy(ParamMap extra)
      Description copied from interface: Params
      Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. See defaultCopy().
      Specified by:
      copy in interface Params
      Specified by:
      copy in class Model<GBTClassificationModel>
      Parameters:
      extra - (undocumented)
      Returns:
      (undocumented)
    • toString

      public String toString()
      Description copied from interface: TreeEnsembleModel
      Summary of the model
      Specified by:
      toString in interface Identifiable
      Specified by:
      toString in interface TreeEnsembleModel<DecisionTreeRegressionModel>
      Overrides:
      toString in class Object
    • featureImportances

      public Vector featureImportances()
    • evaluateEachIteration

      public double[] evaluateEachIteration(Dataset<?> dataset)
      Method to compute error or loss for every iteration of gradient boosting.

      Parameters:
      dataset - Dataset for validation.
      Returns:
      (undocumented)
    • write

      public MLWriter write()
      Description copied from interface: MLWritable
      Returns an MLWriter instance for this ML instance.
      Specified by:
      write in interface MLWritable
      Returns:
      (undocumented)