NaiveBayesModel

class pyspark.mllib.classification.NaiveBayesModel(labels: numpy.ndarray, pi: numpy.ndarray, theta: numpy.ndarray)[source]

Model for Naive Bayes classifiers.

New in version 0.9.0.

Parameters
labelsnumpy.ndarray

List of labels.

pinumpy.ndarray

Log of class priors, whose dimension is C, number of labels.

thetanumpy.ndarray

Log of class conditional probabilities, whose dimension is C-by-D, where D is number of features.

Examples

>>> from pyspark.mllib.linalg import SparseVector
>>> data = [
...     LabeledPoint(0.0, [0.0, 0.0]),
...     LabeledPoint(0.0, [0.0, 1.0]),
...     LabeledPoint(1.0, [1.0, 0.0]),
... ]
>>> model = NaiveBayes.train(sc.parallelize(data))
>>> model.predict(numpy.array([0.0, 1.0]))
0.0
>>> model.predict(numpy.array([1.0, 0.0]))
1.0
>>> model.predict(sc.parallelize([[1.0, 0.0]])).collect()
[1.0]
>>> sparse_data = [
...     LabeledPoint(0.0, SparseVector(2, {1: 0.0})),
...     LabeledPoint(0.0, SparseVector(2, {1: 1.0})),
...     LabeledPoint(1.0, SparseVector(2, {0: 1.0}))
... ]
>>> model = NaiveBayes.train(sc.parallelize(sparse_data))
>>> model.predict(SparseVector(2, {1: 1.0}))
0.0
>>> model.predict(SparseVector(2, {0: 1.0}))
1.0
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = NaiveBayesModel.load(sc, path)
>>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0}))
True
>>> from shutil import rmtree
>>> try:
...     rmtree(path)
... except OSError:
...     pass

Methods

load(sc, path)

Load a model from the given path.

predict(x)

Return the most likely class for a data vector or an RDD of vectors

save(sc, path)

Save this model to the given path.

Methods Documentation

classmethod load(sc: pyspark.context.SparkContext, path: str)pyspark.mllib.classification.NaiveBayesModel[source]

Load a model from the given path.

New in version 1.4.0.

predict(x: Union[VectorLike, pyspark.rdd.RDD[VectorLike]]) → Union[numpy.float64, pyspark.rdd.RDD[numpy.float64]][source]

Return the most likely class for a data vector or an RDD of vectors

New in version 0.9.0.

save(sc: pyspark.context.SparkContext, path: str) → None[source]

Save this model to the given path.