public abstract class ProbabilisticClassificationModel<FeaturesType,M extends ProbabilisticClassificationModel<FeaturesType,M>> extends ClassificationModel<FeaturesType,M>
Model produced by a ProbabilisticClassifier
.
Classes are indexed {0, 1, ..., numClasses - 1}.
Constructor and Description |
---|
ProbabilisticClassificationModel() |
Modifier and Type | Method and Description |
---|---|
protected static <T> T |
$(Param<T> param) |
static Params |
clear(Param<?> param) |
abstract static M |
copy(ParamMap extra) |
protected static <T extends Params> |
copyValues(T to,
ParamMap extra) |
protected static <T extends Params> |
copyValues$default$2() |
protected static <T extends Params> |
defaultCopy(ParamMap extra) |
static java.lang.String |
explainParam(Param<?> param) |
static java.lang.String |
explainParams() |
static ParamMap |
extractParamMap() |
static ParamMap |
extractParamMap(ParamMap extra) |
static Param<java.lang.String> |
featuresCol() |
Param<java.lang.String> |
featuresCol()
Param for features column name.
|
protected static DataType |
featuresDataType() |
static <T> scala.Option<T> |
get(Param<T> param) |
static <T> scala.Option<T> |
getDefault(Param<T> param) |
static java.lang.String |
getFeaturesCol() |
java.lang.String |
getFeaturesCol() |
static java.lang.String |
getLabelCol() |
java.lang.String |
getLabelCol() |
static <T> T |
getOrDefault(Param<T> param) |
static Param<java.lang.Object> |
getParam(java.lang.String paramName) |
static java.lang.String |
getPredictionCol() |
java.lang.String |
getPredictionCol() |
static java.lang.String |
getProbabilityCol() |
static java.lang.String |
getRawPredictionCol() |
java.lang.String |
getRawPredictionCol() |
static double[] |
getThresholds() |
static <T> boolean |
hasDefault(Param<T> param) |
static boolean |
hasParam(java.lang.String paramName) |
static boolean |
hasParent() |
protected static void |
initializeLogIfNecessary(boolean isInterpreter) |
static boolean |
isDefined(Param<?> param) |
static boolean |
isSet(Param<?> param) |
protected static boolean |
isTraceEnabled() |
static Param<java.lang.String> |
labelCol() |
Param<java.lang.String> |
labelCol()
Param for label column name.
|
protected static org.slf4j.Logger |
log() |
protected static void |
logDebug(scala.Function0<java.lang.String> msg) |
protected static void |
logDebug(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logError(scala.Function0<java.lang.String> msg) |
protected static void |
logError(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg) |
protected static void |
logInfo(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static java.lang.String |
logName() |
protected static void |
logTrace(scala.Function0<java.lang.String> msg) |
protected static void |
logTrace(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg) |
protected static void |
logWarning(scala.Function0<java.lang.String> msg,
java.lang.Throwable throwable) |
static void |
normalizeToProbabilitiesInPlace(DenseVector v)
Normalize a vector of raw predictions to be a multinomial probability vector, in place.
|
abstract static int |
numClasses() |
static int |
numFeatures() |
static Param<?>[] |
params() |
static void |
parent_$eq(Estimator<M> x$1) |
static Estimator<M> |
parent() |
protected static double |
predict(FeaturesType features) |
static Param<java.lang.String> |
predictionCol() |
Param<java.lang.String> |
predictionCol()
Param for prediction column name.
|
protected Vector |
predictProbability(FeaturesType features)
Predict the probability of each class given the features.
|
protected abstract static Vector |
predictRaw(FeaturesType features) |
protected double |
probability2prediction(Vector probability)
Given a vector of class conditional probabilities, select the predicted label.
|
static Param<java.lang.String> |
probabilityCol() |
protected double |
raw2prediction(Vector rawPrediction)
Given a vector of raw predictions, select the predicted label.
|
protected Vector |
raw2probability(Vector rawPrediction)
Non-in-place version of
raw2probabilityInPlace() |
protected abstract Vector |
raw2probabilityInPlace(Vector rawPrediction)
Estimate the probability of each class given the raw prediction,
doing the computation in-place.
|
static Param<java.lang.String> |
rawPredictionCol() |
Param<java.lang.String> |
rawPredictionCol()
Param for raw prediction (a.k.a.
|
static <T> Params |
set(Param<T> param,
T value) |
protected static Params |
set(ParamPair<?> paramPair) |
protected static Params |
set(java.lang.String param,
java.lang.Object value) |
protected static <T> Params |
setDefault(Param<T> param,
T value) |
protected static Params |
setDefault(scala.collection.Seq<ParamPair<?>> paramPairs) |
static M |
setFeaturesCol(java.lang.String value) |
static M |
setParent(Estimator<M> parent) |
static M |
setPredictionCol(java.lang.String value) |
M |
setProbabilityCol(java.lang.String value) |
static M |
setRawPredictionCol(java.lang.String value) |
M |
setThresholds(double[] value) |
static DoubleArrayParam |
thresholds() |
static java.lang.String |
toString() |
Dataset<Row> |
transform(Dataset<?> dataset)
Transforms dataset by reading from
featuresCol , and appending new columns as specified by
parameters:
- predicted labels as predictionCol of type Double
- raw predictions (confidences) as rawPredictionCol of type Vector
- probability of each class as probabilityCol of type Vector . |
protected static Dataset<Row> |
transformImpl(Dataset<?> dataset) |
static StructType |
transformSchema(StructType schema) |
protected static StructType |
transformSchema(StructType schema,
boolean logging) |
abstract static java.lang.String |
uid() |
protected static StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
static void |
validateParams() |
numClasses, predict, predictRaw, setRawPredictionCol
featuresDataType, numFeatures, setFeaturesCol, setPredictionCol, transformImpl, transformSchema
transform, transform, transform
transformSchema
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
clear, copy, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn, validateParams
toString, uid
public static void normalizeToProbabilitiesInPlace(DenseVector v)
The input raw predictions should be >= 0. The output vector sums to 1, unless the input vector is all-0 (in which case the output is all-0 too).
NOTE: This is NOT applicable to all models, only ones which effectively use class instance counts for raw predictions.
v
- (undocumented)public abstract static java.lang.String uid()
public static java.lang.String toString()
public static Param<?>[] params()
public static void validateParams()
public static java.lang.String explainParam(Param<?> param)
public static java.lang.String explainParams()
public static final boolean isSet(Param<?> param)
public static final boolean isDefined(Param<?> param)
public static boolean hasParam(java.lang.String paramName)
public static Param<java.lang.Object> getParam(java.lang.String paramName)
protected static final Params set(java.lang.String param, java.lang.Object value)
public static final <T> scala.Option<T> get(Param<T> param)
public static final <T> T getOrDefault(Param<T> param)
protected static final <T> T $(Param<T> param)
public static final <T> scala.Option<T> getDefault(Param<T> param)
public static final <T> boolean hasDefault(Param<T> param)
public static final ParamMap extractParamMap()
protected static java.lang.String logName()
protected static org.slf4j.Logger log()
protected static void logInfo(scala.Function0<java.lang.String> msg)
protected static void logDebug(scala.Function0<java.lang.String> msg)
protected static void logTrace(scala.Function0<java.lang.String> msg)
protected static void logWarning(scala.Function0<java.lang.String> msg)
protected static void logError(scala.Function0<java.lang.String> msg)
protected static void logInfo(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logDebug(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logTrace(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logWarning(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static void logError(scala.Function0<java.lang.String> msg, java.lang.Throwable throwable)
protected static boolean isTraceEnabled()
protected static void initializeLogIfNecessary(boolean isInterpreter)
protected static StructType transformSchema(StructType schema, boolean logging)
public static Estimator<M> parent()
public static void parent_$eq(Estimator<M> x$1)
public static M setParent(Estimator<M> parent)
public static boolean hasParent()
public abstract static M copy(ParamMap extra)
public static final Param<java.lang.String> labelCol()
public static final java.lang.String getLabelCol()
public static final Param<java.lang.String> featuresCol()
public static final java.lang.String getFeaturesCol()
public static final Param<java.lang.String> predictionCol()
public static final java.lang.String getPredictionCol()
public static M setFeaturesCol(java.lang.String value)
public static M setPredictionCol(java.lang.String value)
public static int numFeatures()
protected static DataType featuresDataType()
public static StructType transformSchema(StructType schema)
public static final Param<java.lang.String> rawPredictionCol()
public static final java.lang.String getRawPredictionCol()
public static M setRawPredictionCol(java.lang.String value)
public abstract static int numClasses()
protected static double predict(FeaturesType features)
protected abstract static Vector predictRaw(FeaturesType features)
public static final Param<java.lang.String> probabilityCol()
public static final java.lang.String getProbabilityCol()
public static final DoubleArrayParam thresholds()
public static double[] getThresholds()
protected static StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public M setProbabilityCol(java.lang.String value)
public M setThresholds(double[] value)
public Dataset<Row> transform(Dataset<?> dataset)
featuresCol
, and appending new columns as specified by
parameters:
- predicted labels as predictionCol
of type Double
- raw predictions (confidences) as rawPredictionCol
of type Vector
- probability of each class as probabilityCol
of type Vector
.
transform
in class ClassificationModel<FeaturesType,M extends ProbabilisticClassificationModel<FeaturesType,M>>
dataset
- input datasetprotected abstract Vector raw2probabilityInPlace(Vector rawPrediction)
This internal method is used to implement transform()
and output probabilityCol
.
rawPrediction
- (undocumented)protected Vector raw2probability(Vector rawPrediction)
raw2probabilityInPlace()
protected double raw2prediction(Vector rawPrediction)
ClassificationModel
raw2prediction
in class ClassificationModel<FeaturesType,M extends ProbabilisticClassificationModel<FeaturesType,M>>
rawPrediction
- (undocumented)protected Vector predictProbability(FeaturesType features)
This internal method is used to implement transform()
and output probabilityCol
.
features
- (undocumented)protected double probability2prediction(Vector probability)
probability
- (undocumented)public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public Param<java.lang.String> rawPredictionCol()
public java.lang.String getRawPredictionCol()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.public Param<java.lang.String> labelCol()
public java.lang.String getLabelCol()
public Param<java.lang.String> featuresCol()
public java.lang.String getFeaturesCol()
public Param<java.lang.String> predictionCol()
public java.lang.String getPredictionCol()