1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.spark.examples.ml; 19 20 import org.apache.spark.sql.Dataset; 21 import org.apache.spark.sql.Row; 22 import org.apache.spark.sql.SparkSession; 23 24 // $example on$ 25 import java.io.Serializable; 26 27 import org.apache.spark.api.java.JavaRDD; 28 import org.apache.spark.api.java.function.Function; 29 import org.apache.spark.ml.evaluation.RegressionEvaluator; 30 import org.apache.spark.ml.recommendation.ALS; 31 import org.apache.spark.ml.recommendation.ALSModel; 32 // $example off$ 33 34 public class JavaALSExample { 35 36 // $example on$ 37 public static class Rating implements Serializable { 38 private int userId; 39 private int movieId; 40 private float rating; 41 private long timestamp; 42 Rating()43 public Rating() {} 44 Rating(int userId, int movieId, float rating, long timestamp)45 public Rating(int userId, int movieId, float rating, long timestamp) { 46 this.userId = userId; 47 this.movieId = movieId; 48 this.rating = rating; 49 this.timestamp = timestamp; 50 } 51 getUserId()52 public int getUserId() { 53 return userId; 54 } 55 getMovieId()56 public int getMovieId() { 57 return movieId; 58 } 59 getRating()60 public float getRating() { 61 return rating; 62 } 63 getTimestamp()64 public long getTimestamp() { 65 return timestamp; 66 } 67 parseRating(String str)68 public static Rating parseRating(String str) { 69 String[] fields = str.split("::"); 70 if (fields.length != 4) { 71 throw new IllegalArgumentException("Each line must contain 4 fields"); 72 } 73 int userId = Integer.parseInt(fields[0]); 74 int movieId = Integer.parseInt(fields[1]); 75 float rating = Float.parseFloat(fields[2]); 76 long timestamp = Long.parseLong(fields[3]); 77 return new Rating(userId, movieId, rating, timestamp); 78 } 79 } 80 // $example off$ 81 main(String[] args)82 public static void main(String[] args) { 83 SparkSession spark = SparkSession 84 .builder() 85 .appName("JavaALSExample") 86 .getOrCreate(); 87 88 // $example on$ 89 JavaRDD<Rating> ratingsRDD = spark 90 .read().textFile("data/mllib/als/sample_movielens_ratings.txt").javaRDD() 91 .map(new Function<String, Rating>() { 92 public Rating call(String str) { 93 return Rating.parseRating(str); 94 } 95 }); 96 Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class); 97 Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2}); 98 Dataset<Row> training = splits[0]; 99 Dataset<Row> test = splits[1]; 100 101 // Build the recommendation model using ALS on the training data 102 ALS als = new ALS() 103 .setMaxIter(5) 104 .setRegParam(0.01) 105 .setUserCol("userId") 106 .setItemCol("movieId") 107 .setRatingCol("rating"); 108 ALSModel model = als.fit(training); 109 110 // Evaluate the model by computing the RMSE on the test data 111 Dataset<Row> predictions = model.transform(test); 112 113 RegressionEvaluator evaluator = new RegressionEvaluator() 114 .setMetricName("rmse") 115 .setLabelCol("rating") 116 .setPredictionCol("prediction"); 117 Double rmse = evaluator.evaluate(predictions); 118 System.out.println("Root-mean-square error = " + rmse); 119 // $example off$ 120 spark.stop(); 121 } 122 } 123