public class DecisionTree extends Object implements scala.Serializable, Logging
Constructor and Description |
---|
DecisionTree(Strategy strategy) |
Modifier and Type | Method and Description |
---|---|
static scala.collection.immutable.List<Object> |
extractMultiClassCategories(int input,
int maxFeatureValue)
Nested method to extract list of eligible categories given an index.
|
static void |
findBestSplits(RDD<BaggedPoint<TreePoint>> input,
DecisionTreeMetadata metadata,
Node[] topNodes,
scala.collection.immutable.Map<Object,Node[]> nodesForGroup,
scala.collection.immutable.Map<Object,scala.collection.immutable.Map<Object,RandomForest.NodeIndexInfo>> treeToNodeToIndexInfo,
Split[][] splits,
Bin[][] bins,
scala.collection.mutable.Queue<scala.Tuple2<Object,Node>> nodeQueue,
TimeTracker timer,
scala.Option<NodeIdCache> nodeIdCache)
Given a group of nodes, this finds the best split for each node.
|
static double[] |
findSplitsForContinuousFeature(double[] featureSamples,
DecisionTreeMetadata metadata,
int featureIndex)
Find splits for a continuous feature
NOTE: Returned number of splits is set based on
featureSamples and
could be different from the specified numSplits . |
DecisionTreeModel |
run(RDD<LabeledPoint> input)
Method to train a decision tree model over an RDD
|
static DecisionTreeModel |
train(RDD<LabeledPoint> input,
scala.Enumeration.Value algo,
Impurity impurity,
int maxDepth)
Method to train a decision tree model.
|
static DecisionTreeModel |
train(RDD<LabeledPoint> input,
scala.Enumeration.Value algo,
Impurity impurity,
int maxDepth,
int numClasses)
Method to train a decision tree model.
|
static DecisionTreeModel |
train(RDD<LabeledPoint> input,
scala.Enumeration.Value algo,
Impurity impurity,
int maxDepth,
int numClasses,
int maxBins,
scala.Enumeration.Value quantileCalculationStrategy,
scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo)
Method to train a decision tree model.
|
static DecisionTreeModel |
train(RDD<LabeledPoint> input,
Strategy strategy)
Method to train a decision tree model.
|
static DecisionTreeModel |
trainClassifier(JavaRDD<LabeledPoint> input,
int numClasses,
java.util.Map<Integer,Integer> categoricalFeaturesInfo,
String impurity,
int maxDepth,
int maxBins)
Java-friendly API for
DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int) |
static DecisionTreeModel |
trainClassifier(RDD<LabeledPoint> input,
int numClasses,
scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo,
String impurity,
int maxDepth,
int maxBins)
Method to train a decision tree model for binary or multiclass classification.
|
static DecisionTreeModel |
trainRegressor(JavaRDD<LabeledPoint> input,
java.util.Map<Integer,Integer> categoricalFeaturesInfo,
String impurity,
int maxDepth,
int maxBins)
Java-friendly API for
DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int) |
static DecisionTreeModel |
trainRegressor(RDD<LabeledPoint> input,
scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo,
String impurity,
int maxDepth,
int maxBins)
Method to train a decision tree model for regression.
|
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public DecisionTree(Strategy strategy)
public static DecisionTreeModel train(RDD<LabeledPoint> input, Strategy strategy)
Note: Using DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
and DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
is recommended to clearly separate classification and regression.
input
- Training dataset: RDD of LabeledPoint
.
For classification, labels should take values {0, 1, ..., numClasses-1}.
For regression, labels are real numbers.strategy
- The configuration parameters for the tree algorithm which specify the type
of algorithm (classification, regression, etc.), feature type (continuous,
categorical), depth of the tree, quantile calculation strategy, etc.public static DecisionTreeModel train(RDD<LabeledPoint> input, scala.Enumeration.Value algo, Impurity impurity, int maxDepth)
Note: Using DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
and DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
is recommended to clearly separate classification and regression.
input
- Training dataset: RDD of LabeledPoint
.
For classification, labels should take values {0, 1, ..., numClasses-1}.
For regression, labels are real numbers.algo
- algorithm, classification or regressionimpurity
- impurity criterion used for information gain calculationmaxDepth
- Maximum depth of the tree.
E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.public static DecisionTreeModel train(RDD<LabeledPoint> input, scala.Enumeration.Value algo, Impurity impurity, int maxDepth, int numClasses)
Note: Using DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
and DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
is recommended to clearly separate classification and regression.
input
- Training dataset: RDD of LabeledPoint
.
For classification, labels should take values {0, 1, ..., numClasses-1}.
For regression, labels are real numbers.algo
- algorithm, classification or regressionimpurity
- impurity criterion used for information gain calculationmaxDepth
- Maximum depth of the tree.
E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.numClasses
- number of classes for classification. Default value of 2.public static DecisionTreeModel train(RDD<LabeledPoint> input, scala.Enumeration.Value algo, Impurity impurity, int maxDepth, int numClasses, int maxBins, scala.Enumeration.Value quantileCalculationStrategy, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo)
Note: Using DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
and DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
is recommended to clearly separate classification and regression.
input
- Training dataset: RDD of LabeledPoint
.
For classification, labels should take values {0, 1, ..., numClasses-1}.
For regression, labels are real numbers.algo
- classification or regressionimpurity
- criterion used for information gain calculationmaxDepth
- Maximum depth of the tree.
E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.numClasses
- number of classes for classification. Default value of 2.maxBins
- maximum number of bins used for splitting featuresquantileCalculationStrategy
- algorithm for calculating quantilescategoricalFeaturesInfo
- Map storing arity of categorical features.
E.g., an entry (n -> k) indicates that feature n is categorical
with k categories indexed from 0: {0, 1, ..., k-1}.public static DecisionTreeModel trainClassifier(RDD<LabeledPoint> input, int numClasses, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo, String impurity, int maxDepth, int maxBins)
input
- Training dataset: RDD of LabeledPoint
.
Labels should take values {0, 1, ..., numClasses-1}.numClasses
- number of classes for classification.categoricalFeaturesInfo
- Map storing arity of categorical features.
E.g., an entry (n -> k) indicates that feature n is categorical
with k categories indexed from 0: {0, 1, ..., k-1}.impurity
- Criterion used for information gain calculation.
Supported values: "gini" (recommended) or "entropy".maxDepth
- Maximum depth of the tree.
E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
(suggested value: 5)maxBins
- maximum number of bins used for splitting features
(suggested value: 32)public static DecisionTreeModel trainClassifier(JavaRDD<LabeledPoint> input, int numClasses, java.util.Map<Integer,Integer> categoricalFeaturesInfo, String impurity, int maxDepth, int maxBins)
DecisionTree$.trainClassifier(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, int, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
public static DecisionTreeModel trainRegressor(RDD<LabeledPoint> input, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo, String impurity, int maxDepth, int maxBins)
input
- Training dataset: RDD of LabeledPoint
.
Labels are real numbers.categoricalFeaturesInfo
- Map storing arity of categorical features.
E.g., an entry (n -> k) indicates that feature n is categorical
with k categories indexed from 0: {0, 1, ..., k-1}.impurity
- Criterion used for information gain calculation.
Supported values: "variance".maxDepth
- Maximum depth of the tree.
E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
(suggested value: 5)maxBins
- maximum number of bins used for splitting features
(suggested value: 32)public static DecisionTreeModel trainRegressor(JavaRDD<LabeledPoint> input, java.util.Map<Integer,Integer> categoricalFeaturesInfo, String impurity, int maxDepth, int maxBins)
DecisionTree$.trainRegressor(org.apache.spark.rdd.RDD<org.apache.spark.mllib.regression.LabeledPoint>, scala.collection.immutable.Map<java.lang.Object, java.lang.Object>, java.lang.String, int, int)
public static void findBestSplits(RDD<BaggedPoint<TreePoint>> input, DecisionTreeMetadata metadata, Node[] topNodes, scala.collection.immutable.Map<Object,Node[]> nodesForGroup, scala.collection.immutable.Map<Object,scala.collection.immutable.Map<Object,RandomForest.NodeIndexInfo>> treeToNodeToIndexInfo, Split[][] splits, Bin[][] bins, scala.collection.mutable.Queue<scala.Tuple2<Object,Node>> nodeQueue, TimeTracker timer, scala.Option<NodeIdCache> nodeIdCache)
input
- Training data: RDD of TreePoint
metadata
- Learning and dataset metadatatopNodes
- Root node for each tree. Used for matching instances with nodes.nodesForGroup
- Mapping: treeIndex --> nodes to be split in treetreeToNodeToIndexInfo
- Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
where nodeIndexInfo stores the index in the group and the
feature subsets (if using feature subsets).splits
- possible splits for all features, indexed (numFeatures)(numSplits)bins
- possible bins for all features, indexed (numFeatures)(numBins)nodeQueue
- Queue of nodes to split, with values (treeIndex, node).
Updated with new non-leaf nodes which are created.nodeIdCache
- Node Id cache containing an RDD of Array[Int] where
each value in the array is the data point's node Id
for a corresponding tree. This is used to prevent the need
to pass the entire tree to the executors during
the node stat aggregation phase.public static scala.collection.immutable.List<Object> extractMultiClassCategories(int input, int maxFeatureValue)
public static double[] findSplitsForContinuousFeature(double[] featureSamples, DecisionTreeMetadata metadata, int featureIndex)
featureSamples
and
could be different from the specified numSplits
.
The numSplits
attribute in the DecisionTreeMetadata
class will be set accordingly.featureSamples
- feature values of each samplemetadata
- decision tree metadata
NOTE: metadata.numbins
will be changed accordingly
if there are not enough splits to be foundfeatureIndex
- feature index to find splitspublic DecisionTreeModel run(RDD<LabeledPoint> input)
input
- Training data: RDD of LabeledPoint