Package org.apache.spark.ml.regression
Class GBTRegressionModel
Object
org.apache.spark.ml.PipelineStage
org.apache.spark.ml.Transformer
org.apache.spark.ml.Model<M>
org.apache.spark.ml.PredictionModel<FeaturesType,M>
 
org.apache.spark.ml.regression.RegressionModel<Vector,GBTRegressionModel>
 
org.apache.spark.ml.regression.GBTRegressionModel
- All Implemented Interfaces:
- Serializable,- org.apache.spark.internal.Logging,- Params,- HasCheckpointInterval,- HasFeaturesCol,- HasLabelCol,- HasMaxIter,- HasPredictionCol,- HasSeed,- HasStepSize,- HasValidationIndicatorCol,- HasWeightCol,- PredictorParams,- DecisionTreeParams,- GBTParams,- GBTRegressorParams,- HasVarianceImpurity,- TreeEnsembleModel<DecisionTreeRegressionModel>,- TreeEnsembleParams,- TreeEnsembleRegressorParams,- TreeRegressorParams,- Identifiable,- MLWritable
public class GBTRegressionModel
extends RegressionModel<Vector,GBTRegressionModel>
implements GBTRegressorParams, TreeEnsembleModel<DecisionTreeRegressionModel>, MLWritable, Serializable 
Gradient-Boosted Trees (GBTs)
 model for regression.
 It supports both continuous and categorical features.
 param:  _trees  Decision trees in the ensemble.
 param:  _treeWeights  Weights for the decision trees in the ensemble.
