public class NaiveBayesModel extends Object implements ClassificationModel, scala.Serializable, Saveable
param: labels list of labels param: pi log of class priors, whose dimension is C, number of labels param: theta log of class conditional probabilities, whose dimension is C-by-D, where D is number of features param: modelType The type of NB model to fit can be "multinomial" or "bernoulli"
Modifier and Type | Class and Description |
---|---|
static class |
NaiveBayesModel.SaveLoadV1_0$ |
static class |
NaiveBayesModel.SaveLoadV2_0$ |
Modifier and Type | Method and Description |
---|---|
double[] |
labels() |
static NaiveBayesModel |
load(SparkContext sc,
String path) |
String |
modelType() |
double[] |
pi() |
RDD<Object> |
predict(RDD<Vector> testData)
Predict values for the given data set using the model trained.
|
double |
predict(Vector testData)
Predict values for a single data point using the model trained.
|
RDD<Vector> |
predictProbabilities(RDD<Vector> testData)
Predict values for the given data set using the model trained.
|
Vector |
predictProbabilities(Vector testData)
Predict posterior class probabilities for a single data point using the model trained.
|
void |
save(SparkContext sc,
String path)
Save this model to the given path.
|
double[][] |
theta() |
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
predict
public static NaiveBayesModel load(SparkContext sc, String path)
public double[] labels()
public double[] pi()
public double[][] theta()
public String modelType()
public RDD<Object> predict(RDD<Vector> testData)
ClassificationModel
predict
in interface ClassificationModel
testData
- RDD representing data points to be predictedpublic double predict(Vector testData)
ClassificationModel
predict
in interface ClassificationModel
testData
- array representing a single data pointpublic RDD<Vector> predictProbabilities(RDD<Vector> testData)
testData
- RDD representing data points to be predictedpublic Vector predictProbabilities(Vector testData)
testData
- array representing a single data pointpublic void save(SparkContext sc, String path)
Saveable
This saves: - human-readable (JSON) model metadata to path/metadata/ - Parquet formatted data to path/data/
The model may be loaded using Loader.load
.