1# 2# Licensed to the Apache Software Foundation (ASF) under one or more 3# contributor license agreements. See the NOTICE file distributed with 4# this work for additional information regarding copyright ownership. 5# The ASF licenses this file to You under the Apache License, Version 2.0 6# (the "License"); you may not use this file except in compliance with 7# the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18from __future__ import print_function 19 20import sys 21if sys.version >= '3': 22 long = int 23 24from pyspark.sql import SparkSession 25 26# $example on$ 27from pyspark.ml.evaluation import RegressionEvaluator 28from pyspark.ml.recommendation import ALS 29from pyspark.sql import Row 30# $example off$ 31 32if __name__ == "__main__": 33 spark = SparkSession\ 34 .builder\ 35 .appName("ALSExample")\ 36 .getOrCreate() 37 38 # $example on$ 39 lines = spark.read.text("data/mllib/als/sample_movielens_ratings.txt").rdd 40 parts = lines.map(lambda row: row.value.split("::")) 41 ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), movieId=int(p[1]), 42 rating=float(p[2]), timestamp=long(p[3]))) 43 ratings = spark.createDataFrame(ratingsRDD) 44 (training, test) = ratings.randomSplit([0.8, 0.2]) 45 46 # Build the recommendation model using ALS on the training data 47 als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating") 48 model = als.fit(training) 49 50 # Evaluate the model by computing the RMSE on the test data 51 predictions = model.transform(test) 52 evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating", 53 predictionCol="prediction") 54 rmse = evaluator.evaluate(predictions) 55 print("Root-mean-square error = " + str(rmse)) 56 # $example off$ 57 spark.stop() 58