Class FMClassifier
Object
org.apache.spark.ml.PipelineStage
org.apache.spark.ml.Estimator<M>
org.apache.spark.ml.Predictor<FeaturesType,E,M>
  
org.apache.spark.ml.classification.Classifier<FeaturesType,E,M>
  
org.apache.spark.ml.classification.ProbabilisticClassifier<Vector,FMClassifier,FMClassificationModel>
  
org.apache.spark.ml.classification.FMClassifier
- All Implemented Interfaces:
- Serializable,- org.apache.spark.internal.Logging,- ClassifierParams,- FMClassifierParams,- ProbabilisticClassifierParams,- Params,- HasFeaturesCol,- HasFitIntercept,- HasLabelCol,- HasMaxIter,- HasPredictionCol,- HasProbabilityCol,- HasRawPredictionCol,- HasRegParam,- HasSeed,- HasSolver,- HasStepSize,- HasThresholds,- HasTol,- HasWeightCol,- PredictorParams,- FactorizationMachines,- FactorizationMachinesParams,- DefaultParamsWritable,- Identifiable,- MLWritable
public class FMClassifier
extends ProbabilisticClassifier<Vector,FMClassifier,FMClassificationModel>
implements FactorizationMachines, FMClassifierParams, DefaultParamsWritable, org.apache.spark.internal.Logging  
Factorization Machines learning algorithm for classification.
 It supports normal gradient descent and AdamW solver.
 
