Class AFTSurvivalRegression

All Implemented Interfaces:
Serializable, org.apache.spark.internal.Logging, Params, HasAggregationDepth, HasFeaturesCol, HasFitIntercept, HasLabelCol, HasMaxBlockSizeInMB, HasMaxIter, HasPredictionCol, HasTol, PredictorParams, AFTSurvivalRegressionParams, DefaultParamsWritable, Identifiable, MLWritable

public class AFTSurvivalRegression extends Regressor<Vector,AFTSurvivalRegression,AFTSurvivalRegressionModel> implements AFTSurvivalRegressionParams, DefaultParamsWritable, org.apache.spark.internal.Logging
Fit a parametric survival regression model named accelerated failure time (AFT) model (see Accelerated failure time model (Wikipedia)) based on the Weibull distribution of the survival time.

Since 3.1.0, it supports stacking instances into blocks and using GEMV for better performance. The block size will be 1.0 MB, if param maxBlockSizeInMB is set 0.0 by default.

See Also:
  • Constructor Details

    • AFTSurvivalRegression

      public AFTSurvivalRegression(String uid)
    • AFTSurvivalRegression

      public AFTSurvivalRegression()
  • Method Details

    • load

      public static AFTSurvivalRegression load(String path)
    • read

      public static MLReader<T> read()
    • censorCol

      public final Param<String> censorCol()
      Description copied from interface: AFTSurvivalRegressionParams
      Param for censor column name. The value of this column could be 0 or 1. If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored.
      Specified by:
      censorCol in interface AFTSurvivalRegressionParams
      Returns:
      (undocumented)
    • quantileProbabilities

      public final DoubleArrayParam quantileProbabilities()
      Description copied from interface: AFTSurvivalRegressionParams
      Param for quantile probabilities array. Values of the quantile probabilities array should be in the range (0, 1) and the array should be non-empty.
      Specified by:
      quantileProbabilities in interface AFTSurvivalRegressionParams
      Returns:
      (undocumented)
    • quantilesCol

      public final Param<String> quantilesCol()
      Description copied from interface: AFTSurvivalRegressionParams
      Param for quantiles column name. This column will output quantiles of corresponding quantileProbabilities if it is set.
      Specified by:
      quantilesCol in interface AFTSurvivalRegressionParams
      Returns:
      (undocumented)
    • maxBlockSizeInMB

      public final DoubleParam maxBlockSizeInMB()
      Description copied from interface: HasMaxBlockSizeInMB
      Param for Maximum memory in MB for stacking input data into blocks. Data is stacked within partitions. If more than remaining data size in a partition then it is adjusted to the data size. Default 0.0 represents choosing optimal value, depends on specific algorithm. Must be &gt;= 0..
      Specified by:
      maxBlockSizeInMB in interface HasMaxBlockSizeInMB
      Returns:
      (undocumented)
    • aggregationDepth

      public final IntParam aggregationDepth()
      Description copied from interface: HasAggregationDepth
      Param for suggested depth for treeAggregate (&gt;= 2).
      Specified by:
      aggregationDepth in interface HasAggregationDepth
      Returns:
      (undocumented)
    • fitIntercept

      public final BooleanParam fitIntercept()
      Description copied from interface: HasFitIntercept
      Param for whether to fit an intercept term.
      Specified by:
      fitIntercept in interface HasFitIntercept
      Returns:
      (undocumented)
    • tol

      public final DoubleParam tol()
      Description copied from interface: HasTol
      Param for the convergence tolerance for iterative algorithms (&gt;= 0).
      Specified by:
      tol in interface HasTol
      Returns:
      (undocumented)
    • maxIter

      public final IntParam maxIter()
      Description copied from interface: HasMaxIter
      Param for maximum number of iterations (&gt;= 0).
      Specified by:
      maxIter in interface HasMaxIter
      Returns:
      (undocumented)
    • uid

      public String uid()
      Description copied from interface: Identifiable
      An immutable unique ID for the object and its derivatives.
      Specified by:
      uid in interface Identifiable
      Returns:
      (undocumented)
    • setCensorCol

      public AFTSurvivalRegression setCensorCol(String value)
    • setQuantileProbabilities

      public AFTSurvivalRegression setQuantileProbabilities(double[] value)
    • setQuantilesCol

      public AFTSurvivalRegression setQuantilesCol(String value)
    • setFitIntercept

      public AFTSurvivalRegression setFitIntercept(boolean value)
      Set if we should fit the intercept Default is true.
      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setMaxIter

      public AFTSurvivalRegression setMaxIter(int value)
      Set the maximum number of iterations. Default is 100.
      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setTol

      public AFTSurvivalRegression setTol(double value)
      Set the convergence tolerance of iterations. Smaller value will lead to higher accuracy with the cost of more iterations. Default is 1E-6.
      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setAggregationDepth

      public AFTSurvivalRegression setAggregationDepth(int value)
      Suggested depth for treeAggregate (greater than or equal to 2). If the dimensions of features or the number of partitions are large, this param could be adjusted to a larger size. Default is 2.
      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • setMaxBlockSizeInMB

      public AFTSurvivalRegression setMaxBlockSizeInMB(double value)
      Sets the value of param maxBlockSizeInMB(). Default is 0.0, then 1.0 MB will be chosen.

      Parameters:
      value - (undocumented)
      Returns:
      (undocumented)
    • transformSchema

      public StructType transformSchema(StructType schema)
      Description copied from class: PipelineStage
      Check transform validity and derive the output schema from the input schema.

      We check validity for interactions between parameters during transformSchema and raise an exception if any parameter value is invalid. Parameter value checks which do not depend on other parameters are handled by Param.validate().

      Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks.

      Overrides:
      transformSchema in class Predictor<Vector,AFTSurvivalRegression,AFTSurvivalRegressionModel>
      Parameters:
      schema - (undocumented)
      Returns:
      (undocumented)
    • copy

      public AFTSurvivalRegression copy(ParamMap extra)
      Description copied from interface: Params
      Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. See defaultCopy().
      Specified by:
      copy in interface Params
      Specified by:
      copy in class Predictor<Vector,AFTSurvivalRegression,AFTSurvivalRegressionModel>
      Parameters:
      extra - (undocumented)
      Returns:
      (undocumented)