1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 package org.apache.spark.examples.mllib;
19 
20 // $example on$
21 import java.util.Arrays;
22 
23 import scala.Tuple2;
24 
25 import org.apache.spark.api.java.*;
26 import org.apache.spark.api.java.function.Function;
27 import org.apache.spark.mllib.classification.LogisticRegressionModel;
28 import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
29 import org.apache.spark.mllib.linalg.Vector;
30 import org.apache.spark.mllib.linalg.Vectors;
31 import org.apache.spark.mllib.optimization.*;
32 import org.apache.spark.mllib.regression.LabeledPoint;
33 import org.apache.spark.mllib.util.MLUtils;
34 import org.apache.spark.SparkConf;
35 import org.apache.spark.SparkContext;
36 // $example off$
37 
38 public class JavaLBFGSExample {
main(String[] args)39   public static void main(String[] args) {
40     SparkConf conf = new SparkConf().setAppName("L-BFGS Example");
41     SparkContext sc = new SparkContext(conf);
42 
43     // $example on$
44     String path = "data/mllib/sample_libsvm_data.txt";
45     JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
46     int numFeatures = data.take(1).get(0).features().size();
47 
48     // Split initial RDD into two... [60% training data, 40% testing data].
49     JavaRDD<LabeledPoint> trainingInit = data.sample(false, 0.6, 11L);
50     JavaRDD<LabeledPoint> test = data.subtract(trainingInit);
51 
52     // Append 1 into the training data as intercept.
53     JavaRDD<Tuple2<Object, Vector>> training = data.map(
54       new Function<LabeledPoint, Tuple2<Object, Vector>>() {
55         public Tuple2<Object, Vector> call(LabeledPoint p) {
56           return new Tuple2<Object, Vector>(p.label(), MLUtils.appendBias(p.features()));
57         }
58       });
59     training.cache();
60 
61     // Run training algorithm to build the model.
62     int numCorrections = 10;
63     double convergenceTol = 1e-4;
64     int maxNumIterations = 20;
65     double regParam = 0.1;
66     Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]);
67 
68     Tuple2<Vector, double[]> result = LBFGS.runLBFGS(
69       training.rdd(),
70       new LogisticGradient(),
71       new SquaredL2Updater(),
72       numCorrections,
73       convergenceTol,
74       maxNumIterations,
75       regParam,
76       initialWeightsWithIntercept);
77     Vector weightsWithIntercept = result._1();
78     double[] loss = result._2();
79 
80     final LogisticRegressionModel model = new LogisticRegressionModel(
81       Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)),
82       (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]);
83 
84     // Clear the default threshold.
85     model.clearThreshold();
86 
87     // Compute raw scores on the test set.
88     JavaRDD<Tuple2<Object, Object>> scoreAndLabels = test.map(
89       new Function<LabeledPoint, Tuple2<Object, Object>>() {
90         public Tuple2<Object, Object> call(LabeledPoint p) {
91           Double score = model.predict(p.features());
92           return new Tuple2<Object, Object>(score, p.label());
93         }
94       });
95 
96     // Get evaluation metrics.
97     BinaryClassificationMetrics metrics =
98       new BinaryClassificationMetrics(scoreAndLabels.rdd());
99     double auROC = metrics.areaUnderROC();
100 
101     System.out.println("Loss of each step in training process");
102     for (double l : loss)
103       System.out.println(l);
104     System.out.println("Area under ROC = " + auROC);
105     // $example off$
106   }
107 }
108 
109