public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector,LogisticRegressionModel> implements MLWritable
LogisticRegression
.Modifier and Type | Method and Description |
---|---|
void |
checkThresholdConsistency()
If
threshold and thresholds are both set, ensures they are consistent. |
Vector |
coefficients() |
LogisticRegressionModel |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
Param<String> |
featuresCol()
Param for features column name.
|
String |
getFeaturesCol() |
String |
getLabelCol() |
String |
getPredictionCol() |
String |
getRawPredictionCol() |
double |
getThreshold()
Get threshold for binary classification.
|
double[] |
getThresholds()
Get thresholds for binary or multiclass classification.
|
boolean |
hasSummary()
Indicates whether a training summary exists for this model instance.
|
double |
intercept() |
Param<String> |
labelCol()
Param for label column name.
|
static LogisticRegressionModel |
load(String path) |
int |
numClasses()
Number of classes (values which the label can take).
|
int |
numFeatures()
Returns the number of features the model was trained on.
|
Param<String> |
predictionCol()
Param for prediction column name.
|
Param<String> |
rawPredictionCol()
Param for raw prediction (a.k.a.
|
static MLReader<LogisticRegressionModel> |
read() |
LogisticRegressionModel |
setThreshold(double value)
Set threshold in binary classification, in range [0, 1].
|
LogisticRegressionModel |
setThresholds(double[] value)
Set thresholds in multiclass (or binary) classification to adjust the probability of
predicting each class.
|
LogisticRegressionTrainingSummary |
summary()
Gets summary of model on training set.
|
String |
uid()
An immutable unique ID for the object and its derivatives.
|
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
void |
validateParams() |
Vector |
weights() |
MLWriter |
write()
Returns a
MLWriter instance for this ML instance. |
normalizeToProbabilitiesInPlace, setProbabilityCol, transform
setRawPredictionCol
setFeaturesCol, setPredictionCol, transformSchema
transform, transform, transform
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
toString
save
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
public static MLReader<LogisticRegressionModel> read()
public static LogisticRegressionModel load(String path)
public String uid()
Identifiable
uid
in interface Identifiable
public Vector coefficients()
public double intercept()
public Vector weights()
public LogisticRegressionModel setThreshold(double value)
If the estimated probability of class label 1 is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 more often; a low threshold encourages the model to predict 1 more often.
Note: Calling this with threshold p is equivalent to calling setThresholds(Array(1-p, p))
.
When setThreshold()
is called, any user-set value for thresholds
will be cleared.
If both threshold
and thresholds
are set in a ParamMap, then they must be
equivalent.
Default is 0.5.
value
- (undocumented)public double getThreshold()
If threshold
is set, returns that value.
Otherwise, if thresholds
is set with length 2 (i.e., binary classification),
this returns the equivalent threshold:
1 / (1 + thresholds(0) / thresholds(1))
.
Otherwise, returns {@link threshold} default value.
@group getParam
@throws IllegalArgumentException if {@link thresholds} is set to an array of length other than 2.public LogisticRegressionModel setThresholds(double[] value)
Note: When setThresholds()
is called, any user-set value for threshold
will be cleared.
If both threshold
and thresholds
are set in a ParamMap, then they must be
equivalent.
setThresholds
in class ProbabilisticClassificationModel<Vector,LogisticRegressionModel>
value
- (undocumented)public double[] getThresholds()
If thresholds
is set, return its value.
Otherwise, if threshold
is set, return the equivalent thresholds for binary
classification: (1-threshold, threshold).
If neither are set, throw an exception.
public int numFeatures()
PredictionModel
numFeatures
in class PredictionModel<Vector,LogisticRegressionModel>
public int numClasses()
ClassificationModel
numClasses
in class ClassificationModel<Vector,LogisticRegressionModel>
public LogisticRegressionTrainingSummary summary()
trainingSummary == None
.public boolean hasSummary()
public LogisticRegressionModel copy(ParamMap extra)
Params
copy
in interface Params
copy
in class Model<LogisticRegressionModel>
extra
- (undocumented)defaultCopy()
public MLWriter write()
MLWriter
instance for this ML instance.
For LogisticRegressionModel
, this does NOT currently save the training summary
.
An option to save summary
may be added in the future.
This also does not save the parent
currently.
write
in interface MLWritable
public void checkThresholdConsistency()
threshold
and thresholds
are both set, ensures they are consistent.IllegalArgumentException
- if threshold
and thresholds
are not equivalentpublic void validateParams()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public Param<String> rawPredictionCol()
public String getRawPredictionCol()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.public Param<String> labelCol()
public String getLabelCol()
public Param<String> featuresCol()
public String getFeaturesCol()
public Param<String> predictionCol()
public String getPredictionCol()