Package org.apache.spark.ml.ann
Interface TopologyModel
- All Superinterfaces:
Serializable
Trait for ANN topology model
-
Method Summary
Modifier and TypeMethodDescriptiondouble
computeGradient
(breeze.linalg.DenseMatrix<Object> data, breeze.linalg.DenseMatrix<Object> target, Vector cumGradient, int blockSize) Computes gradient for the networkbreeze.linalg.DenseMatrix<Object>[]
Forward propagationArray of layer modelsLayer[]
layers()
Array of layersPrediction of the model.predictRaw
(Vector features) Raw prediction of the model.raw2ProbabilityInPlace
(Vector rawPrediction) Probability of the model.weights()
-
Method Details
-
computeGradient
double computeGradient(breeze.linalg.DenseMatrix<Object> data, breeze.linalg.DenseMatrix<Object> target, Vector cumGradient, int blockSize) Computes gradient for the network- Parameters:
data
- input datatarget
- target outputcumGradient
- cumulative gradientblockSize
- block size- Returns:
- error
-
forward
breeze.linalg.DenseMatrix<Object>[] forward(breeze.linalg.DenseMatrix<Object> data, boolean includeLastLayer) Forward propagation- Parameters:
data
- input dataincludeLastLayer
- Include the last layer in the output. In MultilayerPerceptronClassifier, the last layer is always softmax; the last layer of outputs is needed for class predictions, but not for rawPrediction.- Returns:
- array of outputs for each of the layers
-
layerModels
LayerModel[] layerModels()Array of layer models- Returns:
- (undocumented)
-
layers
Layer[] layers()Array of layers- Returns:
- (undocumented)
-
predict
Prediction of the model. SeeProbabilisticClassificationModel
- Parameters:
features
- input features- Returns:
- prediction
-
predictRaw
Raw prediction of the model. SeeProbabilisticClassificationModel
- Parameters:
features
- input features- Returns:
- raw prediction
Note: This interface is only used for classification Model.
-
raw2ProbabilityInPlace
Probability of the model. SeeProbabilisticClassificationModel
- Parameters:
rawPrediction
- raw prediction vector- Returns:
- probability
Note: This interface is only used for classification Model.
-
weights
Vector weights()
-