- See Also:
- 
Nested Class SummaryNested classes/interfaces inherited from interface org.apache.spark.internal.Loggingorg.apache.spark.internal.Logging.LogStringContext, org.apache.spark.internal.Logging.SparkShellLoggingFilter
- 
Constructor SummaryConstructorsConstructorDescriptionGBTRegressionModel(String uid, DecisionTreeRegressionModel[] _trees, double[] _treeWeights) Construct a GBTRegressionModel
- 
Method SummaryModifier and TypeMethodDescriptionfinal BooleanParamIf false, the algorithm will pass trees to executors to match instances with nodes.final IntParamParam for set checkpoint interval (>= 1) or disable checkpoint (-1).Creates a copy of this instance with the same UID and some extra params.longdouble[]evaluateEachIteration(Dataset<?> dataset, String loss) Method to compute error or loss for every iteration of gradient boosting.The number of features to consider for splits at each tree node.intNumber of trees in ensembleimpurity()Criterion used for information gain calculation (case-insensitive).leafCol()Leaf indices column name.static GBTRegressionModellossType()Loss function which GBT tries to minimize.final IntParammaxBins()Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node.final IntParammaxDepth()Maximum depth of the tree (nonnegative).final IntParammaxIter()Param for maximum number of iterations (>= 0).final IntParamMaximum memory in MB allocated to histogram aggregation.final DoubleParamMinimum information gain for a split to be considered at a tree node.final IntParamMinimum number of instances each child must have after split.final DoubleParamMinimum fraction of the weighted sample count that each child must have after split.intReturns the number of features the model was trained on.doublePredict label for the given features.static MLReader<GBTRegressionModel>read()final LongParamseed()Param for random seed.final DoubleParamstepSize()Param for Step size (a.k.a.final DoubleParamFraction of the training data used for learning each decision tree, in range (0, 1].toString()Summary of the modelintTotal number of nodes, summed over all trees in the ensemble.Transforms dataset by reading fromPredictionModel.featuresCol(), callingpredict, and storing the predictions as a new columnPredictionModel.predictionCol().transformSchema(StructType schema) Check transform validity and derive the output schema from the input schema.trees()Trees in this ensemble.double[]Weights for each tree, zippable withTreeEnsembleModel.trees()uid()An immutable unique ID for the object and its derivatives.Param for name of the column that indicates whether each row is for training or for validation.final DoubleParamThreshold for stopping early when fit with validation is used.Param for weight column name.write()Returns anMLWriterinstance for this ML instance.Methods inherited from class org.apache.spark.ml.PredictionModelfeaturesCol, labelCol, predictionCol, setFeaturesCol, setPredictionColMethods inherited from class org.apache.spark.ml.Transformertransform, transform, transformMethods inherited from class org.apache.spark.ml.PipelineStageparamsMethods inherited from class java.lang.Objectequals, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface org.apache.spark.ml.tree.DecisionTreeParamsgetCacheNodeIds, getLeafCol, getMaxBins, getMaxDepth, getMaxMemoryInMB, getMinInfoGain, getMinInstancesPerNode, getMinWeightFractionPerNode, getOldStrategy, setLeafColMethods inherited from interface org.apache.spark.ml.tree.GBTParamsgetOldBoostingStrategy, getValidationTolMethods inherited from interface org.apache.spark.ml.tree.GBTRegressorParamsconvertToOldLossType, getLossType, getOldLossTypeMethods inherited from interface org.apache.spark.ml.param.shared.HasCheckpointIntervalgetCheckpointIntervalMethods inherited from interface org.apache.spark.ml.param.shared.HasFeaturesColfeaturesCol, getFeaturesColMethods inherited from interface org.apache.spark.ml.param.shared.HasLabelColgetLabelCol, labelColMethods inherited from interface org.apache.spark.ml.param.shared.HasMaxItergetMaxIterMethods inherited from interface org.apache.spark.ml.param.shared.HasPredictionColgetPredictionCol, predictionColMethods inherited from interface org.apache.spark.ml.param.shared.HasStepSizegetStepSizeMethods inherited from interface org.apache.spark.ml.param.shared.HasValidationIndicatorColgetValidationIndicatorColMethods inherited from interface org.apache.spark.ml.tree.HasVarianceImpuritygetImpurity, getOldImpurityMethods inherited from interface org.apache.spark.ml.param.shared.HasWeightColgetWeightColMethods inherited from interface org.apache.spark.internal.LogginginitializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logBasedOnLevel, logDebug, logDebug, logDebug, logDebug, logError, logError, logError, logError, logInfo, logInfo, logInfo, logInfo, logName, LogStringContext, logTrace, logTrace, logTrace, logTrace, logWarning, logWarning, logWarning, logWarning, MDC, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq, withLogContextMethods inherited from interface org.apache.spark.ml.util.MLWritablesaveMethods inherited from interface org.apache.spark.ml.param.Paramsclear, copyValues, defaultCopy, defaultParamMap, estimateMatadataSize, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwnMethods inherited from interface org.apache.spark.ml.tree.TreeEnsembleModelgetEstimatedSize, getLeafField, getTree, javaTreeWeights, predictLeaf, toDebugStringMethods inherited from interface org.apache.spark.ml.tree.TreeEnsembleParamsgetFeatureSubsetStrategy, getOldStrategy, getSubsamplingRateMethods inherited from interface org.apache.spark.ml.tree.TreeEnsembleRegressorParamsvalidateAndTransformSchema
- 
Constructor Details- 
GBTRegressionModelConstruct a GBTRegressionModel- Parameters:
- _trees- Decision trees in the ensemble.
- _treeWeights- Weights for the decision trees in the ensemble.
- uid- (undocumented)
 
 
- 
- 
Method Details- 
read
- 
load
- 
totalNumNodespublic int totalNumNodes()Description copied from interface:TreeEnsembleModelTotal number of nodes, summed over all trees in the ensemble.- Specified by:
- totalNumNodesin interface- TreeEnsembleModel<DecisionTreeRegressionModel>
 
- 
lossTypeDescription copied from interface:GBTRegressorParamsLoss function which GBT tries to minimize. (case-insensitive) Supported: "squared" (L2) and "absolute" (L1) (default = squared)- Specified by:
- lossTypein interface- GBTRegressorParams
- Returns:
- (undocumented)
 
- 
impurityDescription copied from interface:HasVarianceImpurityCriterion used for information gain calculation (case-insensitive). This impurity type is used in DecisionTreeRegressor, RandomForestRegressor, GBTRegressor and GBTClassifier (since GBTClassificationModel is internally composed of DecisionTreeRegressionModels). Supported: "variance". (default = variance)- Specified by:
- impurityin interface- HasVarianceImpurity
- Returns:
- (undocumented)
 
- 
validationTolDescription copied from interface:GBTParamsThreshold for stopping early when fit with validation is used. (This parameter is ignored when fit without validation is used.) The decision to stop early is decided based on this logic: If the current loss on the validation set is greater than 0.01, the diff of validation error is compared to relative tolerance which is validationTol * (current loss on the validation set). If the current loss on the validation set is less than or equal to 0.01, the diff of validation error is compared to absolute tolerance which is validationTol * 0.01.- Specified by:
- validationTolin interface- GBTParams
- Returns:
- (undocumented)
- See Also:
 
- 
stepSizeDescription copied from interface:GBTParamsParam for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator. (default = 0.1)- Specified by:
- stepSizein interface- GBTParams
- Specified by:
- stepSizein interface- HasStepSize
- Returns:
- (undocumented)
 
- 
validationIndicatorColDescription copied from interface:HasValidationIndicatorColParam for name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation..- Specified by:
- validationIndicatorColin interface- HasValidationIndicatorCol
- Returns:
- (undocumented)
 
- 
maxIterDescription copied from interface:HasMaxIterParam for maximum number of iterations (>= 0).- Specified by:
- maxIterin interface- HasMaxIter
- Returns:
- (undocumented)
 
- 
subsamplingRateDescription copied from interface:TreeEnsembleParamsFraction of the training data used for learning each decision tree, in range (0, 1]. (default = 1.0)- Specified by:
- subsamplingRatein interface- TreeEnsembleParams
- Returns:
- (undocumented)
 
- 
featureSubsetStrategyDescription copied from interface:TreeEnsembleParamsThe number of features to consider for splits at each tree node. Supported options: - "auto": Choose automatically for task: If numTrees == 1, set to "all." If numTrees greater than 1 (forest), set to "sqrt" for classification and to "onethird" for regression. - "all": use all features - "onethird": use 1/3 of the features - "sqrt": use sqrt(number of features) - "log2": use log2(number of features) - "n": when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features. (default = "auto")These various settings are based on the following references: - log2: tested in Breiman (2001) - sqrt: recommended by Breiman manual for random forests - The defaults of sqrt (classification) and onethird (regression) match the R randomForest package. - Specified by:
- featureSubsetStrategyin interface- TreeEnsembleParams
- Returns:
- (undocumented)
- See Also:
 
- 
leafColDescription copied from interface:DecisionTreeParamsLeaf indices column name. Predicted leaf index of each instance in each tree by preorder. (default = "")- Specified by:
- leafColin interface- DecisionTreeParams
- Returns:
- (undocumented)
 
- 
maxDepthDescription copied from interface:DecisionTreeParamsMaximum depth of the tree (nonnegative). E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default = 5)- Specified by:
- maxDepthin interface- DecisionTreeParams
- Returns:
- (undocumented)
 
