Class DecisionTreeClassifier
Object
org.apache.spark.ml.PipelineStage
org.apache.spark.ml.Estimator<M>
org.apache.spark.ml.Predictor<FeaturesType,E,M>
org.apache.spark.ml.classification.Classifier<FeaturesType,E,M>
org.apache.spark.ml.classification.ProbabilisticClassifier<Vector,DecisionTreeClassifier,DecisionTreeClassificationModel>
org.apache.spark.ml.classification.DecisionTreeClassifier
- All Implemented Interfaces:
Serializable
,org.apache.spark.internal.Logging
,ClassifierParams
,ProbabilisticClassifierParams
,Params
,HasCheckpointInterval
,HasFeaturesCol
,HasLabelCol
,HasPredictionCol
,HasProbabilityCol
,HasRawPredictionCol
,HasSeed
,HasThresholds
,HasWeightCol
,PredictorParams
,DecisionTreeClassifierParams
,DecisionTreeParams
,TreeClassifierParams
,DefaultParamsWritable
,Identifiable
,MLWritable
,scala.Serializable
public class DecisionTreeClassifier
extends ProbabilisticClassifier<Vector,DecisionTreeClassifier,DecisionTreeClassificationModel>
implements DecisionTreeClassifierParams, DefaultParamsWritable
Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
for classification.
It supports both binary and multiclass labels, as well as both continuous and categorical
features.
- See Also:
-
Nested Class Summary
Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging
org.apache.spark.internal.Logging.SparkShellLoggingFilter
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionfinal BooleanParam
If false, the algorithm will pass trees to executors to match instances with nodes.final IntParam
Param for set checkpoint interval (>= 1) or disable checkpoint (-1).Creates a copy of this instance with the same UID and some extra params.impurity()
Criterion used for information gain calculation (case-insensitive).leafCol()
Leaf indices column name.static DecisionTreeClassifier
final IntParam
maxBins()
Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node.final IntParam
maxDepth()
Maximum depth of the tree (nonnegative).final IntParam
Maximum memory in MB allocated to histogram aggregation.final DoubleParam
Minimum information gain for a split to be considered at a tree node.final IntParam
Minimum number of instances each child must have after split.final DoubleParam
Minimum fraction of the weighted sample count that each child must have after split.static MLReader<T>
read()
final LongParam
seed()
Param for random seed.setCacheNodeIds
(boolean value) setCheckpointInterval
(int value) Specifies how often to checkpoint the cached node IDs.setImpurity
(String value) setMaxBins
(int value) setMaxDepth
(int value) setMaxMemoryInMB
(int value) setMinInfoGain
(double value) setMinInstancesPerNode
(int value) setMinWeightFractionPerNode
(double value) setSeed
(long value) setWeightCol
(String value) Sets the value of paramweightCol()
.static final String[]
Accessor for supported impurities: entropy, giniuid()
An immutable unique ID for the object and its derivatives.Param for weight column name.Methods inherited from class org.apache.spark.ml.classification.ProbabilisticClassifier
probabilityCol, setProbabilityCol, setThresholds, thresholds
Methods inherited from class org.apache.spark.ml.classification.Classifier
rawPredictionCol, setRawPredictionCol
Methods inherited from class org.apache.spark.ml.Predictor
featuresCol, fit, labelCol, predictionCol, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
Methods inherited from class org.apache.spark.ml.PipelineStage
params
Methods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
Methods inherited from interface org.apache.spark.ml.tree.DecisionTreeClassifierParams
validateAndTransformSchema
Methods inherited from interface org.apache.spark.ml.tree.DecisionTreeParams
getCacheNodeIds, getLeafCol, getMaxBins, getMaxDepth, getMaxMemoryInMB, getMinInfoGain, getMinInstancesPerNode, getMinWeightFractionPerNode, getOldStrategy, setLeafCol
Methods inherited from interface org.apache.spark.ml.util.DefaultParamsWritable
write
Methods inherited from interface org.apache.spark.ml.param.shared.HasCheckpointInterval
getCheckpointInterval
Methods inherited from interface org.apache.spark.ml.param.shared.HasFeaturesCol
featuresCol, getFeaturesCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasLabelCol
getLabelCol, labelCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasPredictionCol
getPredictionCol, predictionCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasProbabilityCol
getProbabilityCol, probabilityCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasRawPredictionCol
getRawPredictionCol, rawPredictionCol
Methods inherited from interface org.apache.spark.ml.param.shared.HasThresholds
getThresholds, thresholds
Methods inherited from interface org.apache.spark.ml.param.shared.HasWeightCol
getWeightCol
Methods inherited from interface org.apache.spark.ml.util.Identifiable
toString
Methods inherited from interface org.apache.spark.internal.Logging
initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq
Methods inherited from interface org.apache.spark.ml.util.MLWritable
save
Methods inherited from interface org.apache.spark.ml.param.Params
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
Methods inherited from interface org.apache.spark.ml.tree.TreeClassifierParams
getImpurity, getOldImpurity
-
Constructor Details
-
DecisionTreeClassifier
-
DecisionTreeClassifier
public DecisionTreeClassifier()
-
-
Method Details
-
supportedImpurities
Accessor for supported impurities: entropy, gini -
load
-
read
-
impurity
Description copied from interface:TreeClassifierParams
Criterion used for information gain calculation (case-insensitive). This impurity type is used in DecisionTreeClassifier and RandomForestClassifier, Supported: "entropy" and "gini". (default = gini)- Specified by:
impurity
in interfaceTreeClassifierParams
- Returns:
- (undocumented)
-
leafCol
Description copied from interface:DecisionTreeParams
Leaf indices column name. Predicted leaf index of each instance in each tree by preorder. (default = "")- Specified by:
leafCol
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
maxDepth
Description copied from interface:DecisionTreeParams
Maximum depth of the tree (nonnegative). E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (default = 5)- Specified by:
maxDepth
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
maxBins
Description copied from interface:DecisionTreeParams
Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node. More bins give higher granularity. Must be at least 2 and at least number of categories in any categorical feature. (default = 32)- Specified by:
maxBins
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
minInstancesPerNode
Description copied from interface:DecisionTreeParams
Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Must be at least 1. (default = 1)- Specified by:
minInstancesPerNode
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
minWeightFractionPerNode
Description copied from interface:DecisionTreeParams
Minimum fraction of the weighted sample count that each child must have after split. If a split causes the fraction of the total weight in the left or right child to be less than minWeightFractionPerNode, the split will be discarded as invalid. Should be in the interval [0.0, 0.5). (default = 0.0)- Specified by:
minWeightFractionPerNode
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
minInfoGain
Description copied from interface:DecisionTreeParams
Minimum information gain for a split to be considered at a tree node. Should be at least 0.0. (default = 0.0)- Specified by:
minInfoGain
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
maxMemoryInMB
Description copied from interface:DecisionTreeParams
Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size. (default = 256 MB)- Specified by:
maxMemoryInMB
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
cacheNodeIds
Description copied from interface:DecisionTreeParams
If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval. (default = false)- Specified by:
cacheNodeIds
in interfaceDecisionTreeParams
- Returns:
- (undocumented)
-
weightCol
Description copied from interface:HasWeightCol
Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.- Specified by:
weightCol
in interfaceHasWeightCol
- Returns:
- (undocumented)
-
seed
Description copied from interface:HasSeed
Param for random seed. -
checkpointInterval
Description copied from interface:HasCheckpointInterval
Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.- Specified by:
checkpointInterval
in interfaceHasCheckpointInterval
- Returns:
- (undocumented)
-
uid
Description copied from interface:Identifiable
An immutable unique ID for the object and its derivatives.- Specified by:
uid
in interfaceIdentifiable
- Returns:
- (undocumented)
-
setMaxDepth
-
setMaxBins
-
setMinInstancesPerNode
-
setMinWeightFractionPerNode
-
setMinInfoGain
-
setMaxMemoryInMB
-
setCacheNodeIds
-
setCheckpointInterval
Specifies how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the checkpoint directory is set inSparkContext
. Must be at least 1. (default = 10)- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
setImpurity
-
setSeed
-
setWeightCol
Sets the value of paramweightCol()
. If this is not set or empty, we treat all instance weights as 1.0. Default is not set, so all instances have weight one.- Parameters:
value
- (undocumented)- Returns:
- (undocumented)
-
copy
Description copied from interface:Params
Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. SeedefaultCopy()
.- Specified by:
copy
in interfaceParams
- Specified by:
copy
in classPredictor<Vector,
DecisionTreeClassifier, DecisionTreeClassificationModel> - Parameters:
extra
- (undocumented)- Returns:
- (undocumented)
-