1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnet.spark 19 20import org.apache.mxnet._ 21import org.apache.mxnet.optimizer.SGD 22import org.apache.mxnet.spark.io.LabeledPointIter 23 24import org.slf4j.{Logger, LoggerFactory} 25 26import org.apache.spark.mllib.regression.LabeledPoint 27import org.apache.spark.rdd.RDD 28import org.apache.spark.SparkContext 29 30class MXNet extends Serializable { 31 32 class MXNetControllingThread( 33 schedulerIP: String, 34 schedulerPort: Int, 35 sparkContext: SparkContext, 36 triggerOfComponent: (String, Int, SparkContext) => Unit) extends Thread { 37 override def run() { 38 triggerOfComponent(schedulerIP, schedulerPort, sparkContext) 39 } 40 } 41 42 private val logger: Logger = LoggerFactory.getLogger(classOf[MXNet]) 43 private val params: MXNetParams = new MXNetParams 44 45 @transient private var psServerThread: MXNetControllingThread = _ 46 @transient private var psSchedulerThread: MXNetControllingThread = _ 47 48 def setBatchSize(batchSize: Int): this.type = { 49 params.batchSize = batchSize 50 this 51 } 52 53 def setNumEpoch(numEpoch: Int): this.type = { 54 params.numEpoch = numEpoch 55 this 56 } 57 58 def setDimension(dimension: Shape): this.type = { 59 params.dimension = dimension 60 this 61 } 62 63 def setNetwork(network: Symbol): this.type = { 64 params.setNetwork(network) 65 this 66 } 67 68 def setContext(ctx: Array[Context]): this.type = { 69 params.context = ctx 70 this 71 } 72 73 def setNumWorker(numWorker: Int): this.type = { 74 params.numWorker = numWorker 75 this 76 } 77 78 def setNumServer(numServer: Int): this.type = { 79 params.numServer = numServer 80 this 81 } 82 83 def setDataName(name: String): this.type = { 84 params.dataName = name 85 this 86 } 87 88 def setLabelName(name: String): this.type = { 89 params.labelName = name 90 this 91 } 92 93 /** 94 * The application (including parameter scheduler & servers) 95 * will exist if it hasn't received heart beat for over timeout seconds 96 * @param timeout timeout in seconds (default 300) 97 */ 98 def setTimeout(timeout: Int): this.type = { 99 params.timeout = timeout 100 this 101 } 102 103 /** 104 * These jars are required by the KVStores at runtime. 105 * They will be uploaded and distributed to each node automatically 106 * @param jars jars required by the KVStore at runtime. 107 */ 108 def setExecutorJars(jars: String): this.type = { 109 params.jars = jars.split(",|:") 110 this 111 } 112 113 def setJava(java: String): this.type = { 114 params.javabin = java 115 this 116 } 117 118 private def startPSServers( 119 schedulerIP: String, 120 schedulerPort: Int, 121 sc: SparkContext) = { 122 def startPSServersInner( 123 schedulerIP: String, 124 schedulerPort: Int, 125 sc: SparkContext): Unit = { 126 sc.parallelize(1 to params.numServer, params.numServer).foreachPartition { p => 127 logger.info("Starting server ...") 128 val server = new ParameterServer(params.runtimeClasspath, 129 role = "server", 130 rootUri = schedulerIP, 131 rootPort = schedulerPort, 132 numServer = params.numServer, 133 numWorker = params.numWorker, 134 timeout = params.timeout, 135 java = params.javabin) 136 val exitCode = server.startProcess() 137 require(exitCode == 0, s"ps server process quit with exit code $exitCode") 138 } 139 } 140 psServerThread = new MXNetControllingThread(schedulerIP, schedulerPort, sc, startPSServersInner) 141 psServerThread.start() 142 } 143 144 private def startPSScheduler( 145 schedulerIP: String, 146 schedulerPort: Int, 147 sc: SparkContext) = { 148 def startPSSchedulerInner( 149 schedulerIP: String, 150 schedulerPort: Int, 151 sc: SparkContext): Unit = { 152 // TODO: check ip & port available 153 logger.info("Starting scheduler on {}:{}", schedulerIP, schedulerPort) 154 val scheduler = new ParameterServer(params.runtimeClasspath, role = "scheduler", 155 rootUri = schedulerIP, rootPort = schedulerPort, 156 numServer = params.numServer, numWorker = params.numWorker, 157 timeout = params.timeout, java = params.javabin) 158 val exitCode = scheduler.startProcess() 159 require(exitCode == 0, s"Failed to start ps scheduler process with exit code $exitCode") 160 } 161 psSchedulerThread = new MXNetControllingThread(schedulerIP, schedulerPort, sc, 162 startPSSchedulerInner) 163 psSchedulerThread.start() 164 } 165 166 private def setFeedForwardModel( 167 optimizer: Optimizer, 168 numExamples: Int, 169 kv: KVStore, 170 inputInPartition: LabeledPointIter): FeedForward = { 171 logger.debug("Define model") 172 val model = new FeedForward(ctx = params.context, 173 symbol = params.getNetwork, 174 numEpoch = params.numEpoch, 175 optimizer = optimizer, 176 initializer = new Xavier(factorType = "in", magnitude = 2.34f), 177 argParams = null, 178 auxParams = null, 179 beginEpoch = 0, 180 epochSize = numExamples / params.batchSize / kv.numWorkers) 181 logger.info("Start training ...") 182 model.fit(trainData = inputInPartition, 183 evalData = null, 184 evalMetric = new Accuracy(), 185 kvStore = kv) 186 model 187 } 188 189 private def setupKVStore(schedulerIP: String, schedulerPort: Int): KVStore = { 190 KVStoreServer.init(ParameterServer.buildEnv(role = "worker", 191 rootUri = schedulerIP, rootPort = schedulerPort, 192 numServer = params.numServer, 193 numWorker = params.numWorker)) 194 val kv = KVStore.create("dist_async") 195 kv.setBarrierBeforeExit(false) 196 kv 197 } 198 199 private def reclaimResources(dataIter: LabeledPointIter, kv: KVStore): Unit = { 200 dataIter.dispose() 201 kv.setBarrierBeforeExit(true) 202 kv.dispose() 203 } 204 205 private def trainModel( 206 trainData: RDD[LabeledPoint], 207 schedulerIP: String, 208 schedulerPort: Int): MXNetModel = { 209 val job = trainData.mapPartitions { partition => 210 val dataIter = new LabeledPointIter( 211 partition, params.dimension, 212 params.batchSize, 213 dataName = params.dataName, 214 labelName = params.labelName) 215 // TODO: more nature way to get the # of examples? 216 var numExamples = 0 217 while (dataIter.hasNext) { 218 val dataBatch = dataIter.next() 219 numExamples += dataBatch.label.head.shape(0) 220 } 221 logger.debug("Number of samples: {}", numExamples) 222 dataIter.reset() 223 224 logger.info("Launching worker ...") 225 logger.info("Batch {}", params.batchSize) 226 // give enough time for ps-lite to detect the dead nodes 227 Thread.sleep(20000) 228 val kv = setupKVStore(schedulerIP, schedulerPort) 229 val optimizer = new SGD(learningRate = 0.01f, momentum = 0.9f, wd = 0.00001f) 230 val model = setFeedForwardModel(optimizer, numExamples, kv, dataIter) 231 logger.info("Training finished, waiting for other workers ...") 232 reclaimResources(dataIter, kv) 233 Iterator(new MXNetModel( 234 model, params.dimension, params.batchSize, 235 dataName = params.dataName, labelName = params.labelName)) 236 }.cache() 237 // force job to run 238 job.foreachPartition(() => _) 239 job.first() 240 } 241 242 def fit(data: RDD[LabeledPoint]): MXNetModel = { 243 val sc = data.context 244 // distribute native jars 245 if (params.jars != null) { 246 params.jars.foreach(jar => sc.addFile(jar)) 247 } 248 val trainData = { 249 if (params.numWorker != data.partitions.length) { 250 logger.info("repartitioning training set to {} partitions", params.numWorker) 251 data.repartition(params.numWorker) 252 } else { 253 data 254 } 255 } 256 val schedulerIP = utils.Network.ipAddress 257 val schedulerPort = utils.Network.availablePort 258 startPSScheduler(schedulerIP, schedulerPort, sc) 259 startPSServers(schedulerIP, schedulerPort, sc) 260 val mxModel = trainModel(trainData, schedulerIP, schedulerPort) 261 logger.info("Waiting for scheduler ...") 262 psSchedulerThread.join() 263 psServerThread.join() 264 mxModel 265 } 266} 267