org.apache.spark.ml.regression

## Class LinearRegression

• All Implemented Interfaces:
java.io.Serializable, Logging, Params, DefaultParamsWritable, Identifiable, MLWritable

public class LinearRegression
extends Predictor<FeaturesType,Learner,M>
implements DefaultParamsWritable, Logging
Linear regression.

The learning objective is to minimize the specified loss function, with regularization. This supports two kinds of loss: - squaredError (a.k.a squared loss) - huber (a hybrid of squared error for relatively small errors and absolute error for relatively large ones, and we estimate the scale parameter from training data)

This supports multiple types of regularization: - none (a.k.a. ordinary least squares) - L2 (ridge regression) - L1 (Lasso) - L2 + L1 (elastic net)

The squared error objective function is:

\begin{align} \min_{w}\frac{1}{2n}{\sum_{i=1}^n(X_{i}w - y_{i})^{2} + \lambda\left[\frac{1-\alpha}{2}{||w||_{2}}^{2} + \alpha{||w||_{1}}\right]} \end{align}

The huber objective function is:

\begin{align} \min_{w, \sigma}\frac{1}{2n}{\sum_{i=1}^n\left(\sigma + H_m\left(\frac{X_{i}w - y_{i}}{\sigma}\right)\sigma\right) + \frac{1}{2}\lambda {||w||_2}^2} \end{align}

where

\begin{align} H_m(z) = \begin{cases} z^2, & \text {if } |z| &lt; \epsilon, \\ 2\epsilon|z| - \epsilon^2, & \text{otherwise} \end{cases} \end{align}

Note: Fitting with huber loss only supports none and L2 regularization.

Serialized Form
• ### Constructor Summary

Constructors
Constructor and Description
LinearRegression()
LinearRegression(String uid)
• ### Method Summary

All Methods
Modifier and Type Method and Description
static IntParam aggregationDepth()
static Params clear(Param<?> param)
LinearRegression copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
static DoubleParam elasticNetParam()
static DoubleParam epsilon()
DoubleParam epsilon()
The shape parameter to control the amount of robustness.
static String explainParam(Param<?> param)
static String explainParams()
static ParamMap extractParamMap()
static ParamMap extractParamMap(ParamMap extra)
static Param<String> featuresCol()
static M fit(Dataset<?> dataset)
static M fit(Dataset<?> dataset, ParamMap paramMap)
static scala.collection.Seq<M> fit(Dataset<?> dataset, ParamMap[] paramMaps)
static M fit(Dataset<?> dataset, ParamPair<?> firstParamPair, ParamPair<?>... otherParamPairs)
static M fit(Dataset<?> dataset, ParamPair<?> firstParamPair, scala.collection.Seq<ParamPair<?>> otherParamPairs)
static BooleanParam fitIntercept()
static <T> scala.Option<T> get(Param<T> param)
static int getAggregationDepth()
static <T> scala.Option<T> getDefault(Param<T> param)
static double getElasticNetParam()
static double getEpsilon()
double getEpsilon()
static String getFeaturesCol()
static boolean getFitIntercept()
static String getLabelCol()
static String getLoss()
static int getMaxIter()
static <T> T getOrDefault(Param<T> param)
static Param<Object> getParam(String paramName)
static String getPredictionCol()
static double getRegParam()
static String getSolver()
static boolean getStandardization()
static double getTol()
static String getWeightCol()
static <T> boolean hasDefault(Param<T> param)
static boolean hasParam(String paramName)
static boolean isDefined(Param<?> param)
static boolean isSet(Param<?> param)
static Param<String> labelCol()
static LinearRegression load(String path)
static Param<String> loss()
Param<String> loss()
The loss function to be optimized.
static int MAX_FEATURES_FOR_NORMAL_SOLVER()
When using LinearRegression.solver == "normal", the solver must limit the number of features to at most this number.
static IntParam maxIter()
static Param<?>[] params()
static Param<String> predictionCol()
static DoubleParam regParam()
static void save(String path)
static <T> Params set(Param<T> param, T value)
LinearRegression setAggregationDepth(int value)
Suggested depth for treeAggregate (greater than or equal to 2).
LinearRegression setElasticNetParam(double value)
Set the ElasticNet mixing parameter.
LinearRegression setEpsilon(double value)
Sets the value of param epsilon.
static Learner setFeaturesCol(String value)
LinearRegression setFitIntercept(boolean value)
Set if we should fit the intercept.
static Learner setLabelCol(String value)
LinearRegression setLoss(String value)
Sets the value of param loss.
LinearRegression setMaxIter(int value)
Set the maximum number of iterations.
static Learner setPredictionCol(String value)
LinearRegression setRegParam(double value)
Set the regularization parameter.
LinearRegression setSolver(String value)
Set the solver algorithm used for optimization.
LinearRegression setStandardization(boolean value)
Whether to standardize the training features before fitting the model.
LinearRegression setTol(double value)
Set the convergence tolerance of iterations.
LinearRegression setWeightCol(String value)
Whether to over-/under-sample training instances according to the given weights in weightCol.
static Param<String> solver()
Param<String> solver()
The solver algorithm for optimization.
static BooleanParam standardization()
static DoubleParam tol()
static String toString()
static StructType transformSchema(StructType schema)
String uid()
An immutable unique ID for the object and its derivatives.
StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
static Param<String> weightCol()
static MLWriter write()
• ### Methods inherited from class org.apache.spark.ml.Predictor

