Source code for

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import numpy as np

from import Params, Param
from import Estimator, Model
from import keyword_only
from pyspark.sql.functions import rand

__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel']

[docs]class ParamGridBuilder(object): r""" Builder for a param grid used in grid search-based model selection. >>> from import LogisticRegression >>> lr = LogisticRegression() >>> output = ParamGridBuilder() \ ... .baseOn({lr.labelCol: 'l'}) \ ... .baseOn([lr.predictionCol, 'p']) \ ... .addGrid(lr.regParam, [1.0, 2.0]) \ ... .addGrid(lr.maxIter, [1, 5]) \ ... .build() >>> expected = [ ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] >>> len(output) == len(expected) True >>> all([m in expected for m in output]) True """ def __init__(self): self._param_grid = {}
[docs] def addGrid(self, param, values): """ Sets the given parameters in this grid to fixed values. """ self._param_grid[param] = values return self
[docs] def baseOn(self, *args): """ Sets the given parameters in this grid to fixed values. Accepts either a parameter dictionary or a list of (parameter, value) pairs. """ if isinstance(args[0], dict): self.baseOn(*args[0].items()) else: for (param, value) in args: self.addGrid(param, [value]) return self
[docs] def build(self): """ Builds and returns all combinations of parameters specified by the param grid. """ keys = self._param_grid.keys() grid_values = self._param_grid.values() return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
[docs]class CrossValidator(Estimator): """ K-fold cross validation. >>> from import LogisticRegression >>> from import BinaryClassificationEvaluator >>> from pyspark.mllib.linalg import Vectors >>> dataset = sqlContext.createDataFrame( ... [(Vectors.dense([0.0]), 0.0), ... (Vectors.dense([0.4]), 1.0), ... (Vectors.dense([0.5]), 0.0), ... (Vectors.dense([0.6]), 1.0), ... (Vectors.dense([1.0]), 1.0)] * 10, ... ["features", "label"]) >>> lr = LogisticRegression() >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) >>> cvModel = >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... """ # a placeholder to make it appear in the generated doc estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") # a placeholder to make it appear in the generated doc estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") # a placeholder to make it appear in the generated doc evaluator = Param( Params._dummy(), "evaluator", "evaluator used to select hyper-parameters that maximize the cross-validated metric") # a placeholder to make it appear in the generated doc numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) """ super(CrossValidator, self).__init__() #: param for estimator to be cross-validated self.estimator = Param(self, "estimator", "estimator to be cross-validated") #: param for estimator param maps self.estimatorParamMaps = Param(self, "estimatorParamMaps", "estimator param maps") #: param for the evaluator used to select hyper-parameters that #: maximize the cross-validated metric self.evaluator = Param( self, "evaluator", "evaluator used to select hyper-parameters that maximize the cross-validated metric") #: param for number of folds for cross validation self.numFolds = Param(self, "numFolds", "number of folds for cross validation") self._setDefault(numFolds=3) kwargs = self.__init__._input_kwargs self._set(**kwargs) @keyword_only
[docs] def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): Sets params for cross validator. """ kwargs = self.setParams._input_kwargs return self._set(**kwargs)
[docs] def setEstimator(self, value): """ Sets the value of :py:attr:`estimator`. """ self._paramMap[self.estimator] = value return self
[docs] def getEstimator(self): """ Gets the value of estimator or its default value. """ return self.getOrDefault(self.estimator)
[docs] def setEstimatorParamMaps(self, value): """ Sets the value of :py:attr:`estimatorParamMaps`. """ self._paramMap[self.estimatorParamMaps] = value return self
[docs] def getEstimatorParamMaps(self): """ Gets the value of estimatorParamMaps or its default value. """ return self.getOrDefault(self.estimatorParamMaps)
[docs] def setEvaluator(self, value): """ Sets the value of :py:attr:`evaluator`. """ self._paramMap[self.evaluator] = value return self
[docs] def getEvaluator(self): """ Gets the value of evaluator or its default value. """ return self.getOrDefault(self.evaluator)
[docs] def setNumFolds(self, value): """ Sets the value of :py:attr:`numFolds`. """ self._paramMap[self.numFolds] = value return self
[docs] def getNumFolds(self): """ Gets the value of numFolds or its default value. """ return self.getOrDefault(self.numFolds)
def _fit(self, dataset): est = self.getOrDefault(self.estimator) epm = self.getOrDefault(self.estimatorParamMaps) numModels = len(epm) eva = self.getOrDefault(self.evaluator) nFolds = self.getOrDefault(self.numFolds) h = 1.0 / nFolds randCol = self.uid + "_rand" df ="*", rand(0).alias(randCol)) metrics = np.zeros(numModels) for i in range(nFolds): validateLB = i * h validateUB = (i + 1) * h condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB) validation = df.filter(condition) train = df.filter(~condition) for j in range(numModels): model =, epm[j]) # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) metrics[j] += metric bestIndex = np.argmax(metrics) bestModel =, epm[bestIndex]) return CrossValidatorModel(bestModel)
[docs] def copy(self, extra={}): newCV = Params.copy(self, extra) if self.isSet(self.estimator): newCV.setEstimator(self.getEstimator().copy(extra)) # estimatorParamMaps remain the same if self.isSet(self.evaluator): newCV.setEvaluator(self.getEvaluator().copy(extra)) return newCV
[docs]class CrossValidatorModel(Model): """ Model from k-fold cross validation. """ def __init__(self, bestModel): super(CrossValidatorModel, self).__init__() #: best model from cross validation self.bestModel = bestModel def _transform(self, dataset): return self.bestModel.transform(dataset)
[docs] def copy(self, extra={}): """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ return CrossValidatorModel(self.bestModel.copy(extra))
if __name__ == "__main__": import doctest from pyspark.context import SparkContext from pyspark.sql import SQLContext globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.tuning tests") sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() if failure_count: exit(-1)