public class DecisionTreeRegressionModel extends RegressionModel<Vector,DecisionTreeRegressionModel> implements DecisionTreeModel, DecisionTreeRegressorParams, MLWritable, scala.Serializable
param: rootNode Root of the decision tree
Modifier and Type | Method and Description |
---|---|
BooleanParam |
cacheNodeIds()
If false, the algorithm will pass trees to executors to match instances with nodes.
|
IntParam |
checkpointInterval()
Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
|
DecisionTreeRegressionModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
int |
depth()
Depth of the tree.
|
Vector |
featureImportances() |
Param<String> |
impurity()
Criterion used for information gain calculation (case-insensitive).
|
Param<String> |
leafCol()
Leaf indices column name.
|
static DecisionTreeRegressionModel |
load(String path) |
IntParam |
maxBins()
Maximum number of bins used for discretizing continuous features and for choosing how to split
on features at each node.
|
IntParam |
maxDepth()
Maximum depth of the tree (nonnegative).
|
IntParam |
maxMemoryInMB()
Maximum memory in MB allocated to histogram aggregation.
|
DoubleParam |
minInfoGain()
Minimum information gain for a split to be considered at a tree node.
|
IntParam |
minInstancesPerNode()
Minimum number of instances each child must have after split.
|
DoubleParam |
minWeightFractionPerNode()
Minimum fraction of the weighted sample count that each child must have after split.
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
double |
predict(Vector features)
Predict label for the given features.
|
static MLReader<DecisionTreeRegressionModel> |
read() |
Node |
rootNode()
Root of the decision tree
|
LongParam |
seed()
Param for random seed.
|
DecisionTreeRegressionModel |
setVarianceCol(String value) |
String |
toString()
Summary of the model
|
Dataset<Row> |
transform(Dataset<?> dataset)
Transforms dataset by reading from
featuresCol , calling predict , and storing
the predictions as a new column predictionCol . |
StructType |
transformSchema(StructType schema)
Check transform validity and derive the output schema from the input schema.
|
String |
uid()
An immutable unique ID for the object and its derivatives.
|
Param<String> |
varianceCol()
Param for Column name for the biased sample variance of prediction.
|
Param<String> |
weightCol()
Param for weight column name.
|
MLWriter |
write()
Returns an
MLWriter instance for this ML instance. |
featuresCol, labelCol, predictionCol, setFeaturesCol, setPredictionCol
transform, transform, transform
params
getLeafField, leafIterator, maxSplitFeatureIndex, numNodes, predictLeaf, toDebugString
validateAndTransformSchema
getCacheNodeIds, getLeafCol, getMaxBins, getMaxDepth, getMaxMemoryInMB, getMinInfoGain, getMinInstancesPerNode, getMinWeightFractionPerNode, getOldStrategy, setLeafCol
getLabelCol, labelCol
featuresCol, getFeaturesCol
getPredictionCol, predictionCol
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
getCheckpointInterval
getWeightCol
getImpurity, getOldImpurity
getVarianceCol
save
$init$, initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, initLock, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log__$eq, org$apache$spark$internal$Logging$$log_, uninitialize
public static MLReader<DecisionTreeRegressionModel> read()
public static DecisionTreeRegressionModel load(String path)
public final Param<String> varianceCol()
HasVarianceCol
varianceCol
in interface HasVarianceCol
public final Param<String> impurity()
HasVarianceImpurity
impurity
in interface HasVarianceImpurity
public final Param<String> leafCol()
DecisionTreeParams
leafCol
in interface DecisionTreeParams
public final IntParam maxDepth()
DecisionTreeParams
maxDepth
in interface DecisionTreeParams
public final IntParam maxBins()
DecisionTreeParams
maxBins
in interface DecisionTreeParams
public final IntParam minInstancesPerNode()
DecisionTreeParams
minInstancesPerNode
in interface DecisionTreeParams
public final DoubleParam minWeightFractionPerNode()
DecisionTreeParams
minWeightFractionPerNode
in interface DecisionTreeParams
public final DoubleParam minInfoGain()
DecisionTreeParams
minInfoGain
in interface DecisionTreeParams
public final IntParam maxMemoryInMB()
DecisionTreeParams
maxMemoryInMB
in interface DecisionTreeParams
public final BooleanParam cacheNodeIds()
DecisionTreeParams
cacheNodeIds
in interface DecisionTreeParams
public final Param<String> weightCol()
HasWeightCol
weightCol
in interface HasWeightCol
public final LongParam seed()
HasSeed
public final IntParam checkpointInterval()
HasCheckpointInterval
checkpointInterval
in interface HasCheckpointInterval
public int depth()
DecisionTreeModel
depth
in interface DecisionTreeModel
public String uid()
Identifiable
uid
in interface Identifiable
public Node rootNode()
DecisionTreeModel
rootNode
in interface DecisionTreeModel
public int numFeatures()
PredictionModel
numFeatures
in class PredictionModel<Vector,DecisionTreeRegressionModel>
public DecisionTreeRegressionModel setVarianceCol(String value)
public double predict(Vector features)
PredictionModel
transform()
and output predictionCol
.predict
in class PredictionModel<Vector,DecisionTreeRegressionModel>
features
- (undocumented)public StructType transformSchema(StructType schema)
PipelineStage
We check validity for interactions between parameters during transformSchema
and
raise an exception if any parameter value is invalid. Parameter value checks which
do not depend on other parameters are handled by Param.validate()
.
Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks.
transformSchema
in class PredictionModel<Vector,DecisionTreeRegressionModel>
schema
- (undocumented)public Dataset<Row> transform(Dataset<?> dataset)
PredictionModel
featuresCol
, calling predict
, and storing
the predictions as a new column predictionCol
.
transform
in class PredictionModel<Vector,DecisionTreeRegressionModel>
dataset
- input datasetpredictionCol
of type Double
public DecisionTreeRegressionModel copy(ParamMap extra)
Params
defaultCopy()
.copy
in interface Params
copy
in class Model<DecisionTreeRegressionModel>
extra
- (undocumented)public String toString()
DecisionTreeModel
toString
in interface DecisionTreeModel
toString
in interface Identifiable
toString
in class Object
public Vector featureImportances()
public MLWriter write()
MLWritable
MLWriter
instance for this ML instance.write
in interface MLWritable