Class GaussianMixtureModel

Object
org.apache.spark.mllib.clustering.GaussianMixtureModel
All Implemented Interfaces:
Serializable, Saveable, scala.Serializable

public class GaussianMixtureModel extends Object implements scala.Serializable, Saveable
Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are the respective mean and covariance for each Gaussian distribution i=1..k.

param: weights Weights for each Gaussian distribution in the mixture, where weights(i) is the weight for Gaussian i, and weights.sum == 1 param: gaussians Array of MultivariateGaussian where gaussians(i) represents the Multivariate Gaussian (Normal) Distribution for Gaussian i

See Also:
  • Constructor Details

    • GaussianMixtureModel

      public GaussianMixtureModel(double[] weights, MultivariateGaussian[] gaussians)
  • Method Details

    • load

      public static GaussianMixtureModel load(SparkContext sc, String path)
    • weights

      public double[] weights()
    • gaussians

      public MultivariateGaussian[] gaussians()
    • save

      public void save(SparkContext sc, String path)
      Description copied from interface: Saveable
      Save this model to the given path.

      This saves: - human-readable (JSON) model metadata to path/metadata/ - Parquet formatted data to path/data/

      The model may be loaded using Loader.load.

      Specified by:
      save in interface Saveable
      Parameters:
      sc - Spark context used to save model data.
      path - Path specifying the directory in which to save this model. If the directory already exists, this method throws an exception.
    • k

      public int k()
      Number of gaussians in mixture
      Returns:
      (undocumented)
    • predict

      public RDD<Object> predict(RDD<Vector> points)
      Maps given points to their cluster indices.
      Parameters:
      points - (undocumented)
      Returns:
      (undocumented)
    • predict

      public int predict(Vector point)
      Maps given point to its cluster index.
      Parameters:
      point - (undocumented)
      Returns:
      (undocumented)
    • predict

      public JavaRDD<Integer> predict(JavaRDD<Vector> points)
      Java-friendly version of predict()
      Parameters:
      points - (undocumented)
      Returns:
      (undocumented)
    • predictSoft

      public RDD<double[]> predictSoft(RDD<Vector> points)
      Given the input vectors, return the membership value of each vector to all mixture components.
      Parameters:
      points - (undocumented)
      Returns:
      (undocumented)
    • predictSoft

      public double[] predictSoft(Vector point)
      Given the input vector, return the membership values to all mixture components.
      Parameters:
      point - (undocumented)
      Returns:
      (undocumented)