Package pyspark :: Package mllib :: Module clustering
[frames] | no frames]

Source Code for Module pyspark.mllib.clustering

 1  # 
 2  # Licensed to the Apache Software Foundation (ASF) under one or more 
 3  # contributor license agreements.  See the NOTICE file distributed with 
 4  # this work for additional information regarding copyright ownership. 
 5  # The ASF licenses this file to You under the Apache License, Version 2.0 
 6  # (the "License"); you may not use this file except in compliance with 
 7  # the License.  You may obtain a copy of the License at 
 8  # 
 9  #    http://www.apache.org/licenses/LICENSE-2.0 
10  # 
11  # Unless required by applicable law or agreed to in writing, software 
12  # distributed under the License is distributed on an "AS IS" BASIS, 
13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
14  # See the License for the specific language governing permissions and 
15  # limitations under the License. 
16  # 
17   
18  from numpy import array, dot 
19  from math import sqrt 
20  from pyspark import SparkContext 
21  from pyspark.mllib._common import \ 
22      _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ 
23      _serialize_double_matrix, _deserialize_double_matrix, \ 
24      _serialize_double_vector, _deserialize_double_vector, \ 
25      _get_initial_weights, _serialize_rating, _regression_train_wrapper 
26 27 -class KMeansModel(object):
28 """A clustering model derived from the k-means method. 29 30 >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) 31 >>> clusters = KMeans.train(sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random") 32 >>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0])) 33 True 34 >>> clusters.predict(array([8.0, 9.0])) == clusters.predict(array([9.0, 8.0])) 35 True 36 >>> clusters = KMeans.train(sc.parallelize(data), 2) 37 """
38 - def __init__(self, centers_):
39 self.centers = centers_
40
41 - def predict(self, x):
42 """Find the cluster to which x belongs in this model.""" 43 best = 0 44 best_distance = 1e75 45 for i in range(0, self.centers.shape[0]): 46 diff = x - self.centers[i] 47 distance = sqrt(dot(diff, diff)) 48 if distance < best_distance: 49 best = i 50 best_distance = distance 51 return best
52
53 -class KMeans(object):
54 @classmethod
55 - def train(cls, data, k, maxIterations=100, runs=1, 56 initializationMode="k-means||"):
57 """Train a k-means clustering model.""" 58 sc = data.context 59 dataBytes = _get_unmangled_double_vector_rdd(data) 60 ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd, 61 k, maxIterations, runs, initializationMode) 62 if len(ans) != 1: 63 raise RuntimeError("JVM call result had unexpected length") 64 elif type(ans[0]) != bytearray: 65 raise RuntimeError("JVM call result had first element of type " 66 + type(ans[0]) + " which is not bytearray") 67 return KMeansModel(_deserialize_double_matrix(ans[0]))
68
69 -def _test():
70 import doctest 71 globs = globals().copy() 72 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 73 (failure_count, test_count) = doctest.testmod(globs=globs, 74 optionflags=doctest.ELLIPSIS) 75 globs['sc'].stop() 76 if failure_count: 77 exit(-1)
78 79 if __name__ == "__main__": 80 _test() 81