1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file ps_rabit-inl.h
22  * \brief distributed version of PS using BSP
23  *     synchronization in the backend
24  * \author Tianqi Chen, Mu Li
25  */
26 #ifndef MSHADOW_PS_RABIT_INL_H_ // NOLINT(*)
27 #define MSHADOW_PS_RABIT_INL_H_ // NOLINT(*)
28 #include <vector>
29 #include "./mshadow_ps.h"
30 #include "./ps_local-inl.h"
31 
32 #if MSHADOW_RABIT_PS
33 #include <rabit.h>
34 namespace mshadow {
35 namespace ps {
36 // multi-threaded implementation of
37 template<typename xpu, typename DType>
38 class RabitModel : public LocalModel<xpu, DType> {
39  public:
40   // parent type
41   typedef LocalModel<xpu, DType> Parent;
42   // constructor
RabitModel()43   RabitModel() {
44     // enforce usage of fifo queue
45     this->use_fifo_push_queue = 1;
46     destroy_reduce_thread_ = false;
47     disable_allreduce_ = 0;
48     this->init_reducer_ = 0;
49   }
~RabitModel(void)50   virtual ~RabitModel(void) {
51     Parent::Destroy();
52     if (init_reducer_ != 0) {
53       destroy_reduce_thread_ = true;
54       reduce_queue_.Abort(1);
55       thread_reduce_handler_.Join();
56       reduce_queue_.Destroy();
57     }
58   }
59   // initialize the parameter server
Init(const std::vector<int> & devices)60   virtual void Init(const std::vector<int> &devices) {
61     this->use_fifo_push_queue = 1;
62     // use fifo
63     reduce_queue_.Init(true);
64     thread_reduce_handler_.Start(ReduceGlobalThread, this);
65     init_reducer_ = 1;
66     // initialize other things
67     Parent::Init(devices);
68   }
69   // set parameters
SetParam(const char * name,const char * val)70   virtual void SetParam(const char *name, const char *val) {
71     if (!strcmp(name, "msg:disable_allreduce")) {
72       disable_allreduce_ = atoi(val);
73     }
74     Parent::SetParam(name, val);
75   }
76   // override this function, to use parameter server
HandlePushFinish(Tensor<cpu,3,DType> data,int key)77   virtual void HandlePushFinish(Tensor<cpu, 3, DType> data,
78                                 int key) {
79     // summation the data fron all devices
80     LocalModel<xpu, DType>::ReduceSum(data);
81     CHECK_EQ(data[0].CheckContiguous(), true) << "data must be contiguous";
82     ReduceTask tsk;
83     tsk.data = data[0]; tsk.key = key;
84     reduce_queue_.Push(tsk, 0);
85   }
86 
87  private:
88   // reduce task
89   struct ReduceTask {
90     int key;
91     mshadow::Tensor<cpu, 2> data;
92   };
93   // destroy reduce
94   bool destroy_reduce_thread_;
95   // whether reducer is initialized
96   int init_reducer_;
97   // check disable_allreduce functionalities
98   int disable_allreduce_;
99   // reduce handler thread
100   utils::Thread thread_reduce_handler_;
101   // queue for allreduce task
102   utils::ThreadPQueue<ReduceTask> reduce_queue_;
103   // reduce handler
ReduceHandler(void)104   inline void ReduceHandler(void) {
105     while (!destroy_reduce_thread_) {
106       ReduceTask tsk;
107       if (reduce_queue_.Pop(&tsk)) {
108         CHECK_EQ(disable_allreduce_, 0) << "Allreduce disabled error";
109         int key = tsk.key;
110         rabit::Allreduce<rabit::op::Max>(&key, 1);
111         CHECK_EQ(key, tsk.key) << "Allreduce not concensus";
112         rabit::Allreduce<rabit::op::Sum>
113             (tsk.data.dptr_, tsk.data.MSize());
114         tsk.data *= 1.0f / rabit::GetWorldSize();
115         CHECK_EQ(disable_allreduce_, 0) << "Allreduce disabled error";
116         this->HandleReduceFinish(tsk.data, tsk.key);
117       } else {
118         CHECK_EQ(destroy_reduce_thread_, true) << "abort but not destroy";
119       }
120     }
121   }
122   /*!\brief entry point of reduce thread */
ReduceGlobalThread(void * pthread)123   inline static MSHADOW_THREAD_PREFIX ReduceGlobalThread(void *pthread) {
124     static_cast<RabitModel*>(pthread)->ReduceHandler();
125     return NULL;
126   }
127 };
128 }  // namespace ps
129 }  // namespace mshadow
130 #endif  // MSHADOW_RABIT_PS
131 #endif  // MSHADOW_PS_RABIT_INL_H_ // NOLINT(*)
132