The implementation is based on: S. Rendle. "Factorization machines" 2010.
FM is able to estimate interactions even in problems with huge sparsity (like advertising and recommendation system). FM formula is:
$$ \begin{align} y = \sigma\left( w_0 + \sum\limits^n_{i-1} w_i x_i + \sum\limits^n_{i=1} \sum\limits^n_{j=i+1} \langle v_i, v_j \rangle x_i x_j \right) \end{align} $$First two terms denote global bias and linear term (as same as linear regression), and last term denotes pairwise interactions term. v_i describes the i-th variable with k factors.
FM classification model uses logistic loss which can be solved by gradient descent method, and regularization terms like L2 are usually added to the loss function to prevent overfitting.
- See Also:
- Note:
- Multiclass labels are not currently supported.
- 
Nested Class SummaryNested classes/interfaces inherited from interface org.apache.spark.internal.Loggingorg.apache.spark.internal.Logging.LogStringContext, org.apache.spark.internal.Logging.SparkShellLoggingFilter
- 
Constructor SummaryConstructors
- 
Method SummaryModifier and TypeMethodDescriptionCreates a copy of this instance with the same UID and some extra params.longestimateModelSize(Dataset<?> dataset) final IntParamParam for dimensionality of the factors (>= 0)final BooleanParamParam for whether to fit an intercept term.final BooleanParamParam for whether to fit linear term (aka 1-way term)final DoubleParaminitStd()Param for standard deviation of initial coefficientsstatic FMClassifierfinal IntParammaxIter()Param for maximum number of iterations (>= 0).final DoubleParamParam for mini-batch fraction, must be in range (0, 1]static MLReader<T>read()final DoubleParamregParam()Param for regularization parameter (>= 0).final LongParamseed()Param for random seed.setFactorSize(int value) Set the dimensionality of the factors.setFitIntercept(boolean value) Set whether to fit intercept term.setFitLinear(boolean value) Set whether to fit linear term.setInitStd(double value) Set the standard deviation of initial coefficients.setMaxIter(int value) Set the maximum number of iterations.setMiniBatchFraction(double value) Set the mini-batch fraction parameter.setRegParam(double value) Set the L2 regularization parameter.setSeed(long value) Set the random seed for weight initialization.Set the solver algorithm used for optimization.setStepSize(double value) Set the initial step size for the first step (like learning rate).setTol(double value) Set the convergence tolerance of iterations.solver()The solver algorithm for optimization.stepSize()Param for Step size to be used for each iteration of optimization (> 0).final DoubleParamtol()Param for the convergence tolerance for iterative algorithms (>= 0).uid()An immutable unique ID for the object and its derivatives.Param for weight column name.Methods inherited from class org.apache.spark.ml.classification.ProbabilisticClassifierprobabilityCol, setProbabilityCol, setThresholds, thresholdsMethods inherited from class org.apache.spark.ml.classification.ClassifierrawPredictionCol, setRawPredictionColMethods inherited from class org.apache.spark.ml.PredictorfeaturesCol, fit, labelCol, predictionCol, setFeaturesCol, setLabelCol, setPredictionCol, transformSchemaMethods inherited from class org.apache.spark.ml.PipelineStageparamsMethods inherited from class java.lang.Objectequals, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitMethods inherited from interface org.apache.spark.ml.util.DefaultParamsWritablewriteMethods inherited from interface org.apache.spark.ml.regression.FactorizationMachinesinitCoefficients, trainImplMethods inherited from interface org.apache.spark.ml.regression.FactorizationMachinesParamsgetFactorSize, getFitLinear, getInitStd, getMiniBatchFractionMethods inherited from interface org.apache.spark.ml.param.shared.HasFeaturesColfeaturesCol, getFeaturesColMethods inherited from interface org.apache.spark.ml.param.shared.HasFitInterceptgetFitInterceptMethods inherited from interface org.apache.spark.ml.param.shared.HasLabelColgetLabelCol, labelColMethods inherited from interface org.apache.spark.ml.param.shared.HasMaxItergetMaxIterMethods inherited from interface org.apache.spark.ml.param.shared.HasPredictionColgetPredictionCol, predictionColMethods inherited from interface org.apache.spark.ml.param.shared.HasProbabilityColgetProbabilityCol, probabilityColMethods inherited from interface org.apache.spark.ml.param.shared.HasRawPredictionColgetRawPredictionCol, rawPredictionColMethods inherited from interface org.apache.spark.ml.param.shared.HasRegParamgetRegParamMethods inherited from interface org.apache.spark.ml.param.shared.HasStepSizegetStepSizeMethods inherited from interface org.apache.spark.ml.param.shared.HasThresholdsgetThresholds, thresholdsMethods inherited from interface org.apache.spark.ml.param.shared.HasWeightColgetWeightColMethods inherited from interface org.apache.spark.ml.util.IdentifiabletoStringMethods inherited from interface org.apache.spark.internal.LogginginitializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, isTraceEnabled, log, logBasedOnLevel, logDebug, logDebug, logDebug, logDebug, logError, logError, logError, logError, logInfo, logInfo, logInfo, logInfo, logName, LogStringContext, logTrace, logTrace, logTrace, logTrace, logWarning, logWarning, logWarning, logWarning, MDC, org$apache$spark$internal$Logging$$log_, org$apache$spark$internal$Logging$$log__$eq, withLogContextMethods inherited from interface org.apache.spark.ml.util.MLWritablesaveMethods inherited from interface org.apache.spark.ml.param.Paramsclear, copyValues, defaultCopy, defaultParamMap, estimateMatadataSize, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, onParamChange, paramMap, params, set, set, set, setDefault, setDefault, shouldOwnMethods inherited from interface org.apache.spark.ml.classification.ProbabilisticClassifierParamsvalidateAndTransformSchema
- 
Constructor Details- 
FMClassifier
- 
FMClassifierpublic FMClassifier()
 
- 
- 
Method Details- 
load
- 
read
- 
factorSizeDescription copied from interface:FactorizationMachinesParamsParam for dimensionality of the factors (>= 0)- Specified by:
- factorSizein interface- FactorizationMachinesParams
- Returns:
- (undocumented)
 
- 
fitLinearDescription copied from interface:FactorizationMachinesParamsParam for whether to fit linear term (aka 1-way term)- Specified by:
- fitLinearin interface- FactorizationMachinesParams
- Returns:
- (undocumented)
 
- 
miniBatchFractionDescription copied from interface:FactorizationMachinesParamsParam for mini-batch fraction, must be in range (0, 1]- Specified by:
- miniBatchFractionin interface- FactorizationMachinesParams
- Returns:
- (undocumented)
 
