1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 package org.apache.spark.examples.ml;
19 
20 import org.apache.spark.sql.Dataset;
21 import org.apache.spark.sql.Row;
22 import org.apache.spark.sql.SparkSession;
23 
24 // $example on$
25 import java.io.Serializable;
26 
27 import org.apache.spark.api.java.JavaRDD;
28 import org.apache.spark.api.java.function.Function;
29 import org.apache.spark.ml.evaluation.RegressionEvaluator;
30 import org.apache.spark.ml.recommendation.ALS;
31 import org.apache.spark.ml.recommendation.ALSModel;
32 // $example off$
33 
34 public class JavaALSExample {
35 
36   // $example on$
37   public static class Rating implements Serializable {
38     private int userId;
39     private int movieId;
40     private float rating;
41     private long timestamp;
42 
Rating()43     public Rating() {}
44 
Rating(int userId, int movieId, float rating, long timestamp)45     public Rating(int userId, int movieId, float rating, long timestamp) {
46       this.userId = userId;
47       this.movieId = movieId;
48       this.rating = rating;
49       this.timestamp = timestamp;
50     }
51 
getUserId()52     public int getUserId() {
53       return userId;
54     }
55 
getMovieId()56     public int getMovieId() {
57       return movieId;
58     }
59 
getRating()60     public float getRating() {
61       return rating;
62     }
63 
getTimestamp()64     public long getTimestamp() {
65       return timestamp;
66     }
67 
parseRating(String str)68     public static Rating parseRating(String str) {
69       String[] fields = str.split("::");
70       if (fields.length != 4) {
71         throw new IllegalArgumentException("Each line must contain 4 fields");
72       }
73       int userId = Integer.parseInt(fields[0]);
74       int movieId = Integer.parseInt(fields[1]);
75       float rating = Float.parseFloat(fields[2]);
76       long timestamp = Long.parseLong(fields[3]);
77       return new Rating(userId, movieId, rating, timestamp);
78     }
79   }
80   // $example off$
81 
main(String[] args)82   public static void main(String[] args) {
83     SparkSession spark = SparkSession
84       .builder()
85       .appName("JavaALSExample")
86       .getOrCreate();
87 
88     // $example on$
89     JavaRDD<Rating> ratingsRDD = spark
90       .read().textFile("data/mllib/als/sample_movielens_ratings.txt").javaRDD()
91       .map(new Function<String, Rating>() {
92         public Rating call(String str) {
93           return Rating.parseRating(str);
94         }
95       });
96     Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);
97     Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
98     Dataset<Row> training = splits[0];
99     Dataset<Row> test = splits[1];
100 
101     // Build the recommendation model using ALS on the training data
102     ALS als = new ALS()
103       .setMaxIter(5)
104       .setRegParam(0.01)
105       .setUserCol("userId")
106       .setItemCol("movieId")
107       .setRatingCol("rating");
108     ALSModel model = als.fit(training);
109 
110     // Evaluate the model by computing the RMSE on the test data
111     Dataset<Row> predictions = model.transform(test);
112 
113     RegressionEvaluator evaluator = new RegressionEvaluator()
114       .setMetricName("rmse")
115       .setLabelCol("rating")
116       .setPredictionCol("prediction");
117     Double rmse = evaluator.evaluate(predictions);
118     System.out.println("Root-mean-square error = " + rmse);
119     // $example off$
120     spark.stop();
121   }
122 }
123