Class PredictionModel<FeaturesType,M extends PredictionModel<FeaturesType,M>>

Type Parameters:
FeaturesType - Type of features. E.g., VectorUDT for vector features.
M - Specialization of PredictionModel. If you subclass this type, use this type parameter to specify the concrete type for the corresponding model.
All Implemented Interfaces:
Serializable, org.apache.spark.internal.Logging, Params, HasFeaturesCol, HasLabelCol, HasPredictionCol, PredictorParams, Identifiable
Direct Known Subclasses:
ClassificationModel, RegressionModel

public abstract class PredictionModel<FeaturesType,M extends PredictionModel<FeaturesType,M>> extends Model<M> implements PredictorParams
Abstraction for a model for prediction tasks (regression and classification).

See Also:
  • Constructor Details

    • PredictionModel

      public PredictionModel()
  • Method Details

    • featuresCol

      public final Param<String> featuresCol()
      Description copied from interface: HasFeaturesCol
      Param for features column name.
      Specified by:
      featuresCol in interface HasFeaturesCol
      Returns:
      (undocumented)
    • labelCol

      public final Param<String> labelCol()
      Description copied from interface: HasLabelCol
      Param for label column name.
      Specified by:
      labelCol in interface HasLabelCol
      Returns:
      (undocumented)
    • numFeatures

      public int numFeatures()
      Returns the number of features the model was trained on. If unknown, returns -1
    • predict

      public abstract double predict(FeaturesType features)
      Predict label for the given features. This method is used to implement transform() and output predictionCol().
      Parameters:
      features - (undocumented)
      Returns:
      (undocumented)
    • predictionCol

      public final Param<String> predictionCol()
      Description copied from interface: HasPredictionCol
      Param for prediction column name.
      Specified by:
      predictionCol in interface HasPredictionCol
      Returns:
      (undocumented)
    • setFeaturesCol

      public M setFeaturesCol(String value)
    • setPredictionCol

      public M setPredictionCol(String value)
    • transform

      public Dataset<Row> transform(Dataset<?> dataset)
      Transforms dataset by reading from featuresCol(), calling predict, and storing the predictions as a new column predictionCol().

      Specified by:
      transform in class Transformer
      Parameters:
      dataset - input dataset
      Returns:
      transformed dataset with predictionCol() of type Double
    • transformSchema

      public StructType transformSchema(StructType schema)
      Description copied from class: PipelineStage
      Check transform validity and derive the output schema from the input schema.

      We check validity for interactions between parameters during transformSchema and raise an exception if any parameter value is invalid. Parameter value checks which do not depend on other parameters are handled by Param.validate().

      Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks.

      Specified by:
      transformSchema in class PipelineStage
      Parameters:
      schema - (undocumented)
      Returns:
      (undocumented)