1/*
2 Copyright (c) 2014 by Contributors
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 */
16
17package ml.dmlc.xgboost4j.scala.rabit.util
18
19import java.nio.{ByteBuffer, ByteOrder}
20
21/**
22  * The assigned rank to a connecting Rabit worker, along with the information of the ranks of
23  * its linked peer workers, which are critical to perform Allreduce.
24  * When RabitWorkerHandler delegates "start" or "recover" commands from the connecting worker
25  * client, RabitTrackerHandler utilizes LinkMap to figure out linkage relationships, and respond
26  * with this class as a message, which is later encoded to byte string, and sent over socket
27  * connection to the worker client.
28  *
29  * @param rank assigned rank (ranked by worker connection order: first worker connecting to the
30  *             tracker is assigned rank 0, second with rank 1, etc.)
31  * @param neighbors ranks of neighboring workers in a tree map.
32  * @param ring ranks of neighboring workers in a ring map.
33  * @param parent rank of the parent worker.
34  */
35private[rabit] case class AssignedRank(rank: Int, neighbors: Seq[Int],
36                                       ring: (Int, Int), parent: Int) {
37  /**
38    * Encode the AssignedRank message into byte sequence for socket communication with Rabit worker
39    * client.
40    * @param worldSize the number of total distributed workers. Must match `numWorkers` used in
41    *                  LinkMap.
42    * @return a ByteBuffer containing encoded data.
43    */
44  def toByteBuffer(worldSize: Int): ByteBuffer = {
45    val buffer = ByteBuffer.allocate(4 * (neighbors.length + 6)).order(ByteOrder.nativeOrder())
46    buffer.putInt(rank).putInt(parent).putInt(worldSize).putInt(neighbors.length)
47    // neighbors in tree structure
48    neighbors.foreach { n => buffer.putInt(n) }
49    buffer.putInt(if (ring._1 != -1 && ring._1 != rank) ring._1 else -1)
50    buffer.putInt(if (ring._2 != -1 && ring._2 != rank) ring._2 else -1)
51
52    buffer.flip()
53    buffer
54  }
55}
56
57private[rabit] class LinkMap(numWorkers: Int) {
58  private def getNeighbors(rank: Int): Seq[Int] = {
59    val rank1 = rank + 1
60    Vector(rank1 / 2 - 1, rank1 * 2 - 1, rank1 * 2).filter { r =>
61      r >= 0 && r < numWorkers
62    }
63  }
64
65  /**
66    * Construct a ring structure that tends to share nodes with the tree.
67    *
68    * @param treeMap
69    * @param parentMap
70    * @param rank
71    * @return Seq[Int] instance starting from rank.
72    */
73  private def constructShareRing(treeMap: Map[Int, Seq[Int]],
74                                 parentMap: Map[Int, Int],
75                                 rank: Int = 0): Seq[Int] = {
76    treeMap(rank).toSet - parentMap(rank) match {
77      case emptySet if emptySet.isEmpty =>
78        List(rank)
79      case connectionSet =>
80        connectionSet.zipWithIndex.foldLeft(List(rank)) {
81          case (ringSeq, (v, cnt)) =>
82            val vConnSeq = constructShareRing(treeMap, parentMap, v)
83            vConnSeq match {
84              case vconn if vconn.size == cnt + 1 =>
85                ringSeq ++ vconn.reverse
86              case vconn =>
87                ringSeq ++ vconn
88            }
89        }
90    }
91  }
92  /**
93    * Construct a ring connection used to recover local data.
94    *
95    * @param treeMap
96    * @param parentMap
97    */
98  private def constructRingMap(treeMap: Map[Int, Seq[Int]], parentMap: Map[Int, Int]) = {
99    assert(parentMap(0) == -1)
100
101    val sharedRing = constructShareRing(treeMap, parentMap, 0).toVector
102    assert(sharedRing.length == treeMap.size)
103
104    (0 until numWorkers).map { r =>
105      val rPrev = (r + numWorkers - 1) % numWorkers
106      val rNext = (r + 1) % numWorkers
107      sharedRing(r) -> (sharedRing(rPrev), sharedRing(rNext))
108    }.toMap
109  }
110
111  private[this] val treeMap_ = (0 until numWorkers).map { r => r -> getNeighbors(r) }.toMap
112  private[this] val parentMap_ = (0 until numWorkers).map{ r => r -> ((r + 1) / 2 - 1) }.toMap
113  private[this] val ringMap_ = constructRingMap(treeMap_, parentMap_)
114  val rMap_ = (0 until (numWorkers - 1)).foldLeft((Map(0 -> 0), 0)) {
115    case ((rmap, k), i) =>
116      val kNext = ringMap_(k)._2
117      (rmap ++ Map(kNext -> (i + 1)), kNext)
118  }._1
119
120  val ringMap = ringMap_.map {
121    case (k, (v0, v1)) => rMap_(k) -> (rMap_(v0), rMap_(v1))
122  }
123  val treeMap = treeMap_.map {
124    case (k, vSeq) => rMap_(k) -> vSeq.map{ v => rMap_(v) }
125  }
126  val parentMap = parentMap_.map {
127    case (k, v) if k == 0 =>
128      rMap_(k) -> -1
129    case (k, v) =>
130      rMap_(k) -> rMap_(v)
131  }
132
133  def assignRank(rank: Int): AssignedRank = {
134    AssignedRank(rank, treeMap(rank), ringMap(rank), parentMap(rank))
135  }
136}
137