- 
maxBinsDescription copied from interface:DecisionTreeParamsMaximum number of bins used for discretizing continuous features and for choosing how to split on features at each node. More bins give higher granularity. Must be at least 2 and at least number of categories in any categorical feature. (default = 32)- Specified by:
- maxBinsin interface- DecisionTreeParams
- Returns:
- (undocumented)
 
- 
minInstancesPerNodeDescription copied from interface:DecisionTreeParamsMinimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Must be at least 1. (default = 1)- Specified by:
- minInstancesPerNodein interface- DecisionTreeParams
- Returns:
- (undocumented)
 
- 
minWeightFractionPerNodeDescription copied from interface:DecisionTreeParamsMinimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in the interval [0.0, 0.5). (default = 0.0)- Specified by:
- minWeightFractionPerNodein interface- DecisionTreeParams
- Returns:
- (undocumented)
 
- 
minInfoGainDescription copied from interface:DecisionTreeParamsMinimum information gain for a split to be considered at a tree node. Should be at least 0.0. (default = 0.0)- Specified by:
- minInfoGainin interface- DecisionTreeParams
- Returns:
- (undocumented)
 
- 
maxMemoryInMBDescription copied from interface:DecisionTreeParamsMaximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size. (default = 256 MB)- Specified by:
- maxMemoryInMBin interface- DecisionTreeParams
- Returns:
- (undocumented)
 
