Class GaussianMixtureModel

All Implemented Interfaces:
Serializable, org.apache.spark.internal.Logging, GaussianMixtureParams, Params, HasAggregationDepth, HasFeaturesCol, HasMaxIter, HasPredictionCol, HasProbabilityCol, HasSeed, HasTol, HasWeightCol, HasTrainingSummary<GaussianMixtureSummary>, Identifiable, MLWritable

Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points are drawn from each Gaussian i with probability weights(i).

param: weights Weight for each Gaussian distribution in the mixture. This is a multinomial probability distribution over the k Gaussians, where weights(i) is the weight for Gaussian i, and weights sum to 1. param: gaussians Array of MultivariateGaussian where gaussians(i) represents the Multivariate Gaussian (Normal) Distribution for Gaussian i

  • Method Details

    • read

      public static MLReader<GaussianMixtureModel> read()
    • load

      public static GaussianMixtureModel load(String path)
    • k

      public final IntParam k()
      Number of independent Gaussians in the mixture model. Must be greater than 1. Default: 2.

    • aggregationDepth

      public final IntParam aggregationDepth()
      Param for suggested depth for treeAggregate (&gt;= 2).
    • tol

      public final DoubleParam tol()
      Param for the convergence tolerance for iterative algorithms (&gt;= 0).
    • probabilityCol

      public final Param<String> probabilityCol()
      Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.
    • weightCol

      public final Param<String> weightCol()
      Param for weight column name. If this is not set or empty, we treat all instance weights as 1.0.
    • predictionCol

      public final Param<String> predictionCol()
      Param for prediction column name.
    • seed

      public final LongParam seed()
      Param for random seed.
    • featuresCol

      public final Param<String> featuresCol()
      Param for features column name.
    • maxIter

      public final IntParam maxIter()
      Param for maximum number of iterations (&gt;= 0).
    • uid

      public String uid()
      An immutable unique ID for the object and its derivatives.
    • weights

      public double[] weights()
    • gaussians

      public MultivariateGaussian[] gaussians()
    • numFeatures

      public int numFeatures()
    • setFeaturesCol

      public GaussianMixtureModel setFeaturesCol(String value)
    • setPredictionCol

      public GaussianMixtureModel setPredictionCol(String value)
    • setProbabilityCol

      public GaussianMixtureModel setProbabilityCol(String value)
    • copy

      public GaussianMixtureModel copy(ParamMap extra)
      Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. See defaultCopy().
    • transform

      public Dataset<Row> transform(Dataset<?> dataset)
      Transforms the input dataset.
    • transformSchema

      public StructType transformSchema(StructType schema)
      Description copied from class: PipelineStage
      Check transform validity and derive the output schema from the input schema.

      We check validity for interactions between parameters during transformSchema and raise an exception if any parameter value is invalid. Parameter value checks which do not depend on other parameters are handled by Param.validate().

      Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks.

    • predict

      public int predict(Vector features)
    • predictProbability

      public Vector predictProbability(Vector features)
    • gaussiansDF

      public Dataset<Row> gaussiansDF()
      Retrieve Gaussian distributions as a DataFrame. Each row represents a Gaussian Distribution. Two columns are defined: mean and cov. Schema:
         |-- mean: vector (nullable = true)
         |-- cov: matrix (nullable = true)
    • write

      public MLWriter write()
      Returns a MLWriter instance for this ML instance.

      For GaussianMixtureModel, this does NOT currently save the training summary(). An option to save summary() may be added in the future.

    • toString

      public String toString()
    • summary

      public GaussianMixtureSummary summary()
      Gets summary of model on training set. An exception is thrown if hasSummary is false.
