Class RandomForestClassificationModel
- All Implemented Interfaces:
Serializable
,org.apache.spark.internal.Logging
,ClassifierParams
,ProbabilisticClassifierParams
,Params
,HasCheckpointInterval
,HasFeaturesCol
,HasLabelCol
,HasPredictionCol
,HasProbabilityCol
,HasRawPredictionCol
,HasSeed
,HasThresholds
,HasWeightCol
,PredictorParams
,DecisionTreeParams
,RandomForestClassifierParams
,RandomForestParams
,TreeClassifierParams
,TreeEnsembleClassifierParams
,TreeEnsembleModel<DecisionTreeClassificationModel>
,TreeEnsembleParams
,HasTrainingSummary<RandomForestClassificationTrainingSummary>
,Identifiable
,MLWritable
param: _trees Decision trees in the ensemble. Warning: These have null parents.
- See Also:
-
Nested Class Summary
Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging
org.apache.spark.internal.Logging.LogStringContext, org.apache.spark.internal.Logging.SparkShellLoggingFilter
-
Method Summary
Modifier and TypeMethodDescriptionGets summary of model on training set.final BooleanParam
Whether bootstrap samples are used when building trees.final BooleanParam
If false, the algorithm will pass trees to executors to match instances with nodes.final IntParam
Param for set checkpoint interval (>= 1) or disable checkpoint (-1).Creates a copy of this instance with the same UID and some extra params.Evaluates the model on a test dataset.The number of features to consider for splits at each tree node.impurity()
Criterion used for information gain calculation (case-insensitive).leafCol()
Leaf indices column name.final IntParam
maxBins()
Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node.final IntParam
maxDepth()
Maximum depth of the tree (nonnegative).final IntParam
Maximum memory in MB allocated to histogram aggregation.final DoubleParam
Minimum information gain for a split to be considered at a tree node.final IntParam
Minimum number of instances each child must have after split.final DoubleParam
Minimum fraction of the weighted sample count that each child must have after split.int
Number of classes (values which the label can take).int
Returns the number of features the model was trained on.final IntParam
numTrees()
Number of trees to train (at least 1).predictRaw
(Vector features) Raw prediction for each possible label.read()
final LongParam
seed()
Param for random seed.final DoubleParam
Fraction of the training data used for learning each decision tree, in range (0, 1].summary()
Gets summary of model on training set.toString()
Summary of the modelint
Total number of nodes, summed over all trees in the ensemble.Transforms dataset by reading fromPredictionModel.featuresCol()
, and appending new columns as specified by parameters: - predicted labels asPredictionModel.predictionCol()
of typeDouble
- raw predictions (confidences) asClassificationModel.rawPredictionCol()
of typeVector
- probability of each class asProbabilisticClassificationModel.probabilityCol()
of typeVector
.transformSchema
(StructType schema) Check transform validity and derive the output schema from the input schema.trees()
Trees in this ensemble.double[]
Weights for each tree, zippable withTreeEnsembleModel.trees()
uid()
An immutable unique ID for the object and its derivatives.Param for weight column name.write()
Returns anMLWriter
instance for this ML instance.Methods inherited from class org.apache.spark.ml.classification.ProbabilisticClassificationModel
normalizeToProbabilitiesInPlace, predictProbability, probabilityCol, setProbabilityCol, setThresholds, thresholds
Methods inherited from class org.apache.spark.ml.classification.ClassificationModel
predict, rawPredictionCol, setRawPredictionCol, transformImpl
Methods inherited from class org.apache.spark.ml.PredictionModel
featuresCol, labelCol, predictionCol, setFeaturesCol, setPredictionCol
Methods inherited from class org.apache.spark.ml.Transformer
transform, transform, transform
Methods inherited from class org.apache.spark.ml.PipelineStage
params
Methods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface org.apache.spark.ml.tree.DecisionTreeParams
getCacheNodeIds, getLeafCol, getMaxBins, getMaxDepth, getMaxMemoryInMB, getMinInfoGain, getMinInstancesPerNode, getMinWeightFractionPerNode, getOldStrategy, setLeafCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasCheckpointInterval
getCheckpointInterval
Methods inherited from interface org.apache.spark.ml.param.shared.HasFeaturesCol
featuresCol, getFeaturesCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasLabelCol
getLabelCol, labelCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasPredictionCol
getPredictionCol, predictionCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasProbabilityCol
getProbabilityCol, probabilityCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasRawPredictionCol
getRawPredictionCol, rawPredictionCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasThresholds
getThresholds, thresholds
Methods inherited from interface org.apache.spark.ml.util.HasTrainingSummary
hasSummary, setSummary
Methods inherited from interface org.apache.spark.ml.param.shared.HasWeightCol
getWeightCol
Methods inherited from interface org.apache.spark.internal.Logging
initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logDebug, logDebug, logDebug, logDebug, logError, logError, logError, logError, logInfo, logInfo, logInfo, logInfo, logName, LogStringContext, logTrace, logTrace, logTrace, logTrace, logWarning, logWarning, logWarning, logWarning, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq, withLogContext
Methods inherited from interface org.apache.spark.ml.util.MLWritable
save
Methods inherited from interface org.apache.spark.ml.param.Params
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
Methods inherited from interface org.apache.spark.ml.tree.RandomForestParams
getBootstrap, getNumTrees
Methods inherited from interface org.apache.spark.ml.tree.TreeClassifierParams
getImpurity, getOldImpurity
Methods inherited from interface org.apache.spark.ml.tree.TreeEnsembleClassifierParams
validateAndTransformSchema
Methods inherited from interface org.apache.spark.ml.tree.TreeEnsembleModel
getLeafField, javaTreeWeights, predictLeaf, toDebugString
Methods inherited from interface org.apache.spark.ml.tree.TreeEnsembleParams
getFeatureSubsetStrategy, getOldStrategy, getSubsamplingRate
-
Method Details
-
read
-
load
-
totalNumNodes
public int totalNumNodes()Description copied from interface:TreeEnsembleModel
Total number of nodes, summed over all trees in the ensemble.- Specified by:
totalNumNodes
in interfaceTreeEnsembleModel<DecisionTreeClassificationModel>
-
impurity
Description copied from interface:TreeClassifierParams
Criterion used for information gain calculation (case-insensitive). This impurity type is used in DecisionTreeClassifier and RandomForestClassifier, Supported: "entropy" and "gini". (default = gini)- Specified by:
impurity
in interfaceTreeClassifierParams
- Returns:
- (undocumented)
-
numTrees
Description copied from interface:RandomForestParams
Number of trees to train (at least 1). If 1, then no bootstrapping is used. If greater than 1, then bootstrapping is done. TODO: Change to always do bootstrapping (simpler). SPARK-7130 (default = 20)Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams) is the param
maxIter
controls how many trees a GBT has. The semantics in the algorithms are a bit different.- Specified by:
numTrees
in interfaceRandomForestParams
- Returns:
- (undocumented)
-
bootstrap
Description copied from interface:RandomForestParams
Whether bootstrap samples are used when building trees.- Specified by:
bootstrap
in interfaceRandomForestParams
- Returns:
- (undocumented)
-
subsamplingRate
Description copied from interface:TreeEnsembleParams
Fraction of the training data used for learning each decision tree, in range (0, 1]. (default = 1.0)- Specified by:
subsamplingRate
in interfaceTreeEnsembleParams
- Returns:
- (undocumented)
-
featureSubsetStrategy
Description copied from interface:TreeEnsembleParams
The number of features to consider for splits at each tree node. Supported options: - "auto": Choose automatically for task: If numTrees == 1, set to "all." If numTrees greater than 1 (forest), set to "sqrt" for classification and to "onethird" for regression. - "all": use all features - "onethird": use 1/3 of the features - "sqrt": use sqrt(number of features) - "log2": use log2(number of features) - "n": when n is in the range (0, 1.0], use n * number of features. When n is in the range (1, number of features), use n features. (default = "auto")These various settings are based on the following references: - log2: tested in Breiman (2001) - sqrt: recommended by Breiman manual for random forests - The defaults of sqrt (classification) and onethird (regression) match the R randomForest package.
- Specified by:
featureSubsetStrategy
in interfaceTreeEnsembleParams
- Returns:
- (undocumented)
- See Also:
-
leafCol
Description copied from interface:DecisionTreeParams
Leaf indices column name. Predicted leaf index of each instance in each tree by preorder. (default = "")- Specified by:
leafCol
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
maxDepth
Description copied from interface:DecisionTreeParams
Maximum depth of the tree (nonnegative). E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default = 5)- Specified by:
maxDepth
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
maxBins
Description copied from interface:DecisionTreeParams
Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node. More bins give higher granularity. Must be at least 2 and at least number of categories in any categorical feature. (default = 32)- Specified by:
maxBins
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
minInstancesPerNode
Description copied from interface:DecisionTreeParams
Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Must be at least 1. (default = 1)- Specified by:
minInstancesPerNode
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
minWeightFractionPerNode
Description copied from interface:DecisionTreeParams
Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in the interval [0.0, 0.5). (default = 0.0)- Specified by:
minWeightFractionPerNode
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
minInfoGain
Description copied from interface:DecisionTreeParams
Minimum information gain for a split to be considered at a tree node. Should be at least 0.0. (default = 0.0)- Specified by:
minInfoGain
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
maxMemoryInMB
Description copied from interface:DecisionTreeParams
Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size. (default = 256 MB)- Specified by:
maxMemoryInMB
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
cacheNodeIds
Description copied from interface:DecisionTreeParams
If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. (default = false)- Specified by:
cacheNodeIds
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
weightCol
Description copied from interface:HasWeightCol
Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.- Specified by:
weightCol
in interfaceHasWeightCol
- Returns:
- (undocumented)
-
seed
Description copied from interface:HasSeed
Param for random seed. -
checkpointInterval
Description copied from interface:HasCheckpointInterval
Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.- Specified by:
checkpointInterval
in interfaceHasCheckpointInterval
- Returns:
- (undocumented)
-
uid
Description copied from interface:Identifiable
An immutable unique ID for the object and its derivatives.- Specified by:
uid
in interfaceIdentifiable
- Returns:
- (undocumented)
-
numFeatures
public int numFeatures()Description copied from class:PredictionModel
Returns the number of features the model was trained on. If unknown, returns -1- Overrides:
numFeatures
in classPredictionModel<Vector,
RandomForestClassificationModel>
-
numClasses
public int numClasses()Description copied from class:ClassificationModel
Number of classes (values which the label can take).- Specified by:
numClasses
in classClassificationModel<Vector,
RandomForestClassificationModel>
-
trees
Description copied from interface:TreeEnsembleModel
Trees in this ensemble. Warning: These have null parent Estimators.- Specified by:
trees
in interfaceTreeEnsembleModel<DecisionTreeClassificationModel>
-
treeWeights
public double[] treeWeights()Description copied from interface:TreeEnsembleModel
Weights for each tree, zippable withTreeEnsembleModel.trees()
- Specified by:
treeWeights
in interfaceTreeEnsembleModel<DecisionTreeClassificationModel>
-
summary
Gets summary of model on training set. An exception is thrown ifhasSummary
is false.- Specified by:
summary
in interfaceHasTrainingSummary<RandomForestClassificationTrainingSummary>
- Returns:
- (undocumented)
-
binarySummary
Gets summary of model on training set. An exception is thrown ifhasSummary
is false or it is a multiclass model.- Returns:
- (undocumented)
-
evaluate
Evaluates the model on a test dataset.- Parameters:
dataset
- Test dataset to evaluate model on.- Returns:
- (undocumented)
-
transformSchema
Description copied from class:PipelineStage
Check transform validity and derive the output schema from the input schema.We check validity for interactions between parameters during
transformSchema
and raise an exception if any parameter value is invalid. Parameter value checks which do not depend on other parameters are handled byParam.validate()
.Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks.
- Overrides:
transformSchema
in classProbabilisticClassificationModel<Vector,
RandomForestClassificationModel> - Parameters:
schema
- (undocumented)- Returns:
- (undocumented)
-
transform
Description copied from class:ProbabilisticClassificationModel
Transforms dataset by reading fromPredictionModel.featuresCol()
, and appending new columns as specified by parameters: - predicted labels asPredictionModel.predictionCol()
of typeDouble
- raw predictions (confidences) asClassificationModel.rawPredictionCol()
of typeVector
- probability of each class asProbabilisticClassificationModel.probabilityCol()
of typeVector
.- Overrides:
transform
in classProbabilisticClassificationModel<Vector,
RandomForestClassificationModel> - Parameters:
dataset
- input dataset- Returns:
- transformed dataset
-
predictRaw
Description copied from class:ClassificationModel
Raw prediction for each possible label. The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives a measure of confidence in each possible label (where larger = more confident). This internal method is used to implementtransform()
and outputClassificationModel.rawPredictionCol()
.- Specified by:
predictRaw
in classClassificationModel<Vector,
RandomForestClassificationModel> - Parameters:
features
- (undocumented)- Returns:
- vector where element i is the raw prediction for label i. This raw prediction may be any real number, where a larger value indicates greater confidence for that label.
-
copy
Description copied from interface:Params
Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. SeedefaultCopy()
.- Specified by:
copy
in interfaceParams
- Specified by:
copy
in classModel<RandomForestClassificationModel>
- Parameters:
extra
- (undocumented)- Returns:
- (undocumented)
-
toString
Description copied from interface:TreeEnsembleModel
Summary of the model- Specified by:
toString
in interfaceIdentifiable
- Specified by:
toString
in interfaceTreeEnsembleModel<DecisionTreeClassificationModel>
- Overrides:
toString
in classObject
-
featureImportances
-
write
Description copied from interface:MLWritable
Returns anMLWriter
instance for this ML instance.- Specified by:
write
in interfaceMLWritable
- Returns:
- (undocumented)
-