Interface TopologyModel

All Superinterfaces:
Serializable

public interface TopologyModel extends Serializable
Trait for ANN topology model
  • Method Details

    • computeGradient

      double computeGradient(breeze.linalg.DenseMatrix<Object> data, breeze.linalg.DenseMatrix<Object> target, Vector cumGradient, int blockSize)
      Computes gradient for the network

      Parameters:
      data - input data
      target - target output
      cumGradient - cumulative gradient
      blockSize - block size
      Returns:
      error
    • forward

      breeze.linalg.DenseMatrix<Object>[] forward(breeze.linalg.DenseMatrix<Object> data, boolean includeLastLayer)
      Forward propagation

      Parameters:
      data - input data
      includeLastLayer - Include the last layer in the output. In MultilayerPerceptronClassifier, the last layer is always softmax; the last layer of outputs is needed for class predictions, but not for rawPrediction.

      Returns:
      array of outputs for each of the layers
    • layerModels

      LayerModel[] layerModels()
      Array of layer models
      Returns:
      (undocumented)
    • layers

      Layer[] layers()
      Array of layers
      Returns:
      (undocumented)
    • predict

      Vector predict(Vector features)
      Prediction of the model. See ProbabilisticClassificationModel

      Parameters:
      features - input features
      Returns:
      prediction
    • predictRaw

      Vector predictRaw(Vector features)
      Raw prediction of the model. See ProbabilisticClassificationModel

      Parameters:
      features - input features
      Returns:
      raw prediction

      Note: This interface is only used for classification Model.

    • raw2ProbabilityInPlace

      Vector raw2ProbabilityInPlace(Vector rawPrediction)
      Probability of the model. See ProbabilisticClassificationModel

      Parameters:
      rawPrediction - raw prediction vector
      Returns:
      probability

      Note: This interface is only used for classification Model.

    • weights

      Vector weights()