Class VectorIndexer

All Implemented Interfaces:
Serializable, org.apache.spark.internal.Logging, VectorIndexerParams, Params, HasHandleInvalid, HasInputCol, HasOutputCol, DefaultParamsWritable, Identifiable, MLWritable, scala.Serializable

public class VectorIndexer extends Estimator<VectorIndexerModel> implements VectorIndexerParams, DefaultParamsWritable
Class for indexing categorical feature columns in a dataset of Vector.

This has 2 usage modes: - Automatically identify categorical features (default behavior) - This helps process a dataset of unknown vectors into a dataset with some continuous features and some categorical features. The choice between continuous and categorical is based upon a maxCategories parameter. - Set maxCategories to the maximum number of categorical any categorical feature should have. - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. If maxCategories = 2, then feature 0 will be declared categorical and use indices {0, 1}, and feature 1 will be declared continuous. - Index all features, if all features are categorical - If maxCategories is set to be very large, then this will build an index of unique values for all features. - Warning: This can cause problems if features are continuous since this will collect ALL unique values to the driver. - E.g.: Feature 0 has unique values {-1.0, 0.0}, and feature 1 values {1.0, 3.0, 5.0}. If maxCategories is greater than or equal to 3, then both features will be declared categorical.

This returns a model which can transform categorical features to use 0-based indices.

Index stability: - This is not guaranteed to choose the same category index across multiple runs. - If a categorical feature includes value 0, then this is guaranteed to map value 0 to index 0. This maintains vector sparsity. - More stability may be added in the future.

TODO: Future extensions: The following functionality is planned for the future: - Preserve metadata in transform; if a feature's metadata is already present, do not recompute. - Specify certain features to not index, either via a parameter or via existing metadata. - Add warning if a categorical feature has only 1 category.

See Also:
  • Constructor Details

    • VectorIndexer

      public VectorIndexer(String uid)
    • VectorIndexer

      public VectorIndexer()
  • Method Details

    • load

      public static VectorIndexer load(String path)
    • read

      public static MLReader<T> read()
    • handleInvalid

      public Param<String> handleInvalid()
      Description copied from interface: VectorIndexerParams
      Param for how to handle invalid data (unseen labels or NULL values). Note: this param only applies to categorical features, not continuous ones. Options are: 'skip': filter out rows with invalid data. 'error': throw an error. 'keep': put invalid data in a special additional bucket, at index of the number of categories of the feature. Default value: "error"
      Specified by:
      handleInvalid in interface HasHandleInvalid
      Specified by:
      handleInvalid in interface VectorIndexerParams
      Returns:
      (undocumented)
    • maxCategories

      public IntParam maxCategories()
      Description copied from interface: VectorIndexerParams
      Threshold for the number of values a categorical feature can take. If a feature is found to have > maxCategories values, then it is declared continuous. Must be greater than or equal to 2.

      (default = 20)

      Specified by:
      maxCategories in interface VectorIndexerParams
      Returns:
      (undocumented)
    • outputCol

      public final Param<String> outputCol()
      Description copied from interface: HasOutputCol
      Param for output column name.
      Specified by:
      outputCol in interface HasOutputCol
      Returns:
      (undocumented)
    • inputCol

      public final Param<String> inputCol()
      Description copied from interface: HasInputCol
      Param for input column name.
      Specified by:
      inputCol in interface HasInputCol
      Returns:
      (undocumented)
    • uid

      public String uid()
      Description copied from interface: Identifiable
      An immutable unique ID for the object and its derivatives.
      Specified by:
      uid in interface Identifiable
      Returns:
      (undocumented)
    • setMaxCategories

      public VectorIndexer setMaxCategories(int value)
    • setInputCol

      public VectorIndexer setInputCol(String value)
    • setOutputCol

      public VectorIndexer setOutputCol(String value)
    • setHandleInvalid

      public VectorIndexer setHandleInvalid(String value)
    • fit

      public VectorIndexerModel fit(Dataset<?> dataset)
      Description copied from class: Estimator
      Fits a model to the input data.
      Specified by:
      fit in class Estimator<VectorIndexerModel>
      Parameters:
      dataset - (undocumented)
      Returns:
      (undocumented)
    • transformSchema

      public StructType transformSchema(StructType schema)
      Description copied from class: PipelineStage
      Check transform validity and derive the output schema from the input schema.

      We check validity for interactions between parameters during transformSchema and raise an exception if any parameter value is invalid. Parameter value checks which do not depend on other parameters are handled by Param.validate().

      Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks.

      Specified by:
      transformSchema in class PipelineStage
      Parameters:
      schema - (undocumented)
      Returns:
      (undocumented)
    • copy

      public VectorIndexer copy(ParamMap extra)
      Description copied from interface: Params
      Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. See defaultCopy().
      Specified by:
      copy in interface Params
      Specified by:
      copy in class Estimator<VectorIndexerModel>
      Parameters:
      extra - (undocumented)
      Returns:
      (undocumented)