fit, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
• ### Methods inherited from class org.apache.spark.ml.Estimator

fit, fit, fit, fit
• ### Methods inherited from class Object

equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasRegParam

getRegParam, regParam
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasElasticNetParam

elasticNetParam, getElasticNetParam
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasMaxIter

getMaxIter, maxIter
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasTol

getTol, tol
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasFitIntercept

fitIntercept, getFitIntercept
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasStandardization

getStandardization, standardization
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasWeightCol

getWeightCol, weightCol
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasSolver

getSolver
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasAggregationDepth

aggregationDepth, getAggregationDepth
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasLoss

getLoss
• ### Methods inherited from interface org.apache.spark.ml.param.Params

clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
• ### Methods inherited from interface org.apache.spark.ml.util.Identifiable

toString
• ### Methods inherited from interface org.apache.spark.ml.util.DefaultParamsWritable

write
• ### Methods inherited from interface org.apache.spark.ml.util.MLWritable

save
• ### Methods inherited from interface org.apache.spark.internal.Logging

initializeLogging, initializeLogIfNecessary, initializeLogIfNecessary, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasLabelCol

getLabelCol, labelCol
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasFeaturesCol

featuresCol, getFeaturesCol
• ### Methods inherited from interface org.apache.spark.ml.param.shared.HasPredictionCol

getPredictionCol, predictionCol
• ### Constructor Detail

• #### LinearRegression

public LinearRegression(String uid)
• #### LinearRegression

public LinearRegression()
• ### Method Detail

public static LinearRegression load(String path)
• #### MAX_FEATURES_FOR_NORMAL_SOLVER

public static int MAX_FEATURES_FOR_NORMAL_SOLVER()
When using LinearRegression.solver == "normal", the solver must limit the number of features to at most this number. The entire covariance matrix X^T^X will be collected to the driver. This limit helps prevent memory overflow errors.
Returns:
(undocumented)
• #### toString

public static String toString()
• #### params

public static Param<?>[] params()
• #### explainParam

public static String explainParam(Param<?> param)
• #### explainParams

public static String explainParams()
• #### isSet

public static final boolean isSet(Param<?> param)
• #### isDefined

public static final boolean isDefined(Param<?> param)
• #### hasParam

public static boolean hasParam(String paramName)
• #### getParam

public static Param<Object> getParam(String paramName)
• #### set

public static final <T> Params set(Param<T> param,
T value)
• #### get

public static final <T> scala.Option<T> get(Param<T> param)
• #### clear

