Package pyspark :: Package mllib :: Module recommendation
[frames] | no frames]

Source Code for Module pyspark.mllib.recommendation

 1  # 
 2  # Licensed to the Apache Software Foundation (ASF) under one or more 
 3  # contributor license agreements.  See the NOTICE file distributed with 
 4  # this work for additional information regarding copyright ownership. 
 5  # The ASF licenses this file to You under the Apache License, Version 2.0 
 6  # (the "License"); you may not use this file except in compliance with 
 7  # the License.  You may obtain a copy of the License at 
 8  # 
 9  #    http://www.apache.org/licenses/LICENSE-2.0 
10  # 
11  # Unless required by applicable law or agreed to in writing, software 
12  # distributed under the License is distributed on an "AS IS" BASIS, 
13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
14  # See the License for the specific language governing permissions and 
15  # limitations under the License. 
16  # 
17   
18  from pyspark import SparkContext 
19  from pyspark.mllib._common import \ 
20      _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ 
21      _serialize_double_matrix, _deserialize_double_matrix, \ 
22      _serialize_double_vector, _deserialize_double_vector, \ 
23      _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ 
24      _serialize_tuple, RatingDeserializer 
25  from pyspark.rdd import RDD 
26 27 28 -class MatrixFactorizationModel(object):
29 """A matrix factorisation model trained by regularized alternating 30 least-squares. 31 32 >>> r1 = (1, 1, 1.0) 33 >>> r2 = (1, 2, 2.0) 34 >>> r3 = (2, 1, 2.0) 35 >>> ratings = sc.parallelize([r1, r2, r3]) 36 >>> model = ALS.trainImplicit(ratings, 1) 37 >>> model.predict(2,2) is not None 38 True 39 >>> testset = sc.parallelize([(1, 2), (1, 1)]) 40 >>> model.predictAll(testset).count() == 2 41 True 42 """ 43
44 - def __init__(self, sc, java_model):
45 self._context = sc 46 self._java_model = java_model
47
48 - def __del__(self):
49 self._context._gateway.detach(self._java_model)
50
51 - def predict(self, user, product):
52 return self._java_model.predict(user, product)
53
54 - def predictAll(self, usersProducts):
55 usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple) 56 return RDD(self._java_model.predict(usersProductsJRDD._jrdd), 57 self._context, RatingDeserializer())
58
59 60 -class ALS(object):
61 @classmethod
62 - def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
63 sc = ratings.context 64 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) 65 mod = sc._jvm.PythonMLLibAPI().trainALSModel( 66 ratingBytes._jrdd, rank, iterations, lambda_, blocks) 67 return MatrixFactorizationModel(sc, mod)
68 69 @classmethod
70 - def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
71 sc = ratings.context 72 ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) 73 mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel( 74 ratingBytes._jrdd, rank, iterations, lambda_, blocks, alpha) 75 return MatrixFactorizationModel(sc, mod)
76
77 78 -def _test():
79 import doctest 80 globs = globals().copy() 81 globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 82 (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) 83 globs['sc'].stop() 84 if failure_count: 85 exit(-1)
86 87 88 if __name__ == "__main__": 89 _test() 90