public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector,LogisticRegressionModel> implements LogisticRegressionParams, MLWritable
LogisticRegression
.Modifier and Type | Method and Description |
---|---|
BinaryLogisticRegressionTrainingSummary |
binarySummary()
Gets summary of model on training set.
|
Matrix |
coefficientMatrix() |
Vector |
coefficients()
A vector of model coefficients for "binomial" logistic regression.
|
LogisticRegressionModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
LogisticRegressionSummary |
evaluate(Dataset<?> dataset)
Evaluates the model on a test dataset.
|
double |
getThreshold()
Get threshold for binary classification.
|
double[] |
getThresholds()
Get thresholds for binary or multiclass classification.
|
boolean |
hasSummary()
Indicates whether a training summary exists for this model instance.
|
double |
intercept()
The model intercept for "binomial" logistic regression.
|
Vector |
interceptVector() |
static LogisticRegressionModel |
load(String path) |
int |
numClasses()
Number of classes (values which the label can take).
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
double |
predict(Vector features)
Predict label for the given feature vector.
|
static MLReader<LogisticRegressionModel> |
read() |
LogisticRegressionModel |
setThreshold(double value)
Set threshold in binary classification, in range [0, 1].
|
LogisticRegressionModel |
setThresholds(double[] value)
Set thresholds in multiclass (or binary) classification to adjust the probability of
predicting each class.
|
LogisticRegressionTrainingSummary |
summary()
Gets summary of model on training set.
|
String |
toString() |
String |
uid()
An immutable unique ID for the object and its derivatives.
|
MLWriter |
write()
Returns a
MLWriter instance for this ML instance. |
normalizeToProbabilitiesInPlace, setProbabilityCol, transform
setRawPredictionCol
setFeaturesCol, setPredictionCol, transformSchema
transform, transform, transform
checkThresholdConsistency, family, getFamily, getLowerBoundsOnCoefficients, getLowerBoundsOnIntercepts, getUpperBoundsOnCoefficients, getUpperBoundsOnIntercepts, lowerBoundsOnCoefficients, lowerBoundsOnIntercepts, upperBoundsOnCoefficients, upperBoundsOnIntercepts, usingBoundConstrainedOptimization, validateAndTransformSchema
getLabelCol, labelCol
featuresCol, getFeaturesCol
getPredictionCol, predictionCol
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
getRawPredictionCol, rawPredictionCol
getProbabilityCol, probabilityCol
thresholds
getRegParam, regParam
elasticNetParam, getElasticNetParam
getMaxIter, maxIter
fitIntercept, getFitIntercept
getStandardization, standardization
getWeightCol, weightCol
threshold
aggregationDepth, getAggregationDepth
save
initializeLogging, initializeLogIfNecessary, initializeLogIfNecessary, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public static MLReader<LogisticRegressionModel> read()
public static LogisticRegressionModel load(String path)
public String uid()
Identifiable
uid
in interface Identifiable
public Matrix coefficientMatrix()
public Vector interceptVector()
public int numClasses()
ClassificationModel
numClasses
in class ClassificationModel<Vector,LogisticRegressionModel>
public Vector coefficients()
public double intercept()
public LogisticRegressionModel setThreshold(double value)
LogisticRegressionParams
If the estimated probability of class label 1 is greater than threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 more often; a low threshold encourages the model to predict 1 more often.
Note: Calling this with threshold p is equivalent to calling setThresholds(Array(1-p, p))
.
When setThreshold()
is called, any user-set value for thresholds
will be cleared.
If both threshold
and thresholds
are set in a ParamMap, then they must be
equivalent.
Default is 0.5.
setThreshold
in interface LogisticRegressionParams
value
- (undocumented)public double getThreshold()
LogisticRegressionParams
If thresholds
is set with length 2 (i.e., binary classification),
this returns the equivalent threshold:
1 / (1 + thresholds(0) / thresholds(1))
.
Otherwise, returns `threshold` if set, or its default value if unset.
@group getParam
@throws IllegalArgumentException if `thresholds` is set to an array of length other than 2.getThreshold
in interface LogisticRegressionParams
getThreshold
in interface HasThreshold
public LogisticRegressionModel setThresholds(double[] value)
LogisticRegressionParams
Note: When setThresholds()
is called, any user-set value for threshold
will be cleared.
If both threshold
and thresholds
are set in a ParamMap, then they must be
equivalent.
setThresholds
in interface LogisticRegressionParams
setThresholds
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
value
- (undocumented)public double[] getThresholds()
LogisticRegressionParams
If thresholds
is set, return its value.
Otherwise, if threshold
is set, return the equivalent thresholds for binary
classification: (1-threshold, threshold).
If neither are set, throw an exception.
getThresholds
in interface LogisticRegressionParams
getThresholds
in interface HasThresholds
public int numFeatures()
PredictionModel
numFeatures
in class PredictionModel<Vector,LogisticRegressionModel>
public LogisticRegressionTrainingSummary summary()
trainingSummary == None
.public BinaryLogisticRegressionTrainingSummary binarySummary()
trainingSummary == None
or it is a multiclass model.public boolean hasSummary()
public LogisticRegressionSummary evaluate(Dataset<?> dataset)
dataset
- Test dataset to evaluate model on.public double predict(Vector features)
thresholds
.predict
in class ClassificationModel<Vector,LogisticRegressionModel>
features
- (undocumented)public LogisticRegressionModel copy(ParamMap extra)
Params
defaultCopy()
.copy
in interface Params
copy
in class Model<LogisticRegressionModel>
extra
- (undocumented)public MLWriter write()
MLWriter
instance for this ML instance.
For LogisticRegressionModel
, this does NOT currently save the training summary
.
An option to save summary
may be added in the future.
This also does not save the parent
currently.
write
in interface MLWritable
public String toString()
toString
in interface Identifiable
toString
in class Object