Source code for pyspark.ml.tuning

#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import itertools
import sys
from multiprocessing.pool import ThreadPool

import numpy as np

from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.common import _py2java
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasCollectSubModels, HasParallelism, HasSeed
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import rand

__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
           'TrainValidationSplitModel']


def _parallelFitTasks(est, train, eva, validation, epm, collectSubModel):
    """
    Creates a list of callables which can be called from different threads to fit and evaluate
    an estimator in parallel. Each callable returns an `(index, metric)` pair.

    :param est: Estimator, the estimator to be fit.
    :param train: DataFrame, training data set, used for fitting.
    :param eva: Evaluator, used to compute `metric`
    :param validation: DataFrame, validation data set, used for evaluation.
    :param epm: Sequence of ParamMap, params maps to be used during fitting & evaluation.
    :param collectSubModel: Whether to collect sub model.
    :return: (int, float, subModel), an index into `epm` and the associated metric value.
    """
    modelIter = est.fitMultiple(train, epm)

    def singleTask():
        index, model = next(modelIter)
        metric = eva.evaluate(model.transform(validation, epm[index]))
        return index, metric, model if collectSubModel else None

    return [singleTask] * len(epm)


[docs]class ParamGridBuilder(object): r""" Builder for a param grid used in grid search-based model selection. >>> from pyspark.ml.classification import LogisticRegression >>> lr = LogisticRegression() >>> output = ParamGridBuilder() \ ... .baseOn({lr.labelCol: 'l'}) \ ... .baseOn([lr.predictionCol, 'p']) \ ... .addGrid(lr.regParam, [1.0, 2.0]) \ ... .addGrid(lr.maxIter, [1, 5]) \ ... .build() >>> expected = [ ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] >>> len(output) == len(expected) True >>> all([m in expected for m in output]) True .. versionadded:: 1.4.0 """ def __init__(self): self._param_grid = {}
[docs] @since("1.4.0") def addGrid(self, param, values): """ Sets the given parameters in this grid to fixed values. """ self._param_grid[param] = values return self
[docs] @since("1.4.0") def baseOn(self, *args): """ Sets the given parameters in this grid to fixed values. Accepts either a parameter dictionary or a list of (parameter, value) pairs. """ if isinstance(args[0], dict): self.baseOn(*args[0].items()) else: for (param, value) in args: self.addGrid(param, [value]) return self
[docs] @since("1.4.0") def build(self): """ Builds and returns all combinations of parameters specified by the param grid. """ keys = self._param_grid.keys() grid_values = self._param_grid.values() def to_key_value_pairs(keys, values): return [(key, key.typeConverter(value)) for key, value in zip(keys, values)] return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)]
class ValidatorParams(HasSeed): """ Common params for TrainValidationSplit and CrossValidator. """ estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") evaluator = Param( Params._dummy(), "evaluator", "evaluator used to select hyper-parameters that maximize the validator metric") def setEstimator(self, value): """ Sets the value of :py:attr:`estimator`. """ return self._set(estimator=value) def getEstimator(self): """ Gets the value of estimator or its default value. """ return self.getOrDefault(self.estimator) def setEstimatorParamMaps(self, value): """ Sets the value of :py:attr:`estimatorParamMaps`. """ return self._set(estimatorParamMaps=value) def getEstimatorParamMaps(self): """ Gets the value of estimatorParamMaps or its default value. """ return self.getOrDefault(self.estimatorParamMaps) def setEvaluator(self, value): """ Sets the value of :py:attr:`evaluator`. """ return self._set(evaluator=value) def getEvaluator(self): """ Gets the value of evaluator or its default value. """ return self.getOrDefault(self.evaluator) @classmethod def _from_java_impl(cls, java_stage): """ Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams. """ # Load information from java_stage to the instance. estimator = JavaParams._from_java(java_stage.getEstimator()) evaluator = JavaParams._from_java(java_stage.getEvaluator()) epms = [estimator._transfer_param_map_from_java(epm) for epm in java_stage.getEstimatorParamMaps()] return estimator, epms, evaluator def _to_java_impl(self): """ Return Java estimator, estimatorParamMaps, and evaluator from this Python instance. """ gateway = SparkContext._gateway cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps())) for idx, epm in enumerate(self.getEstimatorParamMaps()): java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm) java_estimator = self.getEstimator()._to_java() java_evaluator = self.getEvaluator()._to_java() return java_estimator, java_epms, java_evaluator
[docs]class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, MLReadable, MLWritable): """ K-fold cross validation performs model selection by splitting the dataset into a set of non-overlapping randomly partitioned folds which are used as separate training and test datasets e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the test set exactly once. >>> from pyspark.ml.classification import LogisticRegression >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator >>> from pyspark.ml.linalg import Vectors >>> dataset = spark.createDataFrame( ... [(Vectors.dense([0.0]), 0.0), ... (Vectors.dense([0.4]), 1.0), ... (Vectors.dense([0.5]), 0.0), ... (Vectors.dense([0.6]), 1.0), ... (Vectors.dense([1.0]), 1.0)] * 10, ... ["features", "label"]) >>> lr = LogisticRegression() >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, ... parallelism=2) >>> cvModel = cv.fit(dataset) >>> cvModel.avgMetrics[0] 0.5 >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... .. versionadded:: 1.4.0 """ numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation", typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, seed=None, parallelism=1, collectSubModels=False): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ seed=None, parallelism=1, collectSubModels=False) """ super(CrossValidator, self).__init__() self._setDefault(numFolds=3, parallelism=1) kwargs = self._input_kwargs self._set(**kwargs)
[docs] @keyword_only @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, seed=None, parallelism=1, collectSubModels=False): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ seed=None, parallelism=1, collectSubModels=False): Sets params for cross validator. """ kwargs = self._input_kwargs return self._set(**kwargs)
[docs] @since("1.4.0") def setNumFolds(self, value): """ Sets the value of :py:attr:`numFolds`. """ return self._set(numFolds=value)
[docs] @since("1.4.0") def getNumFolds(self): """ Gets the value of numFolds or its default value. """ return self.getOrDefault(self.numFolds)
def _fit(self, dataset): est = self.getOrDefault(self.estimator) epm = self.getOrDefault(self.estimatorParamMaps) numModels = len(epm) eva = self.getOrDefault(self.evaluator) nFolds = self.getOrDefault(self.numFolds) seed = self.getOrDefault(self.seed) h = 1.0 / nFolds randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) metrics = [0.0] * numModels pool = ThreadPool(processes=min(self.getParallelism(), numModels)) subModels = None collectSubModelsParam = self.getCollectSubModels() if collectSubModelsParam: subModels = [[None for j in range(numModels)] for i in range(nFolds)] for i in range(nFolds): validateLB = i * h validateUB = (i + 1) * h condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB) validation = df.filter(condition).cache() train = df.filter(~condition).cache() tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] += (metric / nFolds) if collectSubModelsParam: subModels[i][j] = subModel validation.unpersist() train.unpersist() if eva.isLargerBetter(): bestIndex = np.argmax(metrics) else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) return self._copyValues(CrossValidatorModel(bestModel, metrics, subModels))
[docs] @since("1.4.0") def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ if extra is None: extra = dict() newCV = Params.copy(self, extra) if self.isSet(self.estimator): newCV.setEstimator(self.getEstimator().copy(extra)) # estimatorParamMaps remain the same if self.isSet(self.evaluator): newCV.setEvaluator(self.getEvaluator().copy(extra)) return newCV
[docs] @since("2.3.0") def write(self): """Returns an MLWriter instance for this ML instance.""" return JavaMLWriter(self)
[docs] @classmethod @since("2.3.0") def read(cls): """Returns an MLReader instance for this class.""" return JavaMLReader(cls)
@classmethod def _from_java(cls, java_stage): """ Given a Java CrossValidator, create and return a Python wrapper of it. Used for ML persistence. """ estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage) numFolds = java_stage.getNumFolds() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, numFolds=numFolds, seed=seed, parallelism=parallelism, collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage def _to_java(self): """ Transfer this instance to a Java CrossValidator. Used for ML persistence. :return: Java object equivalent to this instance. """ estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) _java_obj.setSeed(self.getSeed()) _java_obj.setNumFolds(self.getNumFolds()) _java_obj.setParallelism(self.getParallelism()) _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj
[docs]class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): """ CrossValidatorModel contains the model with the highest average cross-validation metric across folds and uses this model to transform input data. CrossValidatorModel also tracks the metrics for each param map evaluated. .. versionadded:: 1.4.0 """ def __init__(self, bestModel, avgMetrics=[], subModels=None): super(CrossValidatorModel, self).__init__() #: best model from cross validation self.bestModel = bestModel #: Average cross-validation metrics for each paramMap in #: CrossValidator.estimatorParamMaps, in the corresponding order. self.avgMetrics = avgMetrics #: sub model list from cross validation self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset)
[docs] @since("1.4.0") def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ if extra is None: extra = dict() bestModel = self.bestModel.copy(extra) avgMetrics = self.avgMetrics subModels = self.subModels return CrossValidatorModel(bestModel, avgMetrics, subModels)
[docs] @since("2.3.0") def write(self): """Returns an MLWriter instance for this ML instance.""" return JavaMLWriter(self)
[docs] @classmethod @since("2.3.0") def read(cls): """Returns an MLReader instance for this class.""" return JavaMLReader(cls)
@classmethod def _from_java(cls, java_stage): """ Given a Java CrossValidatorModel, create and return a Python wrapper of it. Used for ML persistence. """ bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) if java_stage.hasSubModels(): py_stage.subModels = [[JavaParams._from_java(sub_model) for sub_model in fold_sub_models] for fold_sub_models in java_stage.subModels()] py_stage._resetUid(java_stage.uid()) return py_stage def _to_java(self): """ Transfer this instance to a Java CrossValidatorModel. Used for ML persistence. :return: Java object equivalent to this instance. """ sc = SparkContext._active_spark_context # TODO: persist average metrics as well _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", self.uid, self.bestModel._to_java(), _py2java(sc, [])) estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) if self.subModels is not None: java_sub_models = [[sub_model._to_java() for sub_model in fold_sub_models] for fold_sub_models in self.subModels] _java_obj.setSubModels(java_sub_models) return _java_obj
[docs]class TrainValidationSplit(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, MLReadable, MLWritable): """ .. note:: Experimental Validation for hyper-parameter tuning. Randomly splits the input dataset into train and validation sets, and uses evaluation metric on the validation set to select the best model. Similar to :class:`CrossValidator`, but only splits the set once. >>> from pyspark.ml.classification import LogisticRegression >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator >>> from pyspark.ml.linalg import Vectors >>> dataset = spark.createDataFrame( ... [(Vectors.dense([0.0]), 0.0), ... (Vectors.dense([0.4]), 1.0), ... (Vectors.dense([0.5]), 0.0), ... (Vectors.dense([0.6]), 1.0), ... (Vectors.dense([1.0]), 1.0)] * 10, ... ["features", "label"]) >>> lr = LogisticRegression() >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, ... parallelism=2) >>> tvsModel = tvs.fit(dataset) >>> evaluator.evaluate(tvsModel.transform(dataset)) 0.8333... .. versionadded:: 2.0.0 """ trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\ validation data. Must be between 0 and 1.", typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ parallelism=1, collectSubModels=False, seed=None) """ super(TrainValidationSplit, self).__init__() self._setDefault(trainRatio=0.75, parallelism=1) kwargs = self._input_kwargs self._set(**kwargs)
[docs] @since("2.0.0") @keyword_only def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, trainRatio=0.75,\ parallelism=1, collectSubModels=False, seed=None): Sets params for the train validation split. """ kwargs = self._input_kwargs return self._set(**kwargs)
[docs] @since("2.0.0") def setTrainRatio(self, value): """ Sets the value of :py:attr:`trainRatio`. """ return self._set(trainRatio=value)
[docs] @since("2.0.0") def getTrainRatio(self): """ Gets the value of trainRatio or its default value. """ return self.getOrDefault(self.trainRatio)
def _fit(self, dataset): est = self.getOrDefault(self.estimator) epm = self.getOrDefault(self.estimatorParamMaps) numModels = len(epm) eva = self.getOrDefault(self.evaluator) tRatio = self.getOrDefault(self.trainRatio) seed = self.getOrDefault(self.seed) randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) condition = (df[randCol] >= tRatio) validation = df.filter(condition).cache() train = df.filter(~condition).cache() subModels = None collectSubModelsParam = self.getCollectSubModels() if collectSubModelsParam: subModels = [None for i in range(numModels)] tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) pool = ThreadPool(processes=min(self.getParallelism(), numModels)) metrics = [None] * numModels for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] = metric if collectSubModelsParam: subModels[j] = subModel train.unpersist() validation.unpersist() if eva.isLargerBetter(): bestIndex = np.argmax(metrics) else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) return self._copyValues(TrainValidationSplitModel(bestModel, metrics, subModels))
[docs] @since("2.0.0") def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ if extra is None: extra = dict() newTVS = Params.copy(self, extra) if self.isSet(self.estimator): newTVS.setEstimator(self.getEstimator().copy(extra)) # estimatorParamMaps remain the same if self.isSet(self.evaluator): newTVS.setEvaluator(self.getEvaluator().copy(extra)) return newTVS
[docs] @since("2.3.0") def write(self): """Returns an MLWriter instance for this ML instance.""" return JavaMLWriter(self)
[docs] @classmethod @since("2.3.0") def read(cls): """Returns an MLReader instance for this class.""" return JavaMLReader(cls)
@classmethod def _from_java(cls, java_stage): """ Given a Java TrainValidationSplit, create and return a Python wrapper of it. Used for ML persistence. """ estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage) trainRatio = java_stage.getTrainRatio() seed = java_stage.getSeed() parallelism = java_stage.getParallelism() collectSubModels = java_stage.getCollectSubModels() # Create a new instance of this stage. py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, trainRatio=trainRatio, seed=seed, parallelism=parallelism, collectSubModels=collectSubModels) py_stage._resetUid(java_stage.uid()) return py_stage def _to_java(self): """ Transfer this instance to a Java TrainValidationSplit. Used for ML persistence. :return: Java object equivalent to this instance. """ estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) _java_obj.setTrainRatio(self.getTrainRatio()) _java_obj.setSeed(self.getSeed()) _java_obj.setParallelism(self.getParallelism()) _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj
[docs]class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): """ .. note:: Experimental Model from train validation split. .. versionadded:: 2.0.0 """ def __init__(self, bestModel, validationMetrics=[], subModels=None): super(TrainValidationSplitModel, self).__init__() #: best model from train validation split self.bestModel = bestModel #: evaluated validation metrics self.validationMetrics = validationMetrics #: sub models from train validation split self.subModels = subModels def _transform(self, dataset): return self.bestModel.transform(dataset)
[docs] @since("2.0.0") def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. And, this creates a shallow copy of the validationMetrics. It does not copy the extra Params into the subModels. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ if extra is None: extra = dict() bestModel = self.bestModel.copy(extra) validationMetrics = list(self.validationMetrics) subModels = self.subModels return TrainValidationSplitModel(bestModel, validationMetrics, subModels)
[docs] @since("2.3.0") def write(self): """Returns an MLWriter instance for this ML instance.""" return JavaMLWriter(self)
[docs] @classmethod @since("2.3.0") def read(cls): """Returns an MLReader instance for this class.""" return JavaMLReader(cls)
@classmethod def _from_java(cls, java_stage): """ Given a Java TrainValidationSplitModel, create and return a Python wrapper of it. Used for ML persistence. """ # Load information from java_stage to the instance. bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. py_stage = cls(bestModel=bestModel).setEstimator(estimator) py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator) if java_stage.hasSubModels(): py_stage.subModels = [JavaParams._from_java(sub_model) for sub_model in java_stage.subModels()] py_stage._resetUid(java_stage.uid()) return py_stage def _to_java(self): """ Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence. :return: Java object equivalent to this instance. """ sc = SparkContext._active_spark_context # TODO: persst validation metrics as well _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplitModel", self.uid, self.bestModel._to_java(), _py2java(sc, [])) estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) _java_obj.set("estimator", estimator) _java_obj.set("estimatorParamMaps", epms) if self.subModels is not None: java_sub_models = [sub_model._to_java() for sub_model in self.subModels] _java_obj.setSubModels(java_sub_models) return _java_obj
if __name__ == "__main__": import doctest from pyspark.sql import SparkSession globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: spark = SparkSession.builder\ .master("local[2]")\ .appName("ml.tuning tests")\ .getOrCreate() sc = spark.sparkContext globs['sc'] = sc globs['spark'] = spark (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) spark.stop() if failure_count: sys.exit(-1)