public class RandomForestClassificationModel extends ProbabilisticClassificationModel<Vector,RandomForestClassificationModel> implements RandomForestClassifierParams, TreeEnsembleModel<DecisionTreeClassificationModel>, MLWritable, scala.Serializable, HasTrainingSummary<RandomForestClassificationTrainingSummary>
param: _trees Decision trees in the ensemble. Warning: These have null parents.
| Modifier and Type | Method and Description |
|---|---|
BinaryRandomForestClassificationTrainingSummary |
binarySummary()
Gets summary of model on training set.
|
BooleanParam |
bootstrap()
Whether bootstrap samples are used when building trees.
|
BooleanParam |
cacheNodeIds()
If false, the algorithm will pass trees to executors to match instances with nodes.
|
IntParam |
checkpointInterval()
Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
|
RandomForestClassificationModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
RandomForestClassificationSummary |
evaluate(Dataset<?> dataset)
Evaluates the model on a test dataset.
|
Vector |
featureImportances() |
Param<String> |
featureSubsetStrategy()
The number of features to consider for splits at each tree node.
|
Param<String> |
impurity()
Criterion used for information gain calculation (case-insensitive).
|
Param<String> |
leafCol()
Leaf indices column name.
|
static RandomForestClassificationModel |
load(String path) |
IntParam |
maxBins()
Maximum number of bins used for discretizing continuous features and for choosing how to split
on features at each node.
|
IntParam |
maxDepth()
Maximum depth of the tree (nonnegative).
|
IntParam |
maxMemoryInMB()
Maximum memory in MB allocated to histogram aggregation.
|
DoubleParam |
minInfoGain()
Minimum information gain for a split to be considered at a tree node.
|
IntParam |
minInstancesPerNode()
Minimum number of instances each child must have after split.
|
DoubleParam |
minWeightFractionPerNode()
Minimum fraction of the weighted sample count that each child must have after split.
|
int |
numClasses()
Number of classes (values which the label can take).
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
IntParam |
numTrees()
Number of trees to train (at least 1).
|
Vector |
predictRaw(Vector features)
Raw prediction for each possible label.
|
static MLReader<RandomForestClassificationModel> |
read() |
LongParam |
seed()
Param for random seed.
|
DoubleParam |
subsamplingRate()
Fraction of the training data used for learning each decision tree, in range (0, 1].
|
RandomForestClassificationTrainingSummary |
summary()
Gets summary of model on training set.
|
String |
toString()
Summary of the model
|
int |
totalNumNodes()
Total number of nodes, summed over all trees in the ensemble.
|
Dataset<Row> |
transform(Dataset<?> dataset)
Transforms dataset by reading from
featuresCol, and appending new columns as specified by
parameters:
- predicted labels as predictionCol of type Double
- raw predictions (confidences) as rawPredictionCol of type Vector
- probability of each class as probabilityCol of type Vector. |
StructType |
transformSchema(StructType schema)
Check transform validity and derive the output schema from the input schema.
|
DecisionTreeClassificationModel[] |
trees()
Trees in this ensemble.
|
double[] |
treeWeights()
Weights for each tree, zippable with
trees |
String |
uid()
An immutable unique ID for the object and its derivatives.
|
Param<String> |
weightCol()
Param for weight column name.
|
MLWriter |
write()
Returns an
MLWriter instance for this ML instance. |
normalizeToProbabilitiesInPlace, predictProbability, probabilityCol, setProbabilityCol, setThresholds, thresholdspredict, rawPredictionCol, setRawPredictionCol, transformImplfeaturesCol, labelCol, predictionCol, setFeaturesCol, setPredictionColtransform, transform, transformparamsgetBootstrap, getNumTreesvalidateAndTransformSchemagetFeatureSubsetStrategy, getOldStrategy, getSubsamplingRategetCacheNodeIds, getLeafCol, getMaxBins, getMaxDepth, getMaxMemoryInMB, getMinInfoGain, getMinInstancesPerNode, getMinWeightFractionPerNode, getOldStrategy, setLeafColgetCheckpointIntervalgetWeightColgetLabelCol, labelColfeaturesCol, getFeaturesColgetPredictionCol, predictionColclear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwngetRawPredictionCol, rawPredictionColgetProbabilityCol, probabilityColgetThresholds, thresholdsgetImpurity, getOldImpuritygetLeafField, javaTreeWeights, predictLeaf, toDebugStringsavehasSummary, setSummary$init$, initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, initLock, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log__$eq, org$apache$spark$internal$Logging$$log_, uninitializepublic static MLReader<RandomForestClassificationModel> read()
public static RandomForestClassificationModel load(String path)
public int totalNumNodes()
TreeEnsembleModeltotalNumNodes in interface TreeEnsembleModel<DecisionTreeClassificationModel>public final Param<String> impurity()
TreeClassifierParamsimpurity in interface TreeClassifierParamspublic final IntParam numTrees()
RandomForestParams
Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams)
is the param maxIter controls how many trees a GBT has. The semantics in the algorithms
are a bit different.
numTrees in interface RandomForestParamspublic final BooleanParam bootstrap()
RandomForestParamsbootstrap in interface RandomForestParamspublic final DoubleParam subsamplingRate()
TreeEnsembleParamssubsamplingRate in interface TreeEnsembleParamspublic final Param<String> featureSubsetStrategy()
TreeEnsembleParamsThese various settings are based on the following references: - log2: tested in Breiman (2001) - sqrt: recommended by Breiman manual for random forests - The defaults of sqrt (classification) and onethird (regression) match the R randomForest package.
featureSubsetStrategy in interface TreeEnsembleParamspublic final Param<String> leafCol()
DecisionTreeParamsleafCol in interface DecisionTreeParamspublic final IntParam maxDepth()
DecisionTreeParamsmaxDepth in interface DecisionTreeParamspublic final IntParam maxBins()
DecisionTreeParamsmaxBins in interface DecisionTreeParamspublic final IntParam minInstancesPerNode()
DecisionTreeParamsminInstancesPerNode in interface DecisionTreeParamspublic final DoubleParam minWeightFractionPerNode()
DecisionTreeParamsminWeightFractionPerNode in interface DecisionTreeParamspublic final DoubleParam minInfoGain()
DecisionTreeParamsminInfoGain in interface DecisionTreeParamspublic final IntParam maxMemoryInMB()
DecisionTreeParamsmaxMemoryInMB in interface DecisionTreeParamspublic final BooleanParam cacheNodeIds()
DecisionTreeParamscacheNodeIds in interface DecisionTreeParamspublic final Param<String> weightCol()
HasWeightColweightCol in interface HasWeightColpublic final LongParam seed()
HasSeedpublic final IntParam checkpointInterval()
HasCheckpointIntervalcheckpointInterval in interface HasCheckpointIntervalpublic String uid()
Identifiableuid in interface Identifiablepublic int numFeatures()
PredictionModelnumFeatures in class PredictionModel<Vector,RandomForestClassificationModel>public int numClasses()
ClassificationModelnumClasses in class ClassificationModel<Vector,RandomForestClassificationModel>public DecisionTreeClassificationModel[] trees()
TreeEnsembleModeltrees in interface TreeEnsembleModel<DecisionTreeClassificationModel>public double[] treeWeights()
TreeEnsembleModeltreestreeWeights in interface TreeEnsembleModel<DecisionTreeClassificationModel>public RandomForestClassificationTrainingSummary summary()
hasSummary is false.summary in interface HasTrainingSummary<RandomForestClassificationTrainingSummary>public BinaryRandomForestClassificationTrainingSummary binarySummary()
hasSummary is false or it is a multiclass model.public RandomForestClassificationSummary evaluate(Dataset<?> dataset)
dataset - Test dataset to evaluate model on.public StructType transformSchema(StructType schema)
PipelineStage
We check validity for interactions between parameters during transformSchema and
raise an exception if any parameter value is invalid. Parameter value checks which
do not depend on other parameters are handled by Param.validate().
Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks.
transformSchema in class ProbabilisticClassificationModel<Vector,RandomForestClassificationModel>schema - (undocumented)public Dataset<Row> transform(Dataset<?> dataset)
ProbabilisticClassificationModelfeaturesCol, and appending new columns as specified by
parameters:
- predicted labels as predictionCol of type Double
- raw predictions (confidences) as rawPredictionCol of type Vector
- probability of each class as probabilityCol of type Vector.
transform in class ProbabilisticClassificationModel<Vector,RandomForestClassificationModel>dataset - input datasetpublic Vector predictRaw(Vector features)
ClassificationModeltransform() and output rawPredictionCol.
predictRaw in class ClassificationModel<Vector,RandomForestClassificationModel>features - (undocumented)public RandomForestClassificationModel copy(ParamMap extra)
ParamsdefaultCopy().copy in interface Paramscopy in class Model<RandomForestClassificationModel>extra - (undocumented)public String toString()
TreeEnsembleModeltoString in interface TreeEnsembleModel<DecisionTreeClassificationModel>toString in interface IdentifiabletoString in class Objectpublic Vector featureImportances()
public MLWriter write()
MLWritableMLWriter instance for this ML instance.write in interface MLWritable