Class Estimator<M extends Model<M>>

Object
org.apache.spark.ml.PipelineStage
org.apache.spark.ml.Estimator<M>
All Implemented Interfaces:
Serializable, org.apache.spark.internal.Logging, Params, Identifiable
Direct Known Subclasses:
ALS, BisectingKMeans, BucketedRandomProjectionLSH, ChiSqSelector, CountVectorizer, CrossValidator, FPGrowth, GaussianMixture, IDF, Imputer, IsotonicRegression, KMeans, LDA, MaxAbsScaler, MinHashLSH, MinMaxScaler, OneHotEncoder, OneVsRest, PCA, Pipeline, Predictor, QuantileDiscretizer, RFormula, RobustScaler, StandardScaler, StringIndexer, TrainValidationSplit, UnivariateFeatureSelector, VarianceThresholdSelector, VectorIndexer, Word2Vec

public abstract class Estimator<M extends Model<M>> extends PipelineStage
Abstract class for estimators that fit models to data.
See Also:
  • Nested Class Summary

    Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging

    org.apache.spark.internal.Logging.LogStringContext, org.apache.spark.internal.Logging.SparkShellLoggingFilter
  • Constructor Summary

    Constructors
    Constructor
    Description
     
  • Method Summary

    Modifier and Type
    Method
    Description
    abstract Estimator<M>
    copy(ParamMap extra)
    Creates a copy of this instance with the same UID and some extra params.
    abstract M
    fit(Dataset<?> dataset)
    Fits a model to the input data.
    fit(Dataset<?> dataset, ParamMap paramMap)
    Fits a single model to the input data with provided parameter map.
    fit(Dataset<?> dataset, ParamPair<?> firstParamPair, ParamPair<?>... otherParamPairs)
    Fits a single model to the input data with optional parameters.
    fit(Dataset<?> dataset, ParamPair<?> firstParamPair, scala.collection.immutable.Seq<ParamPair<?>> otherParamPairs)
    Fits a single model to the input data with optional parameters.
    scala.collection.immutable.Seq<M>
    fit(Dataset<?> dataset, scala.collection.immutable.Seq<ParamMap> paramMaps)
    Fits multiple models to the input data with multiple sets of parameters.

    Methods inherited from class org.apache.spark.ml.PipelineStage

    params, transformSchema

    Methods inherited from class java.lang.Object

    equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait

    Methods inherited from interface org.apache.spark.ml.util.Identifiable

    toString, uid

    Methods inherited from interface org.apache.spark.internal.Logging

    initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logDebug, logDebug, logDebug, logDebug, logError, logError, logError, logError, logInfo, logInfo, logInfo, logInfo, logName, LogStringContext, logTrace, logTrace, logTrace, logTrace, logWarning, logWarning, logWarning, logWarning, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq, withLogContext
  • Constructor Details

    • Estimator

      public Estimator()
  • Method Details

    • copy

      public abstract Estimator<M> copy(ParamMap extra)
      Description copied from interface: Params
      Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. See defaultCopy().
      Specified by:
      copy in interface Params
      Specified by:
      copy in class PipelineStage
      Parameters:
      extra - (undocumented)
      Returns:
      (undocumented)
    • fit

      public M fit(Dataset<?> dataset, ParamPair<?> firstParamPair, ParamPair<?>... otherParamPairs)
      Fits a single model to the input data with optional parameters.

      Parameters:
      dataset - input dataset
      firstParamPair - the first param pair, overrides embedded params
      otherParamPairs - other param pairs. These values override any specified in this Estimator's embedded ParamMap.
      Returns:
      fitted model
    • fit

      public M fit(Dataset<?> dataset, ParamPair<?> firstParamPair, scala.collection.immutable.Seq<ParamPair<?>> otherParamPairs)
      Fits a single model to the input data with optional parameters.

      Parameters:
      dataset - input dataset
      firstParamPair - the first param pair, overrides embedded params
      otherParamPairs - other param pairs. These values override any specified in this Estimator's embedded ParamMap.
      Returns:
      fitted model
    • fit

      public M fit(Dataset<?> dataset, ParamMap paramMap)
      Fits a single model to the input data with provided parameter map.

      Parameters:
      dataset - input dataset
      paramMap - Parameter map. These values override any specified in this Estimator's embedded ParamMap.
      Returns:
      fitted model
    • fit

      public abstract M fit(Dataset<?> dataset)
      Fits a model to the input data.
      Parameters:
      dataset - (undocumented)
      Returns:
      (undocumented)
    • fit

      public scala.collection.immutable.Seq<M> fit(Dataset<?> dataset, scala.collection.immutable.Seq<ParamMap> paramMaps)
      Fits multiple models to the input data with multiple sets of parameters. The default implementation uses a for loop on each parameter map. Subclasses could override this to optimize multi-model training.

      Parameters:
      dataset - input dataset
      paramMaps - An array of parameter maps. These values override any specified in this Estimator's embedded ParamMap.
      Returns:
      fitted models, matching the input parameter maps