Object
org.apache.spark.mllib.tree.configuration.Strategy
All Implemented Interfaces:
Serializable

public class Strategy extends Object implements Serializable
Stores all the configuration options for tree construction param: algo Learning goal. Supported: org.apache.spark.mllib.tree.configuration.Algo.Classification, org.apache.spark.mllib.tree.configuration.Algo.Regression param: impurity Criterion used for information gain calculation. Supported for Classification: Gini, Entropy. Supported for Regression: Variance. param: maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means 1 internal node + 2 leaf nodes). param: numClasses Number of classes for classification. (Ignored for regression.) Default value is 2 (binary classification). param: maxBins Maximum number of bins used for discretizing continuous features and for choosing how to split on features at each node. More bins give higher granularity. param: quantileCalculationStrategy Algorithm for calculating quantiles. Supported: org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort param: categoricalFeaturesInfo A map storing information about the categorical variables and the number of discrete values they take. An entry (n to k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}. param: minInstancesPerNode Minimum number of instances each child must have after split. Default value is 1. If a split cause left or right child to have less than minInstancesPerNode, this split will not be considered as a valid split. param: minInfoGain Minimum information gain a split must get. Default value is 0.0. If a split has less information gain than minInfoGain, this split will not be considered as a valid split. param: maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is 256 MB. If too small, then 1 node will be split per iteration, and its aggregates may exceed this size. param: subsamplingRate Fraction of the training data used for learning decision tree. param: useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will maintain a separate RDD of node Id cache for each row. param: checkpointInterval How often to checkpoint when the node Id cache gets updated. E.g. 10 means that the cache will get checkpointed every 10 updates. If the checkpoint directory is not set in SparkContext, this setting is ignored.
See Also:
  • Constructor Details

    • Strategy

      public Strategy(scala.Enumeration.Value algo, Impurity impurity, int maxDepth, int numClasses, int maxBins, scala.Enumeration.Value quantileCalculationStrategy, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo, int minInstancesPerNode, double minInfoGain, int maxMemoryInMB, double subsamplingRate, boolean useNodeIdCache, int checkpointInterval, double minWeightFractionPerNode, boolean bootstrap)
    • Strategy

      public Strategy(scala.Enumeration.Value algo, Impurity impurity, int maxDepth, int numClasses, int maxBins, scala.Enumeration.Value quantileCalculationStrategy, scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo, int minInstancesPerNode, double minInfoGain, int maxMemoryInMB, double subsamplingRate, boolean useNodeIdCache, int checkpointInterval)
      Backwards compatible constructor for Strategy
      Parameters:
      algo - (undocumented)
      impurity - (undocumented)
      maxDepth - (undocumented)
      numClasses - (undocumented)
      maxBins - (undocumented)
      quantileCalculationStrategy - (undocumented)
      categoricalFeaturesInfo - (undocumented)
      minInstancesPerNode - (undocumented)
      minInfoGain - (undocumented)
      maxMemoryInMB - (undocumented)
      subsamplingRate - (undocumented)
      useNodeIdCache - (undocumented)
      checkpointInterval - (undocumented)
    • Strategy

      public Strategy(scala.Enumeration.Value algo, Impurity impurity, int maxDepth, int numClasses, int maxBins, Map<Integer,Integer> categoricalFeaturesInfo)
      Java-friendly constructor for Strategy
      Parameters:
      algo - (undocumented)
      impurity - (undocumented)
      maxDepth - (undocumented)
      numClasses - (undocumented)
      maxBins - (undocumented)
      categoricalFeaturesInfo - (undocumented)
  • Method Details

    • defaultStrategy

      public static Strategy defaultStrategy(String algo)
      Construct a default set of parameters for DecisionTree
      Parameters:
      algo - "Classification" or "Regression"
      Returns:
      (undocumented)
    • defaultStrategy

      public static Strategy defaultStrategy(scala.Enumeration.Value algo)
      Construct a default set of parameters for DecisionTree
      Parameters:
      algo - Algo.Classification or Algo.Regression
      Returns:
      (undocumented)
    • algo

      public scala.Enumeration.Value algo()
    • impurity

      public Impurity impurity()
    • maxDepth

      public int maxDepth()
    • numClasses

      public int numClasses()
    • maxBins

      public int maxBins()
    • quantileCalculationStrategy

      public scala.Enumeration.Value quantileCalculationStrategy()
    • categoricalFeaturesInfo

      public scala.collection.immutable.Map<Object,Object> categoricalFeaturesInfo()
    • minInstancesPerNode

      public int minInstancesPerNode()
    • minInfoGain

      public double minInfoGain()
    • maxMemoryInMB

      public int maxMemoryInMB()
    • subsamplingRate

      public double subsamplingRate()
    • useNodeIdCache

      public boolean useNodeIdCache()
    • checkpointInterval

      public int checkpointInterval()
    • minWeightFractionPerNode

      public double minWeightFractionPerNode()
    • isMulticlassClassification

      public boolean isMulticlassClassification()
      Returns:
      (undocumented)
    • isMulticlassWithCategoricalFeatures

      public boolean isMulticlassWithCategoricalFeatures()
      Returns:
      (undocumented)
    • setAlgo

      public void setAlgo(String algo)
      Sets Algorithm using a String.
      Parameters:
      algo - (undocumented)
    • setCategoricalFeaturesInfo

      public void setCategoricalFeaturesInfo(Map<Integer,Integer> categoricalFeaturesInfo)
      Sets categoricalFeaturesInfo using a Java Map.
      Parameters:
      categoricalFeaturesInfo - (undocumented)
    • copy

      public Strategy copy()
      Returns a shallow copy of this instance.
      Returns:
      (undocumented)
    • getAlgo

      public scala.Enumeration.Value getAlgo()
    • getCategoricalFeaturesInfo

      public scala.collection.immutable.Map<Object,Object> getCategoricalFeaturesInfo()
    • getCheckpointInterval

      public int getCheckpointInterval()
    • getImpurity

      public Impurity getImpurity()
    • getMaxBins

      public int getMaxBins()
    • getMaxDepth

      public int getMaxDepth()
    • getMaxMemoryInMB

      public int getMaxMemoryInMB()
    • getMinInfoGain

      public double getMinInfoGain()
    • getMinInstancesPerNode

      public int getMinInstancesPerNode()
    • getMinWeightFractionPerNode

      public double getMinWeightFractionPerNode()
    • getNumClasses

      public int getNumClasses()
    • getQuantileCalculationStrategy

      public scala.Enumeration.Value getQuantileCalculationStrategy()
    • getSubsamplingRate

      public double getSubsamplingRate()
    • getUseNodeIdCache

      public boolean getUseNodeIdCache()
    • setAlgo

      public void setAlgo(scala.Enumeration.Value x$1)
    • setCategoricalFeaturesInfo

      public void setCategoricalFeaturesInfo(scala.collection.immutable.Map<Object,Object> x$1)
    • setCheckpointInterval

      public void setCheckpointInterval(int x$1)
    • setImpurity

      public void setImpurity(Impurity x$1)
    • setMaxBins

      public void setMaxBins(int x$1)
    • setMaxDepth

      public void setMaxDepth(int x$1)
    • setMaxMemoryInMB

      public void setMaxMemoryInMB(int x$1)
    • setMinInfoGain

      public void setMinInfoGain(double x$1)
    • setMinInstancesPerNode

      public void setMinInstancesPerNode(int x$1)
    • setMinWeightFractionPerNode

      public void setMinWeightFractionPerNode(double x$1)
    • setNumClasses

      public void setNumClasses(int x$1)
    • setQuantileCalculationStrategy

      public void setQuantileCalculationStrategy(scala.Enumeration.Value x$1)
    • setSubsamplingRate

      public void setSubsamplingRate(double x$1)
    • setUseNodeIdCache

      public void setUseNodeIdCache(boolean x$1)