public static final Params clear(Param<?> param)
• #### getOrDefault

public static final <T> T getOrDefault(Param<T> param)
• #### getDefault

public static final <T> scala.Option<T> getDefault(Param<T> param)
• #### hasDefault

public static final <T> boolean hasDefault(Param<T> param)
• #### extractParamMap

public static final ParamMap extractParamMap(ParamMap extra)
• #### extractParamMap

public static final ParamMap extractParamMap()
• #### fit

public static M fit(Dataset<?> dataset,
ParamPair<?> firstParamPair,
scala.collection.Seq<ParamPair<?>> otherParamPairs)
• #### fit

public static M fit(Dataset<?> dataset,
ParamMap paramMap)
• #### fit

public static scala.collection.Seq<M> fit(Dataset<?> dataset,
ParamMap[] paramMaps)
• #### fit

public static M fit(Dataset<?> dataset,
ParamPair<?> firstParamPair,
ParamPair<?>... otherParamPairs)
• #### labelCol

public static final Param<String> labelCol()
• #### getLabelCol

public static final String getLabelCol()
• #### featuresCol

public static final Param<String> featuresCol()
• #### getFeaturesCol

public static final String getFeaturesCol()
• #### predictionCol

public static final Param<String> predictionCol()
• #### getPredictionCol

public static final String getPredictionCol()
• #### setLabelCol

public static Learner setLabelCol(String value)
• #### setFeaturesCol

public static Learner setFeaturesCol(String value)
• #### setPredictionCol

public static Learner setPredictionCol(String value)
• #### fit

public static M fit(Dataset<?> dataset)
• #### transformSchema

public static StructType transformSchema(StructType schema)
• #### regParam

public static final DoubleParam regParam()
• #### getRegParam

public static final double getRegParam()
• #### elasticNetParam

public static final DoubleParam elasticNetParam()
• #### getElasticNetParam

public static final double getElasticNetParam()
• #### maxIter

public static final IntParam maxIter()
• #### getMaxIter

public static final int getMaxIter()
• #### tol

public static final DoubleParam tol()
• #### getTol

public static final double getTol()
• #### fitIntercept

public static final BooleanParam fitIntercept()
• #### getFitIntercept

public static final boolean getFitIntercept()
• #### standardization

public static final BooleanParam standardization()
• #### getStandardization

public static final boolean getStandardization()
• #### weightCol

public static final Param<String> weightCol()
• #### getWeightCol

public static final String getWeightCol()
• #### getSolver

public static final String getSolver()
• #### aggregationDepth

public static final IntParam aggregationDepth()
• #### getAggregationDepth

public static final int getAggregationDepth()
• #### getLoss

public static final String getLoss()
• #### solver

public static final Param<String> solver()
• #### loss

public static final Param<String> loss()
• #### epsilon

public static final DoubleParam epsilon()
• #### getEpsilon

public static double getEpsilon()
• #### save

public static void save(String path)
throws java.io.IOException
Throws:
java.io.IOException
• #### write

public static MLWriter write()
• #### uid

public String uid()
Description copied from interface: Identifiable
An immutable unique ID for the object and its derivatives.
Specified by:
uid in interface Identifiable
Returns:
(undocumented)
• #### setRegParam

public LinearRegression setRegParam(double value)
Set the regularization parameter. Default is 0.0.

Parameters:
value - (undocumented)
Returns:
(undocumented)
• #### setFitIntercept

public LinearRegression setFitIntercept(boolean value)
Set if we should fit the intercept. Default is true.

Parameters:
value - (undocumented)
Returns:
(undocumented)
• #### setStandardization

public LinearRegression setStandardization(boolean value)
Whether to standardize the training features before fitting the model. The coefficients of models will be always returned on the original scale, so it will be transparent for users. Default is true.

Parameters:
value - (undocumented)
Returns:
(undocumented)
Note:
With/without standardization, the models should be always converged to the same solution when no regularization is applied. In R's GLMNET package, the default behavior is true as well.

• #### setElasticNetParam

