StreamingKMeansModel

class pyspark.mllib.clustering.StreamingKMeansModel(clusterCenters: List[VectorLike], clusterWeights: VectorLike)[source]

Clustering model which can perform an online update of the centroids.

The update formula for each centroid is given by

  • c_t+1 = ((c_t * n_t * a) + (x_t * m_t)) / (n_t + m_t)

  • n_t+1 = n_t * a + m_t

where

  • c_t: Centroid at the n_th iteration.

  • n_t: Number of samples (or) weights associated with the centroid at the n_th iteration.

  • x_t: Centroid of the new data closest to c_t.

  • m_t: Number of samples (or) weights of the new data closest to c_t

  • c_t+1: New centroid.

  • n_t+1: New number of weights.

  • a: Decay Factor, which gives the forgetfulness.

New in version 1.5.0.

Parameters
clusterCenterslist of pyspark.mllib.linalg.Vector or covertible

Initial cluster centers.

clusterWeightspyspark.mllib.linalg.Vector or covertible

List of weights assigned to each cluster.

Notes

If a is set to 1, it is the weighted mean of the previous and new data. If it set to zero, the old centroids are completely forgotten.

Examples

>>> initCenters = [[0.0, 0.0], [1.0, 1.0]]
>>> initWeights = [1.0, 1.0]
>>> stkm = StreamingKMeansModel(initCenters, initWeights)
>>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1],
...                        [0.9, 0.9], [1.1, 1.1]])
>>> stkm = stkm.update(data, 1.0, "batches")
>>> stkm.centers
array([[ 0.,  0.],
       [ 1.,  1.]])
>>> stkm.predict([-0.1, -0.1])
0
>>> stkm.predict([0.9, 0.9])
1
>>> stkm.clusterWeights
[3.0, 3.0]
>>> decayFactor = 0.0
>>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])])
>>> stkm = stkm.update(data, 0.0, "batches")
>>> stkm.centers
array([[ 0.2,  0.2],
       [ 1.5,  1.5]])
>>> stkm.clusterWeights
[1.0, 1.0]
>>> stkm.predict([0.2, 0.2])
0
>>> stkm.predict([1.5, 1.5])
1

Methods

computeCost(rdd)

Return the K-means cost (sum of squared distances of points to their nearest center) for this model on the given data.

load(sc, path)

Load a model from the given path.

predict(x)

Find the cluster that each of the points belongs to in this model.

save(sc, path)

Save this model to the given path.

update(data, decayFactor, timeUnit)

Update the centroids, according to data

Attributes

clusterCenters

Get the cluster centers, represented as a list of NumPy arrays.

clusterWeights

Return the cluster weights.

k

Total number of clusters.

Methods Documentation

computeCost(rdd: pyspark.rdd.RDD[VectorLike]) → float

Return the K-means cost (sum of squared distances of points to their nearest center) for this model on the given data.

New in version 1.4.0.

Parameters
rdd:pyspark.RDD

The RDD of points to compute the cost on.

classmethod load(sc: pyspark.context.SparkContext, path: str)pyspark.mllib.clustering.KMeansModel

Load a model from the given path.

New in version 1.4.0.

predict(x: Union[VectorLike, pyspark.rdd.RDD[VectorLike]]) → Union[int, pyspark.rdd.RDD[int]]

Find the cluster that each of the points belongs to in this model.

New in version 0.9.0.

Parameters
xpyspark.mllib.linalg.Vector or pyspark.RDD

A data point (or RDD of points) to determine cluster index. pyspark.mllib.linalg.Vector can be replaced with equivalent objects (list, tuple, numpy.ndarray).

Returns
int or pyspark.RDD of int

Predicted cluster index or an RDD of predicted cluster indices if the input is an RDD.

save(sc: pyspark.context.SparkContext, path: str) → None

Save this model to the given path.

New in version 1.4.0.

update(data: pyspark.rdd.RDD[VectorLike], decayFactor: float, timeUnit: str) → StreamingKMeansModel[source]

Update the centroids, according to data

New in version 1.5.0.

Parameters
datapyspark.RDD

RDD with new data for the model update.

decayFactorfloat

Forgetfulness of the previous centroids.

timeUnitstr

Can be “batches” or “points”. If points, then the decay factor is raised to the power of number of new points and if batches, then decay factor will be used as is.

.. versionadded:: 1.5.0

Attributes Documentation

clusterCenters

Get the cluster centers, represented as a list of NumPy arrays.

New in version 1.0.0.

clusterWeights

Return the cluster weights.

New in version 1.5.0.

k

Total number of clusters.

New in version 1.4.0.