1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file ps_rabit-inl.h 22 * \brief distributed version of PS using BSP 23 * synchronization in the backend 24 * \author Tianqi Chen, Mu Li 25 */ 26 #ifndef MSHADOW_PS_RABIT_INL_H_ // NOLINT(*) 27 #define MSHADOW_PS_RABIT_INL_H_ // NOLINT(*) 28 #include <vector> 29 #include "./mshadow_ps.h" 30 #include "./ps_local-inl.h" 31 32 #if MSHADOW_RABIT_PS 33 #include <rabit.h> 34 namespace mshadow { 35 namespace ps { 36 // multi-threaded implementation of 37 template<typename xpu, typename DType> 38 class RabitModel : public LocalModel<xpu, DType> { 39 public: 40 // parent type 41 typedef LocalModel<xpu, DType> Parent; 42 // constructor RabitModel()43 RabitModel() { 44 // enforce usage of fifo queue 45 this->use_fifo_push_queue = 1; 46 destroy_reduce_thread_ = false; 47 disable_allreduce_ = 0; 48 this->init_reducer_ = 0; 49 } ~RabitModel(void)50 virtual ~RabitModel(void) { 51 Parent::Destroy(); 52 if (init_reducer_ != 0) { 53 destroy_reduce_thread_ = true; 54 reduce_queue_.Abort(1); 55 thread_reduce_handler_.Join(); 56 reduce_queue_.Destroy(); 57 } 58 } 59 // initialize the parameter server Init(const std::vector<int> & devices)60 virtual void Init(const std::vector<int> &devices) { 61 this->use_fifo_push_queue = 1; 62 // use fifo 63 reduce_queue_.Init(true); 64 thread_reduce_handler_.Start(ReduceGlobalThread, this); 65 init_reducer_ = 1; 66 // initialize other things 67 Parent::Init(devices); 68 } 69 // set parameters SetParam(const char * name,const char * val)70 virtual void SetParam(const char *name, const char *val) { 71 if (!strcmp(name, "msg:disable_allreduce")) { 72 disable_allreduce_ = atoi(val); 73 } 74 Parent::SetParam(name, val); 75 } 76 // override this function, to use parameter server HandlePushFinish(Tensor<cpu,3,DType> data,int key)77 virtual void HandlePushFinish(Tensor<cpu, 3, DType> data, 78 int key) { 79 // summation the data fron all devices 80 LocalModel<xpu, DType>::ReduceSum(data); 81 CHECK_EQ(data[0].CheckContiguous(), true) << "data must be contiguous"; 82 ReduceTask tsk; 83 tsk.data = data[0]; tsk.key = key; 84 reduce_queue_.Push(tsk, 0); 85 } 86 87 private: 88 // reduce task 89 struct ReduceTask { 90 int key; 91 mshadow::Tensor<cpu, 2> data; 92 }; 93 // destroy reduce 94 bool destroy_reduce_thread_; 95 // whether reducer is initialized 96 int init_reducer_; 97 // check disable_allreduce functionalities 98 int disable_allreduce_; 99 // reduce handler thread 100 utils::Thread thread_reduce_handler_; 101 // queue for allreduce task 102 utils::ThreadPQueue<ReduceTask> reduce_queue_; 103 // reduce handler ReduceHandler(void)104 inline void ReduceHandler(void) { 105 while (!destroy_reduce_thread_) { 106 ReduceTask tsk; 107 if (reduce_queue_.Pop(&tsk)) { 108 CHECK_EQ(disable_allreduce_, 0) << "Allreduce disabled error"; 109 int key = tsk.key; 110 rabit::Allreduce<rabit::op::Max>(&key, 1); 111 CHECK_EQ(key, tsk.key) << "Allreduce not concensus"; 112 rabit::Allreduce<rabit::op::Sum> 113 (tsk.data.dptr_, tsk.data.MSize()); 114 tsk.data *= 1.0f / rabit::GetWorldSize(); 115 CHECK_EQ(disable_allreduce_, 0) << "Allreduce disabled error"; 116 this->HandleReduceFinish(tsk.data, tsk.key); 117 } else { 118 CHECK_EQ(destroy_reduce_thread_, true) << "abort but not destroy"; 119 } 120 } 121 } 122 /*!\brief entry point of reduce thread */ ReduceGlobalThread(void * pthread)123 inline static MSHADOW_THREAD_PREFIX ReduceGlobalThread(void *pthread) { 124 static_cast<RabitModel*>(pthread)->ReduceHandler(); 125 return NULL; 126 } 127 }; 128 } // namespace ps 129 } // namespace mshadow 130 #endif // MSHADOW_RABIT_PS 131 #endif // MSHADOW_PS_RABIT_INL_H_ // NOLINT(*) 132