1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnet 19 20import org.apache.mxnet.Base._ 21import org.slf4j.{Logger, LoggerFactory} 22 23private[mxnet] class KVStoreServer(private val kvStore: KVStore) { 24 private val logger: Logger = LoggerFactory.getLogger(classOf[KVStoreServer]) 25 private val handle: KVStoreHandle = kvStore.handle 26 private val controller = new KVServerControllerCallback { 27 override def invoke(cmdId: Int, cmdBody: String): Unit = { 28 logger.debug("Receive cmdId {}, cmdBody: {}", cmdId, cmdBody) 29 if (cmdId == 0) { 30 val optimizer = Serializer.getSerializer.deserialize[Optimizer]( 31 Serializer.decodeBase64String(cmdBody)) 32 kvStore.setOptimizer(optimizer) 33 } else { 34 logger.warn(s"Server ${kvStore.rank}, unknown command ($cmdId, $cmdBody)") 35 } 36 } 37 } 38 39 // run the server, whose behavior is like 40 // while receive(x): 41 // if is_command x: controller(x) 42 // else if is_key_value x: updater(x) 43 def run(): Unit = { 44 checkCall(_LIB.mxKVStoreRunServer(handle, controller)) 45 } 46} 47 48object KVStoreServer { 49 private val logger: Logger = LoggerFactory.getLogger(classOf[KVStoreServer]) 50 /** 51 * Start server/scheduler according to env variables 52 * @param dieIfOthersGoOutTimeout When this argument is set to an integer greater than 0 53 * (in second), 54 * a daemon thread will start to periodically check 55 * whether scheduler (server side) or servers (scheduler side) 56 * are dead. If so, die itself. 57 * This could be useful for running mxnet on distributed 58 * data platform, 59 * where you do not know which node your application runs on 60 * and in such situation 61 * you want others die automatically once 62 * some of the nodes goes out. 63 */ 64 def start(dieIfOthersGoOutTimeout: Int = 0): Unit = { 65 val isWorker = new RefInt 66 checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker)) 67 require(isWorker.value == 0, "cannot start kv-store server on worker node") 68 val kvStore = KVStore.create("dist") 69 val daemonThread: Option[Thread] = 70 if (dieIfOthersGoOutTimeout > 0) { 71 val daemon = new Runnable { 72 override def run(): Unit = { 73 var running = true 74 while (running) { 75 try { 76 Thread.sleep(dieIfOthersGoOutTimeout.toLong * 1000) 77 val numDead = kvStore.numDeadNode(KVStore.GROUP_NODE_SCHEDULER 78 + KVStore.GROUP_NODE_SERVER + KVStore.GROUP_NODE_WORKER) 79 if (numDead > 0) { 80 logger.error(s"Detect $numDead dead node(s). Shutdown now.") 81 System.exit(1) 82 } 83 } catch { 84 case e: InterruptedException => running = false 85 } 86 } 87 } 88 } 89 val t = new Thread(daemon) 90 t.setDaemon(true) 91 t.start() 92 Option(t) 93 } else { 94 None 95 } 96 val server = new KVStoreServer(kvStore) 97 server.run() 98 daemonThread.foreach(t => { 99 t.interrupt() 100 t.join() 101 }) 102 kvStore.dispose() 103 } 104 105 def init(env: Map[String, String]): Unit = { 106 val keys = env.keys.toArray 107 val vals = env.values.toArray 108 checkCall(_LIB.mxInitPSEnv(keys, vals)) 109 } 110} 111 112private[mxnet] trait KVServerControllerCallback { 113 def invoke(cmdId: Int, cmdBody: String): Unit 114} 115