1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18// scalastyle:off println 19package org.apache.spark.examples.mllib 20 21import org.apache.log4j.{Level, Logger} 22import scopt.OptionParser 23 24import org.apache.spark.{SparkConf, SparkContext} 25import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, SVMWithSGD} 26import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics 27import org.apache.spark.mllib.optimization.{L1Updater, SquaredL2Updater} 28import org.apache.spark.mllib.util.MLUtils 29 30/** 31 * An example app for binary classification. Run with 32 * {{{ 33 * bin/run-example org.apache.spark.examples.mllib.BinaryClassification 34 * }}} 35 * A synthetic dataset is located at `data/mllib/sample_binary_classification_data.txt`. 36 * If you use it as a template to create your own app, please use `spark-submit` to submit your app. 37 */ 38object BinaryClassification { 39 40 object Algorithm extends Enumeration { 41 type Algorithm = Value 42 val SVM, LR = Value 43 } 44 45 object RegType extends Enumeration { 46 type RegType = Value 47 val L1, L2 = Value 48 } 49 50 import Algorithm._ 51 import RegType._ 52 53 case class Params( 54 input: String = null, 55 numIterations: Int = 100, 56 stepSize: Double = 1.0, 57 algorithm: Algorithm = LR, 58 regType: RegType = L2, 59 regParam: Double = 0.01) extends AbstractParams[Params] 60 61 def main(args: Array[String]) { 62 val defaultParams = Params() 63 64 val parser = new OptionParser[Params]("BinaryClassification") { 65 head("BinaryClassification: an example app for binary classification.") 66 opt[Int]("numIterations") 67 .text("number of iterations") 68 .action((x, c) => c.copy(numIterations = x)) 69 opt[Double]("stepSize") 70 .text("initial step size (ignored by logistic regression), " + 71 s"default: ${defaultParams.stepSize}") 72 .action((x, c) => c.copy(stepSize = x)) 73 opt[String]("algorithm") 74 .text(s"algorithm (${Algorithm.values.mkString(",")}), " + 75 s"default: ${defaultParams.algorithm}") 76 .action((x, c) => c.copy(algorithm = Algorithm.withName(x))) 77 opt[String]("regType") 78 .text(s"regularization type (${RegType.values.mkString(",")}), " + 79 s"default: ${defaultParams.regType}") 80 .action((x, c) => c.copy(regType = RegType.withName(x))) 81 opt[Double]("regParam") 82 .text(s"regularization parameter, default: ${defaultParams.regParam}") 83 arg[String]("<input>") 84 .required() 85 .text("input paths to labeled examples in LIBSVM format") 86 .action((x, c) => c.copy(input = x)) 87 note( 88 """ 89 |For example, the following command runs this app on a synthetic dataset: 90 | 91 | bin/spark-submit --class org.apache.spark.examples.mllib.BinaryClassification \ 92 | examples/target/scala-*/spark-examples-*.jar \ 93 | --algorithm LR --regType L2 --regParam 1.0 \ 94 | data/mllib/sample_binary_classification_data.txt 95 """.stripMargin) 96 } 97 98 parser.parse(args, defaultParams) match { 99 case Some(params) => run(params) 100 case _ => sys.exit(1) 101 } 102 } 103 104 def run(params: Params): Unit = { 105 val conf = new SparkConf().setAppName(s"BinaryClassification with $params") 106 val sc = new SparkContext(conf) 107 108 Logger.getRootLogger.setLevel(Level.WARN) 109 110 val examples = MLUtils.loadLibSVMFile(sc, params.input).cache() 111 112 val splits = examples.randomSplit(Array(0.8, 0.2)) 113 val training = splits(0).cache() 114 val test = splits(1).cache() 115 116 val numTraining = training.count() 117 val numTest = test.count() 118 println(s"Training: $numTraining, test: $numTest.") 119 120 examples.unpersist(blocking = false) 121 122 val updater = params.regType match { 123 case L1 => new L1Updater() 124 case L2 => new SquaredL2Updater() 125 } 126 127 val model = params.algorithm match { 128 case LR => 129 val algorithm = new LogisticRegressionWithLBFGS() 130 algorithm.optimizer 131 .setNumIterations(params.numIterations) 132 .setUpdater(updater) 133 .setRegParam(params.regParam) 134 algorithm.run(training).clearThreshold() 135 case SVM => 136 val algorithm = new SVMWithSGD() 137 algorithm.optimizer 138 .setNumIterations(params.numIterations) 139 .setStepSize(params.stepSize) 140 .setUpdater(updater) 141 .setRegParam(params.regParam) 142 algorithm.run(training).clearThreshold() 143 } 144 145 val prediction = model.predict(test.map(_.features)) 146 val predictionAndLabel = prediction.zip(test.map(_.label)) 147 148 val metrics = new BinaryClassificationMetrics(predictionAndLabel) 149 150 println(s"Test areaUnderPR = ${metrics.areaUnderPR()}.") 151 println(s"Test areaUnderROC = ${metrics.areaUnderROC()}.") 152 153 sc.stop() 154 } 155} 156// scalastyle:on println 157