public LinearRegression setElasticNetParam(double value)
Set the ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. For alpha in (0,1), the penalty is a combination of L1 and L2. Default is 0.0 which is an L2 penalty.

Note: Fitting with huber loss only supports None and L2 regularization, so throws exception if this param is non-zero value.

Parameters:
value - (undocumented)
Returns:
(undocumented)
• #### setMaxIter

public LinearRegression setMaxIter(int value)
Set the maximum number of iterations. Default is 100.

Parameters:
value - (undocumented)
Returns:
(undocumented)
• #### setTol

public LinearRegression setTol(double value)
Set the convergence tolerance of iterations. Smaller value will lead to higher accuracy with the cost of more iterations. Default is 1E-6.

Parameters:
value - (undocumented)
Returns:
(undocumented)
• #### setWeightCol

public LinearRegression setWeightCol(String value)
Whether to over-/under-sample training instances according to the given weights in weightCol. If not set or empty, all instances are treated equally (weight 1.0). Default is not set, so all instances have weight one.

Parameters:
value - (undocumented)
Returns:
(undocumented)
• #### setSolver

public LinearRegression setSolver(String value)
Set the solver algorithm used for optimization. In case of linear regression, this can be "l-bfgs", "normal" and "auto". - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton optimization method. - "normal" denotes using Normal Equation as an analytical solution to the linear regression problem. This solver is limited to LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER. - "auto" (default) means that the solver algorithm is selected automatically. The Normal Equations solver will be used when possible, but this will automatically fall back to iterative optimization methods when needed.

Note: Fitting with huber loss doesn't support normal solver, so throws exception if this param was set with "normal".

Parameters:
value - (undocumented)
Returns:
(undocumented)
• #### setAggregationDepth

public LinearRegression setAggregationDepth(int value)
Suggested depth for treeAggregate (greater than or equal to 2). If the dimensions of features or the number of partitions are large, this param could be adjusted to a larger size. Default is 2.

Parameters:
value - (undocumented)
Returns:
(undocumented)
• #### setLoss

public LinearRegression setLoss(String value)
Sets the value of param loss. Default is "squaredError".

Parameters:
value - (undocumented)
Returns:
(undocumented)
• #### setEpsilon

public LinearRegression setEpsilon(double value)
Sets the value of param epsilon. Default is 1.35.

Parameters:
value - (undocumented)
Returns:
(undocumented)
• #### copy

public LinearRegression copy(ParamMap extra)
Description copied from interface: Params
Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. See defaultCopy().
Specified by:
copy in interface Params
Specified by:
copy in class Predictor<Vector,LinearRegression,LinearRegressionModel>
Parameters:
extra - (undocumented)
Returns:
(undocumented)
• #### epsilon

public DoubleParam epsilon()
The shape parameter to control the amount of robustness. Must be &gt; 1.0. At larger values of epsilon, the huber criterion becomes more similar to least squares regression; for small values of epsilon, the criterion is more similar to L1 regression. Default is 1.35 to get as much robustness as possible while retaining 95% statistical efficiency for normally distributed data. It matches sklearn HuberRegressor and is "M" from A robust hybrid of lasso and ridge regression. Only valid when "loss" is "huber".

Returns:
(undocumented)
• #### getEpsilon

public double getEpsilon()
• #### loss

public Param<String> loss()
The loss function to be optimized. Supported options: "squaredError" and "huber". Default: "squaredError"

Specified by:
loss in interface HasLoss
Returns:
(undocumented)
• #### solver

public Param<String> solver()
The solver algorithm for optimization. Supported options: "l-bfgs", "normal" and "auto". Default: "auto"

Specified by:
solver in interface HasSolver
Returns:
(undocumented)
• #### validateAndTransformSchema

public StructType validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
• #### validateAndTransformSchema

public StructType validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.

Parameters:
schema - input schema
fitting - whether this is in fitting
featuresDataType - SQL DataType for FeaturesType. E.g., VectorUDT for vector features.
Returns:
output schema