1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18// scalastyle:off println
19package org.apache.spark.examples.mllib
20
21import org.apache.log4j.{Level, Logger}
22import scopt.OptionParser
23
24import org.apache.spark.{SparkConf, SparkContext}
25import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, SVMWithSGD}
26import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
27import org.apache.spark.mllib.optimization.{L1Updater, SquaredL2Updater}
28import org.apache.spark.mllib.util.MLUtils
29
30/**
31 * An example app for binary classification. Run with
32 * {{{
33 * bin/run-example org.apache.spark.examples.mllib.BinaryClassification
34 * }}}
35 * A synthetic dataset is located at `data/mllib/sample_binary_classification_data.txt`.
36 * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
37 */
38object BinaryClassification {
39
40  object Algorithm extends Enumeration {
41    type Algorithm = Value
42    val SVM, LR = Value
43  }
44
45  object RegType extends Enumeration {
46    type RegType = Value
47    val L1, L2 = Value
48  }
49
50  import Algorithm._
51  import RegType._
52
53  case class Params(
54      input: String = null,
55      numIterations: Int = 100,
56      stepSize: Double = 1.0,
57      algorithm: Algorithm = LR,
58      regType: RegType = L2,
59      regParam: Double = 0.01) extends AbstractParams[Params]
60
61  def main(args: Array[String]) {
62    val defaultParams = Params()
63
64    val parser = new OptionParser[Params]("BinaryClassification") {
65      head("BinaryClassification: an example app for binary classification.")
66      opt[Int]("numIterations")
67        .text("number of iterations")
68        .action((x, c) => c.copy(numIterations = x))
69      opt[Double]("stepSize")
70        .text("initial step size (ignored by logistic regression), " +
71          s"default: ${defaultParams.stepSize}")
72        .action((x, c) => c.copy(stepSize = x))
73      opt[String]("algorithm")
74        .text(s"algorithm (${Algorithm.values.mkString(",")}), " +
75        s"default: ${defaultParams.algorithm}")
76        .action((x, c) => c.copy(algorithm = Algorithm.withName(x)))
77      opt[String]("regType")
78        .text(s"regularization type (${RegType.values.mkString(",")}), " +
79        s"default: ${defaultParams.regType}")
80        .action((x, c) => c.copy(regType = RegType.withName(x)))
81      opt[Double]("regParam")
82        .text(s"regularization parameter, default: ${defaultParams.regParam}")
83      arg[String]("<input>")
84        .required()
85        .text("input paths to labeled examples in LIBSVM format")
86        .action((x, c) => c.copy(input = x))
87      note(
88        """
89          |For example, the following command runs this app on a synthetic dataset:
90          |
91          | bin/spark-submit --class org.apache.spark.examples.mllib.BinaryClassification \
92          |  examples/target/scala-*/spark-examples-*.jar \
93          |  --algorithm LR --regType L2 --regParam 1.0 \
94          |  data/mllib/sample_binary_classification_data.txt
95        """.stripMargin)
96    }
97
98    parser.parse(args, defaultParams) match {
99      case Some(params) => run(params)
100      case _ => sys.exit(1)
101    }
102  }
103
104  def run(params: Params): Unit = {
105    val conf = new SparkConf().setAppName(s"BinaryClassification with $params")
106    val sc = new SparkContext(conf)
107
108    Logger.getRootLogger.setLevel(Level.WARN)
109
110    val examples = MLUtils.loadLibSVMFile(sc, params.input).cache()
111
112    val splits = examples.randomSplit(Array(0.8, 0.2))
113    val training = splits(0).cache()
114    val test = splits(1).cache()
115
116    val numTraining = training.count()
117    val numTest = test.count()
118    println(s"Training: $numTraining, test: $numTest.")
119
120    examples.unpersist(blocking = false)
121
122    val updater = params.regType match {
123      case L1 => new L1Updater()
124      case L2 => new SquaredL2Updater()
125    }
126
127    val model = params.algorithm match {
128      case LR =>
129        val algorithm = new LogisticRegressionWithLBFGS()
130        algorithm.optimizer
131          .setNumIterations(params.numIterations)
132          .setUpdater(updater)
133          .setRegParam(params.regParam)
134        algorithm.run(training).clearThreshold()
135      case SVM =>
136        val algorithm = new SVMWithSGD()
137        algorithm.optimizer
138          .setNumIterations(params.numIterations)
139          .setStepSize(params.stepSize)
140          .setUpdater(updater)
141          .setRegParam(params.regParam)
142        algorithm.run(training).clearThreshold()
143    }
144
145    val prediction = model.predict(test.map(_.features))
146    val predictionAndLabel = prediction.zip(test.map(_.label))
147
148    val metrics = new BinaryClassificationMetrics(predictionAndLabel)
149
150    println(s"Test areaUnderPR = ${metrics.areaUnderPR()}.")
151    println(s"Test areaUnderROC = ${metrics.areaUnderROC()}.")
152
153    sc.stop()
154  }
155}
156// scalastyle:on println
157