Package pyspark :: Package mllib :: Module recommendation
[frames] | no frames]

Source Code for Module pyspark.mllib.recommendation

 1  # 
 2  # Licensed to the Apache Software Foundation (ASF) under one or more 
 3  # contributor license agreements.  See the NOTICE file distributed with 
 4  # this work for additional information regarding copyright ownership. 
 5  # The ASF licenses this file to You under the Apache License, Version 2.0 
 6  # (the "License"); you may not use this file except in compliance with 
 7  # the License.  You may obtain a copy of the License at 
 8  # 
 9  #    http://www.apache.org/licenses/LICENSE-2.0 
10  # 
11  # Unless required by applicable law or agreed to in writing, software 
12  # distributed under the License is distributed on an "AS IS" BASIS, 
13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
14  # See the License for the specific language governing permissions and 
15  # limitations under the License. 
16  # 
17   
18  from pyspark import SparkContext 
19  from pyspark.mllib._common import \ 
20      _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ 
21      _serialize_double_matrix, _deserialize_double_matrix, \ 
22      _serialize_double_vector, _deserialize_double_vector, \ 
23      _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ 
24      _serialize_tuple, RatingDeserializer 
25  from pyspark.rdd import RDD 
26 27 -class MatrixFactorizationModel(object):
28 """A matrix factorisation model trained by regularized alternating 29 least-squares. 30 31 >>> r1 = (1, 1, 1.0) 32 >>> r2 = (1, 2, 2.0) 33 >>> r3 = (2, 1, 2.0) 34 >>> ratings = sc.parallelize([r1, r2, r3]) 35 >>> model = ALS.trainImplicit(ratings, 1) 36 >>> model.predict(2,2) is not None 37 True 38 >>> testset = sc.parallelize([(1, 2), (1, 1)]) 39 >>> model.predictAll(testset).count() == 2 40 True 41 """ 42
43 - def __init__(self, sc, java_model):
44 self._context = sc 45 self._java_model = java_model
46
47 - def __del__(self):
48 self._context._gateway.detach(self._java_model)
49
50 - def predict(self, user, product):
51 return self._java_model.predict(user, product)
52
53 - def predictAll(self, usersProducts):
54 usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple) 55 return RDD(self._java_model.predict(usersProductsJRDD._jrdd), 56 self._context, RatingDeserializer())
57
58 -class ALS(object):
59 @classmethod
60 - def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
61 sc = ratings.context 62 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) 63 mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd, 64 rank, iterations, lambda_, blocks) 65 return MatrixFactorizationModel(sc, mod)
66 67 @classmethod
68 - def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
69 sc = ratings.context 70 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) 71 mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd, 72 rank, iterations, lambda_, blocks, alpha) 73 return MatrixFactorizationModel(sc, mod)
74
75 -def _test():
76 import doctest 77 globs = globals().copy() 78 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 79 (failure_count, test_count) = doctest.testmod(globs=globs, 80 optionflags=doctest.ELLIPSIS) 81 globs['sc'].stop() 82 if failure_count: 83 exit(-1)
84 85 if __name__ == "__main__": 86 _test() 87