1/* 2 Copyright (c) 2014 by Contributors 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17package ml.dmlc.xgboost4j.scala.rabit.util 18 19import java.nio.{ByteBuffer, ByteOrder} 20 21/** 22 * The assigned rank to a connecting Rabit worker, along with the information of the ranks of 23 * its linked peer workers, which are critical to perform Allreduce. 24 * When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker 25 * client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond 26 * with this class as a message, which is later encoded to byte string, and sent over socket 27 * connection to the worker client. 28 * 29 * @param rank assigned rank (ranked by worker connection order: first worker connecting to the 30 * tracker is assigned rank 0, second with rank 1, etc.) 31 * @param neighbors ranks of neighboring workers in a tree map. 32 * @param ring ranks of neighboring workers in a ring map. 33 * @param parent rank of the parent worker. 34 */ 35private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int], 36 ring: (Int, Int), parent: Int) { 37 /** 38 * Encode the AssignedRank message into byte sequence for socket communication with Rabit worker 39 * client. 40 * @param worldSize the number of total distributed workers. Must match `numWorkers` used in 41 * LinkMap. 42 * @return a ByteBuffer containing encoded data. 43 */ 44 def toByteBuffer(worldSize: Int): ByteBuffer = { 45 val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder()) 46 buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length) 47 // neighbors in tree structure 48 neighbors.foreach { n => buffer.putInt(n) } 49 buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1) 50 buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1) 51 52 buffer.flip() 53 buffer 54 } 55} 56 57private[rabit] class LinkMap(numWorkers: Int) { 58 private def getNeighbors(rank: Int): Seq[Int] = { 59 val rank1 = rank + 1 60 Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r => 61 r >= 0 && r < numWorkers 62 } 63 } 64 65 /** 66 * Construct a ring structure that tends to share nodes with the tree. 67 * 68 * @param treeMap 69 * @param parentMap 70 * @param rank 71 * @return Seq[Int] instance starting from rank. 72 */ 73 private def constructShareRing(treeMap: Map[Int, Seq[Int]], 74 parentMap: Map[Int, Int], 75 rank: Int = 0): Seq[Int] = { 76 treeMap(rank).toSet - parentMap(rank) match { 77 case emptySet if emptySet.isEmpty => 78 List(rank) 79 case connectionSet => 80 connectionSet.zipWithIndex.foldLeft(List(rank)) { 81 case (ringSeq, (v, cnt)) => 82 val vConnSeq = constructShareRing(treeMap, parentMap, v) 83 vConnSeq match { 84 case vconn if vconn.size == cnt + 1 => 85 ringSeq ++ vconn.reverse 86 case vconn => 87 ringSeq ++ vconn 88 } 89 } 90 } 91 } 92 /** 93 * Construct a ring connection used to recover local data. 94 * 95 * @param treeMap 96 * @param parentMap 97 */ 98 private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = { 99 assert(parentMap(0) == -1) 100 101 val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector 102 assert(sharedRing.length == treeMap.size) 103 104 (0 until numWorkers).map { r => 105 val rPrev = (r + numWorkers - 1) % numWorkers 106 val rNext = (r + 1) % numWorkers 107 sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext)) 108 }.toMap 109 } 110 111 private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap 112 private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap 113 private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_) 114 val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) { 115 case ((rmap, k), i) => 116 val kNext = ringMap_(k)._2 117 (rmap ++ Map(kNext -> (i + 1)), kNext) 118 }._1 119 120 val ringMap = ringMap_.map { 121 case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1)) 122 } 123 val treeMap = treeMap_.map { 124 case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) } 125 } 126 val parentMap = parentMap_.map { 127 case (k, v) if k == 0 => 128 rMap_(k) -> -1 129 case (k, v) => 130 rMap_(k) -> rMap_(v) 131 } 132 133 def assignRank(rank: Int): AssignedRank = { 134 AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank)) 135 } 136} 137