- 
cacheNodeIdsDescription copied from interface:DecisionTreeParamsIf false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. (default = false)- Specified by:
- cacheNodeIdsin interface- DecisionTreeParams
- Returns:
- (undocumented)
 
- 
weightColDescription copied from interface:HasWeightColParam for weight column name. If this is not set or empty, we treat all instance weights as 1.0.- Specified by:
- weightColin interface- HasWeightCol
- Returns:
- (undocumented)
 
- 
seedDescription copied from interface:HasSeedParam for random seed.
- 
checkpointIntervalDescription copied from interface:HasCheckpointIntervalParam for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.- Specified by:
- checkpointIntervalin interface- HasCheckpointInterval
- Returns:
- (undocumented)
 
- 
uidDescription copied from interface:IdentifiableAn immutable unique ID for the object and its derivatives.- Specified by:
- uidin interface- Identifiable
- Returns:
- (undocumented)
 
- 
numFeaturespublic int numFeatures()Description copied from class:PredictionModelReturns the number of features the model was trained on. If unknown, returns -1- Overrides:
- numFeaturesin class- PredictionModel<Vector,- GBTRegressionModel> 
 
- 
estimatedSizepublic long estimatedSize()
- 
treesDescription copied from interface:TreeEnsembleModelTrees in this ensemble. Warning: These have null parent Estimators.- Specified by:
- treesin interface- TreeEnsembleModel<DecisionTreeRegressionModel>
 
- 
getNumTreespublic int getNumTrees()Number of trees in ensemble- Returns:
- (undocumented)
 
- 
treeWeightspublic double[] treeWeights()Description copied from interface:TreeEnsembleModelWeights for each tree, zippable withTreeEnsembleModel.trees()- Specified by:
- treeWeightsin interface- TreeEnsembleModel<DecisionTreeRegressionModel>
 
- 
transformSchemaDescription copied from class:PipelineStageCheck transform validity and derive the output schema from the input schema.We check validity for interactions between parameters during transformSchemaand raise an exception if any parameter value is invalid. Parameter value checks which do not depend on other parameters are handled byParam.validate().Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks. - Overrides:
- transformSchemain class- PredictionModel<Vector,- GBTRegressionModel> 
- Parameters:
- schema- (undocumented)
- Returns:
- (undocumented)
 
- 
transformDescription copied from class:PredictionModelTransforms dataset by reading fromPredictionModel.featuresCol(), callingpredict, and storing the predictions as a new columnPredictionModel.predictionCol().- Overrides:
- transformin class- PredictionModel<Vector,- GBTRegressionModel> 
- Parameters:
- dataset- input dataset
- Returns:
- transformed dataset with PredictionModel.predictionCol()of typeDouble
 
- 
predictDescription copied from class:PredictionModelPredict label for the given features. This method is used to implementtransform()and outputPredictionModel.predictionCol().- Specified by:
- predictin class- PredictionModel<Vector,- GBTRegressionModel> 
- Parameters:
- features- (undocumented)
- Returns:
- (undocumented)
 
- 
copyDescription copied from interface:ParamsCreates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. SeedefaultCopy().- Specified by:
- copyin interface- Params
- Specified by:
- copyin class- Model<GBTRegressionModel>
- Parameters:
- extra- (undocumented)
- Returns:
- (undocumented)
 
- 
toStringDescription copied from interface:TreeEnsembleModelSummary of the model- Specified by:
- toStringin interface- Identifiable
- Specified by:
- toStringin interface- TreeEnsembleModel<DecisionTreeRegressionModel>
- Overrides:
- toStringin class- Object
 
- 
featureImportances
- 
evaluateEachIterationMethod to compute error or loss for every iteration of gradient boosting.- Parameters:
- dataset- Dataset for validation.
- loss- The loss function used to compute error. Supported options: squared, absolute
- Returns:
- (undocumented)
 
- 
writeDescription copied from interface:MLWritableReturns anMLWriterinstance for this ML instance.- Specified by:
- writein interface- MLWritable
- Returns:
- (undocumented)
 
 
-