public class FMClassifier extends ProbabilisticClassifier<Vector,FMClassifier,FMClassificationModel> implements FactorizationMachines, FMClassifierParams, DefaultParamsWritable, org.apache.spark.internal.Logging
The implementation is based upon: S. Rendle. "Factorization machines" 2010.
FM is able to estimate interactions even in problems with huge sparsity (like advertising and recommendation system). FM formula is:
$$ \begin{align} y = \sigma\left( w_0 + \sum\limits^n_{i-1} w_i x_i + \sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j \right) \end{align} $$First two terms denote global bias and linear term (as same as linear regression), and last term denotes pairwise interactions term. v_i describes the i-th variable with k factors.
FM classification model uses logistic loss which can be solved by gradient descent method, and regularization terms like L2 are usually added to the loss function to prevent overfitting.
Constructor and Description |
---|
FMClassifier() |
FMClassifier(String uid) |
Modifier and Type | Method and Description |
---|---|
FMClassifier |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
IntParam |
factorSize()
Param for dimensionality of the factors (>= 0)
|
BooleanParam |
fitIntercept()
Param for whether to fit an intercept term.
|
BooleanParam |
fitLinear()
Param for whether to fit linear term (aka 1-way term)
|
DoubleParam |
initStd()
Param for standard deviation of initial coefficients
|
static FMClassifier |
load(String path) |
IntParam |
maxIter()
Param for maximum number of iterations (>= 0).
|
DoubleParam |
miniBatchFraction()
Param for mini-batch fraction, must be in range (0, 1]
|
static MLReader<T> |
read() |
DoubleParam |
regParam()
Param for regularization parameter (>= 0).
|
LongParam |
seed()
Param for random seed.
|
FMClassifier |
setFactorSize(int value)
Set the dimensionality of the factors.
|
FMClassifier |
setFitIntercept(boolean value)
Set whether to fit intercept term.
|
FMClassifier |
setFitLinear(boolean value)
Set whether to fit linear term.
|
FMClassifier |
setInitStd(double value)
Set the standard deviation of initial coefficients.
|
FMClassifier |
setMaxIter(int value)
Set the maximum number of iterations.
|
FMClassifier |
setMiniBatchFraction(double value)
Set the mini-batch fraction parameter.
|
FMClassifier |
setRegParam(double value)
Set the L2 regularization parameter.
|
FMClassifier |
setSeed(long value)
Set the random seed for weight initialization.
|
FMClassifier |
setSolver(String value)
Set the solver algorithm used for optimization.
|
FMClassifier |
setStepSize(double value)
Set the initial step size for the first step (like learning rate).
|
FMClassifier |
setTol(double value)
Set the convergence tolerance of iterations.
|
Param<String> |
solver()
The solver algorithm for optimization.
|
DoubleParam |
stepSize()
Param for Step size to be used for each iteration of optimization (> 0).
|
DoubleParam |
tol()
Param for the convergence tolerance for iterative algorithms (>= 0).
|
String |
uid()
An immutable unique ID for the object and its derivatives.
|
Param<String> |
weightCol()
Param for weight column name.
|
probabilityCol, setProbabilityCol, setThresholds, thresholds
rawPredictionCol, setRawPredictionCol
featuresCol, fit, labelCol, predictionCol, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
params
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
initCoefficients, trainImpl
validateAndTransformSchema
getLabelCol, labelCol
featuresCol, getFeaturesCol
getPredictionCol, predictionCol
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
toString
getRawPredictionCol, rawPredictionCol
getProbabilityCol, probabilityCol
getThresholds, thresholds
getFactorSize, getFitLinear, getInitStd, getMiniBatchFraction
getMaxIter
getStepSize
getFitIntercept
getRegParam
getWeightCol
write
save
$init$, initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, initLock, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log__$eq, org$apache$spark$internal$Logging$$log_, uninitialize
public static FMClassifier load(String path)
public static MLReader<T> read()
public final IntParam factorSize()
FactorizationMachinesParams
factorSize
in interface FactorizationMachinesParams
public final BooleanParam fitLinear()
FactorizationMachinesParams
fitLinear
in interface FactorizationMachinesParams
public final DoubleParam miniBatchFraction()
FactorizationMachinesParams
miniBatchFraction
in interface FactorizationMachinesParams
public final DoubleParam initStd()
FactorizationMachinesParams
initStd
in interface FactorizationMachinesParams
public final Param<String> solver()
FactorizationMachinesParams
solver
in interface HasSolver
solver
in interface FactorizationMachinesParams
public final Param<String> weightCol()
HasWeightCol
weightCol
in interface HasWeightCol
public final DoubleParam regParam()
HasRegParam
regParam
in interface HasRegParam
public final BooleanParam fitIntercept()
HasFitIntercept
fitIntercept
in interface HasFitIntercept
public final LongParam seed()
HasSeed
public final DoubleParam tol()
HasTol
public DoubleParam stepSize()
HasStepSize
stepSize
in interface HasStepSize
public final IntParam maxIter()
HasMaxIter
maxIter
in interface HasMaxIter
public String uid()
Identifiable
uid
in interface Identifiable
public FMClassifier setFactorSize(int value)
value
- (undocumented)public FMClassifier setFitIntercept(boolean value)
value
- (undocumented)public FMClassifier setFitLinear(boolean value)
value
- (undocumented)public FMClassifier setRegParam(double value)
value
- (undocumented)public FMClassifier setMiniBatchFraction(double value)
value
- (undocumented)public FMClassifier setInitStd(double value)
value
- (undocumented)public FMClassifier setMaxIter(int value)
value
- (undocumented)public FMClassifier setStepSize(double value)
value
- (undocumented)public FMClassifier setTol(double value)
value
- (undocumented)public FMClassifier setSolver(String value)
value
- (undocumented)public FMClassifier setSeed(long value)
value
- (undocumented)public FMClassifier copy(ParamMap extra)
Params
defaultCopy()
.copy
in interface Params
copy
in class Predictor<Vector,FMClassifier,FMClassificationModel>
extra
- (undocumented)