Package pyspark :: Package mllib :: Module recommendation
[frames] | no frames]

Source Code for Module pyspark.mllib.recommendation

 1  # 
 2  # Licensed to the Apache Software Foundation (ASF) under one or more 
 3  # contributor license agreements.  See the NOTICE file distributed with 
 4  # this work for additional information regarding copyright ownership. 
 5  # The ASF licenses this file to You under the Apache License, Version 2.0 
 6  # (the "License"); you may not use this file except in compliance with 
 7  # the License.  You may obtain a copy of the License at 
 8  # 
 9  #    http://www.apache.org/licenses/LICENSE-2.0 
10  # 
11  # Unless required by applicable law or agreed to in writing, software 
12  # distributed under the License is distributed on an "AS IS" BASIS, 
13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
14  # See the License for the specific language governing permissions and 
15  # limitations under the License. 
16  # 
17   
18  from pyspark import SparkContext 
19  from pyspark.mllib._common import \ 
20      _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ 
21      _serialize_double_matrix, _deserialize_double_matrix, \ 
22      _serialize_double_vector, _deserialize_double_vector, \ 
23      _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ 
24      _serialize_tuple, RatingDeserializer 
25  from pyspark.rdd import RDD 
26 27 28 -class MatrixFactorizationModel(object):
29 30 """A matrix factorisation model trained by regularized alternating 31 least-squares. 32 33 >>> r1 = (1, 1, 1.0) 34 >>> r2 = (1, 2, 2.0) 35 >>> r3 = (2, 1, 2.0) 36 >>> ratings = sc.parallelize([r1, r2, r3]) 37 >>> model = ALS.trainImplicit(ratings, 1) 38 >>> model.predict(2,2) is not None 39 True 40 >>> testset = sc.parallelize([(1, 2), (1, 1)]) 41 >>> model.predictAll(testset).count() == 2 42 True 43 """ 44
45 - def __init__(self, sc, java_model):
46 self._context = sc 47 self._java_model = java_model
48
49 - def __del__(self):
50 self._context._gateway.detach(self._java_model)
51
52 - def predict(self, user, product):
53 return self._java_model.predict(user, product)
54
55 - def predictAll(self, usersProducts):
56 usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple) 57 return RDD(self._java_model.predict(usersProductsJRDD._jrdd), 58 self._context, RatingDeserializer())
59
60 61 -class ALS(object):
62 63 @classmethod
64 - def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
65 sc = ratings.context 66 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) 67 mod = sc._jvm.PythonMLLibAPI().trainALSModel( 68 ratingBytes._jrdd, rank, iterations, lambda_, blocks) 69 return MatrixFactorizationModel(sc, mod)
70 71 @classmethod
72 - def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
73 sc = ratings.context 74 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) 75 mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel( 76 ratingBytes._jrdd, rank, iterations, lambda_, blocks, alpha) 77 return MatrixFactorizationModel(sc, mod)
78
79 80 -def _test():
81 import doctest 82 globs = globals().copy() 83 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 84 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) 85 globs['sc'].stop() 86 if failure_count: 87 exit(-1)
88 89 90 if __name__ == "__main__": 91 _test() 92