org.apache.spark.ml.regression
Class LinearRegression

Object
  extended by org.apache.spark.ml.PipelineStage
      extended by org.apache.spark.ml.Estimator<M>
          extended by org.apache.spark.ml.Predictor<FeaturesType,Learner,M>
              extended by org.apache.spark.ml.regression.LinearRegression
All Implemented Interfaces:
java.io.Serializable, Logging, Params

public class LinearRegression
extends Predictor<FeaturesType,Learner,M>
implements Logging

:: Experimental :: Linear regression.

The learning objective is to minimize the squared error, with regularization. The specific squared error loss function used is: L = 1/2n ||A weights - y||^2^

This support multiple types of regularization: - none (a.k.a. ordinary least squares) - L2 (ridge regression) - L1 (Lasso) - L2 + L1 (elastic net)

See Also:
Serialized Form

Constructor Summary
LinearRegression()
           
LinearRegression(String uid)
           
 
Method Summary
 LinearRegression copy(ParamMap extra)
          Creates a copy of this instance with the same UID and some extra params.
 LinearRegression setElasticNetParam(double value)
          Set the ElasticNet mixing parameter.
 LinearRegression setMaxIter(int value)
          Set the maximum number of iterations.
 LinearRegression setRegParam(double value)
          Set the regularization parameter.
 LinearRegression setTol(double value)
          Set the convergence tolerance of iterations.
 String uid()
           
 StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
          Validates and transforms the input schema with the provided param map.
 
Methods inherited from class org.apache.spark.ml.Predictor
fit, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
 
Methods inherited from class org.apache.spark.ml.Estimator
fit, fit, fit, fit
 
Methods inherited from class Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface org.apache.spark.Logging
initializeIfNecessary, initializeLogging, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
 
Methods inherited from interface org.apache.spark.ml.param.Params
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, setDefault, shouldOwn, validateParams
 

Constructor Detail

LinearRegression

public LinearRegression(String uid)

LinearRegression

public LinearRegression()
Method Detail

uid

public String uid()

setRegParam

public LinearRegression setRegParam(double value)
Set the regularization parameter. Default is 0.0.

Parameters:
value - (undocumented)
Returns:
(undocumented)

setElasticNetParam

public LinearRegression setElasticNetParam(double value)
Set the ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of L1 and L2. Default is 0.0 which is an L2 penalty.

Parameters:
value - (undocumented)
Returns:
(undocumented)

setMaxIter

public LinearRegression setMaxIter(int value)
Set the maximum number of iterations. Default is 100.

Parameters:
value - (undocumented)
Returns:
(undocumented)

setTol

public LinearRegression setTol(double value)
Set the convergence tolerance of iterations. Smaller value will lead to higher accuracy with the cost of more iterations. Default is 1E-6.

Parameters:
value - (undocumented)
Returns:
(undocumented)

copy

public LinearRegression copy(ParamMap extra)
Description copied from interface: Params
Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly.

Specified by:
copy in interface Params
Specified by:
copy in class Predictor<Vector,LinearRegression,LinearRegressionModel>
Parameters:
extra - (undocumented)
Returns:
(undocumented)
See Also:
defaultCopy()

validateAndTransformSchema

public StructType validateAndTransformSchema(StructType schema,
                                             boolean fitting,
                                             DataType featuresDataType)
Validates and transforms the input schema with the provided param map.

Parameters:
schema - input schema
fitting - whether this is in fitting
featuresDataType - SQL DataType for FeaturesType. E.g., VectorUDT for vector features.
Returns:
output schema