1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnet.spark
19
20import org.apache.mxnet._
21import org.apache.mxnet.optimizer.SGD
22import org.apache.mxnet.spark.io.LabeledPointIter
23
24import org.slf4j.{Logger, LoggerFactory}
25
26import org.apache.spark.mllib.regression.LabeledPoint
27import org.apache.spark.rdd.RDD
28import org.apache.spark.SparkContext
29
30class MXNet extends Serializable {
31
32  class MXNetControllingThread(
33      schedulerIP: String,
34      schedulerPort: Int,
35      sparkContext: SparkContext,
36      triggerOfComponent: (String, Int, SparkContext) => Unit) extends Thread {
37    override def run() {
38      triggerOfComponent(schedulerIP, schedulerPort, sparkContext)
39    }
40  }
41
42  private val logger: Logger = LoggerFactory.getLogger(classOf[MXNet])
43  private val params: MXNetParams = new MXNetParams
44
45  @transient private var psServerThread: MXNetControllingThread = _
46  @transient private var psSchedulerThread: MXNetControllingThread = _
47
48  def setBatchSize(batchSize: Int): this.type = {
49    params.batchSize = batchSize
50    this
51  }
52
53  def setNumEpoch(numEpoch: Int): this.type = {
54    params.numEpoch = numEpoch
55    this
56  }
57
58  def setDimension(dimension: Shape): this.type = {
59    params.dimension = dimension
60    this
61  }
62
63  def setNetwork(network: Symbol): this.type = {
64    params.setNetwork(network)
65    this
66  }
67
68  def setContext(ctx: Array[Context]): this.type = {
69    params.context = ctx
70    this
71  }
72
73  def setNumWorker(numWorker: Int): this.type = {
74    params.numWorker = numWorker
75    this
76  }
77
78  def setNumServer(numServer: Int): this.type = {
79    params.numServer = numServer
80    this
81  }
82
83  def setDataName(name: String): this.type = {
84    params.dataName = name
85    this
86  }
87
88  def setLabelName(name: String): this.type = {
89    params.labelName = name
90    this
91  }
92
93  /**
94   * The application (including parameter scheduler & servers)
95   * will exist if it hasn't received heart beat for over timeout seconds
96   * @param timeout timeout in seconds (default 300)
97   */
98  def setTimeout(timeout: Int): this.type = {
99    params.timeout = timeout
100    this
101  }
102
103  /**
104   * These jars are required by the KVStores at runtime.
105   * They will be uploaded and distributed to each node automatically
106   * @param jars jars required by the KVStore at runtime.
107   */
108  def setExecutorJars(jars: String): this.type = {
109    params.jars = jars.split(",|:")
110    this
111  }
112
113  def setJava(java: String): this.type = {
114    params.javabin = java
115    this
116  }
117
118  private def startPSServers(
119      schedulerIP: String,
120      schedulerPort: Int,
121      sc: SparkContext) = {
122    def startPSServersInner(
123        schedulerIP: String,
124        schedulerPort: Int,
125        sc: SparkContext): Unit = {
126      sc.parallelize(1 to params.numServer, params.numServer).foreachPartition { p =>
127          logger.info("Starting server ...")
128          val server = new ParameterServer(params.runtimeClasspath,
129            role = "server",
130            rootUri = schedulerIP,
131            rootPort = schedulerPort,
132            numServer = params.numServer,
133            numWorker = params.numWorker,
134            timeout = params.timeout,
135            java = params.javabin)
136          val exitCode = server.startProcess()
137          require(exitCode == 0, s"ps server process quit with exit code $exitCode")
138        }
139    }
140    psServerThread = new MXNetControllingThread(schedulerIP, schedulerPort, sc, startPSServersInner)
141    psServerThread.start()
142  }
143
144  private def startPSScheduler(
145      schedulerIP: String,
146      schedulerPort: Int,
147      sc: SparkContext) = {
148    def startPSSchedulerInner(
149        schedulerIP: String,
150        schedulerPort: Int,
151        sc: SparkContext): Unit = {
152      // TODO: check ip & port available
153      logger.info("Starting scheduler on {}:{}", schedulerIP, schedulerPort)
154      val scheduler = new ParameterServer(params.runtimeClasspath, role = "scheduler",
155        rootUri = schedulerIP, rootPort = schedulerPort,
156        numServer = params.numServer, numWorker = params.numWorker,
157        timeout = params.timeout, java = params.javabin)
158      val exitCode = scheduler.startProcess()
159      require(exitCode == 0, s"Failed to start ps scheduler process with exit code $exitCode")
160    }
161    psSchedulerThread = new MXNetControllingThread(schedulerIP, schedulerPort, sc,
162      startPSSchedulerInner)
163    psSchedulerThread.start()
164  }
165
166  private def setFeedForwardModel(
167      optimizer: Optimizer,
168      numExamples: Int,
169      kv: KVStore,
170      inputInPartition: LabeledPointIter): FeedForward = {
171    logger.debug("Define model")
172    val model = new FeedForward(ctx = params.context,
173      symbol = params.getNetwork,
174      numEpoch = params.numEpoch,
175      optimizer = optimizer,
176      initializer = new Xavier(factorType = "in", magnitude = 2.34f),
177      argParams = null,
178      auxParams = null,
179      beginEpoch = 0,
180      epochSize = numExamples / params.batchSize / kv.numWorkers)
181    logger.info("Start training ...")
182    model.fit(trainData = inputInPartition,
183      evalData = null,
184      evalMetric = new Accuracy(),
185      kvStore = kv)
186    model
187  }
188
189  private def setupKVStore(schedulerIP: String, schedulerPort: Int): KVStore = {
190    KVStoreServer.init(ParameterServer.buildEnv(role = "worker",
191      rootUri = schedulerIP, rootPort = schedulerPort,
192      numServer = params.numServer,
193      numWorker = params.numWorker))
194    val kv = KVStore.create("dist_async")
195    kv.setBarrierBeforeExit(false)
196    kv
197  }
198
199  private def reclaimResources(dataIter: LabeledPointIter, kv: KVStore): Unit = {
200    dataIter.dispose()
201    kv.setBarrierBeforeExit(true)
202    kv.dispose()
203  }
204
205  private def trainModel(
206      trainData: RDD[LabeledPoint],
207      schedulerIP: String,
208      schedulerPort: Int): MXNetModel = {
209    val job = trainData.mapPartitions { partition =>
210      val dataIter = new LabeledPointIter(
211        partition, params.dimension,
212        params.batchSize,
213        dataName = params.dataName,
214        labelName = params.labelName)
215      // TODO: more nature way to get the # of examples?
216      var numExamples = 0
217      while (dataIter.hasNext) {
218        val dataBatch = dataIter.next()
219        numExamples += dataBatch.label.head.shape(0)
220      }
221      logger.debug("Number of samples: {}", numExamples)
222      dataIter.reset()
223
224      logger.info("Launching worker ...")
225      logger.info("Batch {}", params.batchSize)
226      // give enough time for ps-lite to detect the dead nodes
227      Thread.sleep(20000)
228      val kv = setupKVStore(schedulerIP, schedulerPort)
229      val optimizer = new SGD(learningRate = 0.01f, momentum = 0.9f, wd = 0.00001f)
230      val model = setFeedForwardModel(optimizer, numExamples, kv, dataIter)
231      logger.info("Training finished, waiting for other workers ...")
232      reclaimResources(dataIter, kv)
233      Iterator(new MXNetModel(
234        model, params.dimension, params.batchSize,
235        dataName = params.dataName, labelName = params.labelName))
236    }.cache()
237    // force job to run
238    job.foreachPartition(() => _)
239    job.first()
240  }
241
242  def fit(data: RDD[LabeledPoint]): MXNetModel = {
243    val sc = data.context
244    // distribute native jars
245    if (params.jars != null) {
246      params.jars.foreach(jar => sc.addFile(jar))
247    }
248    val trainData = {
249      if (params.numWorker != data.partitions.length) {
250        logger.info("repartitioning training set to {} partitions", params.numWorker)
251        data.repartition(params.numWorker)
252      } else {
253        data
254      }
255    }
256    val schedulerIP = utils.Network.ipAddress
257    val schedulerPort = utils.Network.availablePort
258    startPSScheduler(schedulerIP, schedulerPort, sc)
259    startPSServers(schedulerIP, schedulerPort, sc)
260    val mxModel = trainModel(trainData, schedulerIP, schedulerPort)
261    logger.info("Waiting for scheduler ...")
262    psSchedulerThread.join()
263    psServerThread.join()
264    mxModel
265  }
266}
267