1#
2# Licensed to the Apache Software Foundation (ASF) under one or more
3# contributor license agreements.  See the NOTICE file distributed with
4# this work for additional information regarding copyright ownership.
5# The ASF licenses this file to You under the Apache License, Version 2.0
6# (the "License"); you may not use this file except in compliance with
7# the License.  You may obtain a copy of the License at
8#
9#    http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18from __future__ import print_function
19
20import sys
21if sys.version >= '3':
22    long = int
23
24from pyspark.sql import SparkSession
25
26# $example on$
27from pyspark.ml.evaluation import RegressionEvaluator
28from pyspark.ml.recommendation import ALS
29from pyspark.sql import Row
30# $example off$
31
32if __name__ == "__main__":
33    spark = SparkSession\
34        .builder\
35        .appName("ALSExample")\
36        .getOrCreate()
37
38    # $example on$
39    lines = spark.read.text("data/mllib/als/sample_movielens_ratings.txt").rdd
40    parts = lines.map(lambda row: row.value.split("::"))
41    ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), movieId=int(p[1]),
42                                         rating=float(p[2]), timestamp=long(p[3])))
43    ratings = spark.createDataFrame(ratingsRDD)
44    (training, test) = ratings.randomSplit([0.8, 0.2])
45
46    # Build the recommendation model using ALS on the training data
47    als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating")
48    model = als.fit(training)
49
50    # Evaluate the model by computing the RMSE on the test data
51    predictions = model.transform(test)
52    evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating",
53                                    predictionCol="prediction")
54    rmse = evaluator.evaluate(predictions)
55    print("Root-mean-square error = " + str(rmse))
56    # $example off$
57    spark.stop()
58