1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnet
19
20import org.apache.mxnet.Base._
21import org.slf4j.{Logger, LoggerFactory}
22
23private[mxnet] class KVStoreServer(private val kvStore: KVStore) {
24  private val logger: Logger = LoggerFactory.getLogger(classOf[KVStoreServer])
25  private val handle: KVStoreHandle = kvStore.handle
26  private val controller = new KVServerControllerCallback {
27    override def invoke(cmdId: Int, cmdBody: String): Unit = {
28      logger.debug("Receive cmdId {}, cmdBody: {}", cmdId, cmdBody)
29      if (cmdId == 0) {
30        val optimizer = Serializer.getSerializer.deserialize[Optimizer](
31          Serializer.decodeBase64String(cmdBody))
32        kvStore.setOptimizer(optimizer)
33      } else {
34        logger.warn(s"Server ${kvStore.rank}, unknown command ($cmdId, $cmdBody)")
35      }
36    }
37  }
38
39  // run the server, whose behavior is like
40  // while receive(x):
41  //   if is_command x: controller(x)
42  //   else if is_key_value x: updater(x)
43  def run(): Unit = {
44    checkCall(_LIB.mxKVStoreRunServer(handle, controller))
45  }
46}
47
48object KVStoreServer {
49  private val logger: Logger = LoggerFactory.getLogger(classOf[KVStoreServer])
50  /**
51   * Start server/scheduler according to env variables
52   * @param dieIfOthersGoOutTimeout When this argument is set to an integer greater than 0
53   *                                (in second),
54   *                                a daemon thread will start to periodically check
55   *                                whether scheduler (server side) or servers (scheduler side)
56   *                                are dead. If so, die itself.
57   *                                This could be useful for running mxnet on distributed
58   *                                data platform,
59   *                                where you do not know which node your application runs on
60   *                                and in such situation
61   *                                you want others die automatically once
62   *                                some of the nodes goes out.
63   */
64  def start(dieIfOthersGoOutTimeout: Int = 0): Unit = {
65    val isWorker = new RefInt
66    checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker))
67    require(isWorker.value == 0, "cannot start kv-store server on worker node")
68    val kvStore = KVStore.create("dist")
69    val daemonThread: Option[Thread] =
70      if (dieIfOthersGoOutTimeout > 0) {
71        val daemon = new Runnable {
72          override def run(): Unit = {
73            var running = true
74            while (running) {
75              try {
76                Thread.sleep(dieIfOthersGoOutTimeout.toLong * 1000)
77                val numDead = kvStore.numDeadNode(KVStore.GROUP_NODE_SCHEDULER
78                  + KVStore.GROUP_NODE_SERVER + KVStore.GROUP_NODE_WORKER)
79                if (numDead > 0) {
80                  logger.error(s"Detect $numDead dead node(s). Shutdown now.")
81                  System.exit(1)
82                }
83              } catch {
84                case e: InterruptedException => running = false
85              }
86            }
87          }
88        }
89        val t = new Thread(daemon)
90        t.setDaemon(true)
91        t.start()
92        Option(t)
93      } else {
94        None
95      }
96    val server = new KVStoreServer(kvStore)
97    server.run()
98    daemonThread.foreach(t => {
99      t.interrupt()
100      t.join()
101    })
102    kvStore.dispose()
103  }
104
105  def init(env: Map[String, String]): Unit = {
106    val keys = env.keys.toArray
107    val vals = env.values.toArray
108    checkCall(_LIB.mxInitPSEnv(keys, vals))
109  }
110}
111
112private[mxnet] trait KVServerControllerCallback {
113  def invoke(cmdId: Int, cmdBody: String): Unit
114}
115