- 
initStdDescription copied from interface:FactorizationMachinesParamsParam for standard deviation of initial coefficients- Specified by:
- initStdin interface- FactorizationMachinesParams
- Returns:
- (undocumented)
 
- 
solverDescription copied from interface:FactorizationMachinesParamsThe solver algorithm for optimization. Supported options: "gd", "adamW". Default: "adamW"- Specified by:
- solverin interface- FactorizationMachinesParams
- Specified by:
- solverin interface- HasSolver
- Returns:
- (undocumented)
 
- 
weightColDescription copied from interface:HasWeightColParam for weight column name. If this is not set or empty, we treat all instance weights as 1.0.- Specified by:
- weightColin interface- HasWeightCol
- Returns:
- (undocumented)
 
- 
regParamDescription copied from interface:HasRegParamParam for regularization parameter (>= 0).- Specified by:
- regParamin interface- HasRegParam
- Returns:
- (undocumented)
 
- 
fitInterceptDescription copied from interface:HasFitInterceptParam for whether to fit an intercept term.- Specified by:
- fitInterceptin interface- HasFitIntercept
- Returns:
- (undocumented)
 
- 
seedDescription copied from interface:HasSeedParam for random seed.
- 
tolDescription copied from interface:HasTolParam for the convergence tolerance for iterative algorithms (>= 0).
- 
stepSizeDescription copied from interface:HasStepSizeParam for Step size to be used for each iteration of optimization (> 0).- Specified by:
- stepSizein interface- HasStepSize
- Returns:
- (undocumented)
 
- 
maxIterDescription copied from interface:HasMaxIterParam for maximum number of iterations (>= 0).- Specified by:
- maxIterin interface- HasMaxIter
- Returns:
- (undocumented)
 
- 
uidDescription copied from interface:IdentifiableAn immutable unique ID for the object and its derivatives.- Specified by:
- uidin interface- Identifiable
- Returns:
- (undocumented)
 
- 
setFactorSizeSet the dimensionality of the factors. Default is 8.- Parameters:
- value- (undocumented)
- Returns:
- (undocumented)
 
- 
setFitInterceptSet whether to fit intercept term. Default is true.- Parameters:
- value- (undocumented)
- Returns:
- (undocumented)
 
- 
setFitLinearSet whether to fit linear term. Default is true.- Parameters:
- value- (undocumented)
- Returns:
- (undocumented)
 
- 
setRegParamSet the L2 regularization parameter. Default is 0.0.- Parameters:
- value- (undocumented)
- Returns:
- (undocumented)
 
- 
setMiniBatchFractionSet the mini-batch fraction parameter. Default is 1.0.- Parameters:
- value- (undocumented)
- Returns:
- (undocumented)
 
- 
setInitStdSet the standard deviation of initial coefficients. Default is 0.01.- Parameters:
- value- (undocumented)
- Returns:
- (undocumented)
 
- 
setMaxIterSet the maximum number of iterations. Default is 100.- Parameters:
- value- (undocumented)
- Returns:
- (undocumented)
 
- 
setStepSizeSet the initial step size for the first step (like learning rate). Default is 1.0.- Parameters:
- value- (undocumented)
- Returns:
- (undocumented)
 
- 
setTolSet the convergence tolerance of iterations. Default is 1E-6.- Parameters:
- value- (undocumented)
- Returns:
- (undocumented)
 
- 
setSolverSet the solver algorithm used for optimization. Supported options: "gd", "adamW". Default: "adamW"- Parameters:
- value- (undocumented)
- Returns:
- (undocumented)
 
- 
setSeedSet the random seed for weight initialization.- Parameters:
- value- (undocumented)
- Returns:
- (undocumented)
 
- 
copyDescription copied from interface:ParamsCreates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. SeedefaultCopy().- Specified by:
- copyin interface- Params
- Specified by:
- copyin class- Predictor<Vector,- FMClassifier, - FMClassificationModel> 
- Parameters:
- extra- (undocumented)
- Returns:
- (undocumented)
 
- 
estimateModelSize
 
-