public class DecisionTreeRegressionModel extends RegressionModel<Vector,DecisionTreeRegressionModel> implements DecisionTreeModel, DecisionTreeRegressorParams, MLWritable, scala.Serializable
param: rootNode Root of the decision tree
| Modifier and Type | Method and Description |
|---|---|
BooleanParam |
cacheNodeIds()
If false, the algorithm will pass trees to executors to match instances with nodes.
|
IntParam |
checkpointInterval()
Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
|
DecisionTreeRegressionModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
int |
depth()
Depth of the tree.
|
Vector |
featureImportances() |
Param<String> |
impurity()
Criterion used for information gain calculation (case-insensitive).
|
Param<String> |
leafCol()
Leaf indices column name.
|
static DecisionTreeRegressionModel |
load(String path) |
IntParam |
maxBins()
Maximum number of bins used for discretizing continuous features and for choosing how to split
on features at each node.
|
IntParam |
maxDepth()
Maximum depth of the tree (nonnegative).
|
IntParam |
maxMemoryInMB()
Maximum memory in MB allocated to histogram aggregation.
|
DoubleParam |
minInfoGain()
Minimum information gain for a split to be considered at a tree node.
|
IntParam |
minInstancesPerNode()
Minimum number of instances each child must have after split.
|
DoubleParam |
minWeightFractionPerNode()
Minimum fraction of the weighted sample count that each child must have after split.
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
double |
predict(Vector features)
Predict label for the given features.
|
static MLReader<DecisionTreeRegressionModel> |
read() |
Node |
rootNode()
Root of the decision tree
|
LongParam |
seed()
Param for random seed.
|
DecisionTreeRegressionModel |
setVarianceCol(String value) |
String |
toString()
Summary of the model
|
Dataset<Row> |
transform(Dataset<?> dataset)
Transforms dataset by reading from
featuresCol, calling predict, and storing
the predictions as a new column predictionCol. |
StructType |
transformSchema(StructType schema)
Check transform validity and derive the output schema from the input schema.
|
String |
uid()
An immutable unique ID for the object and its derivatives.
|
Param<String> |
varianceCol()
Param for Column name for the biased sample variance of prediction.
|
Param<String> |
weightCol()
Param for weight column name.
|
MLWriter |
write()
Returns an
MLWriter instance for this ML instance. |
featuresCol, labelCol, predictionCol, setFeaturesCol, setPredictionColtransform, transform, transformparamsgetLeafField, leafIterator, maxSplitFeatureIndex, numNodes, predictLeaf, toDebugStringvalidateAndTransformSchemagetCacheNodeIds, getLeafCol, getMaxBins, getMaxDepth, getMaxMemoryInMB, getMinInfoGain, getMinInstancesPerNode, getMinWeightFractionPerNode, getOldStrategy, setLeafColgetLabelCol, labelColfeaturesCol, getFeaturesColgetPredictionCol, predictionColclear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwngetCheckpointIntervalgetWeightColgetImpurity, getOldImpuritygetVarianceColsave$init$, initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, initLock, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log__$eq, org$apache$spark$internal$Logging$$log_, uninitializepublic static MLReader<DecisionTreeRegressionModel> read()
public static DecisionTreeRegressionModel load(String path)
public final Param<String> varianceCol()
HasVarianceColvarianceCol in interface HasVarianceColpublic final Param<String> impurity()
HasVarianceImpurityimpurity in interface HasVarianceImpuritypublic final Param<String> leafCol()
DecisionTreeParamsleafCol in interface DecisionTreeParamspublic final IntParam maxDepth()
DecisionTreeParamsmaxDepth in interface DecisionTreeParamspublic final IntParam maxBins()
DecisionTreeParamsmaxBins in interface DecisionTreeParamspublic final IntParam minInstancesPerNode()
DecisionTreeParamsminInstancesPerNode in interface DecisionTreeParamspublic final DoubleParam minWeightFractionPerNode()
DecisionTreeParamsminWeightFractionPerNode in interface DecisionTreeParamspublic final DoubleParam minInfoGain()
DecisionTreeParamsminInfoGain in interface DecisionTreeParamspublic final IntParam maxMemoryInMB()
DecisionTreeParamsmaxMemoryInMB in interface DecisionTreeParamspublic final BooleanParam cacheNodeIds()
DecisionTreeParamscacheNodeIds in interface DecisionTreeParamspublic final Param<String> weightCol()
HasWeightColweightCol in interface HasWeightColpublic final LongParam seed()
HasSeedpublic final IntParam checkpointInterval()
HasCheckpointIntervalcheckpointInterval in interface HasCheckpointIntervalpublic int depth()
DecisionTreeModeldepth in interface DecisionTreeModelpublic String uid()
Identifiableuid in interface Identifiablepublic Node rootNode()
DecisionTreeModelrootNode in interface DecisionTreeModelpublic int numFeatures()
PredictionModelnumFeatures in class PredictionModel<Vector,DecisionTreeRegressionModel>public DecisionTreeRegressionModel setVarianceCol(String value)
public double predict(Vector features)
PredictionModeltransform() and output predictionCol.predict in class PredictionModel<Vector,DecisionTreeRegressionModel>features - (undocumented)public StructType transformSchema(StructType schema)
PipelineStage
We check validity for interactions between parameters during transformSchema and
raise an exception if any parameter value is invalid. Parameter value checks which
do not depend on other parameters are handled by Param.validate().
Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks.
transformSchema in class PredictionModel<Vector,DecisionTreeRegressionModel>schema - (undocumented)public Dataset<Row> transform(Dataset<?> dataset)
PredictionModelfeaturesCol, calling predict, and storing
the predictions as a new column predictionCol.
transform in class PredictionModel<Vector,DecisionTreeRegressionModel>dataset - input datasetpredictionCol of type Doublepublic DecisionTreeRegressionModel copy(ParamMap extra)
ParamsdefaultCopy().copy in interface Paramscopy in class Model<DecisionTreeRegressionModel>extra - (undocumented)public String toString()
DecisionTreeModeltoString in interface DecisionTreeModeltoString in interface IdentifiabletoString in class Objectpublic Vector featureImportances()
public MLWriter write()
MLWritableMLWriter instance for this ML instance.write in interface MLWritable