1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /**
21  * @file   kvstore_nccl.h
22  * @brief  NCCL implementation of KVStore
23  */
24 #ifndef MXNET_KVSTORE_KVSTORE_NCCL_H_
25 #define MXNET_KVSTORE_KVSTORE_NCCL_H_
26 
27 #if MXNET_USE_NCCL
28 
29 #include <mxnet/kvstore.h>
30 #include <nccl.h>
31 #include <unordered_map>
32 #include <bitset>
33 #include <vector>
34 #include <string>
35 #include <utility>
36 #include <functional>
37 #include <algorithm>
38 #include <tuple>
39 #include "./comm.h"
40 #include "./kvstore_local.h"
41 #include "../common/cuda_utils.h"
42 
43 // NCCL v2 introduces NCCL_MAJOR macro for versioning,
44 // so if there is no such macro defined in nccl.h
45 // then it is NCCL v1
46 #ifndef NCCL_MAJOR
47 #define NCCL_MAJOR 1
48 #endif
49 
50 #if NCCL_MAJOR == 1
51 #define ncclGroupStart()
52 #define ncclGroupEnd()
53 #define ncclNumTypes nccl_NUM_TYPES
54 #endif  // NCCL_MAJOR == 1
55 
56 namespace mxnet {
57 namespace kvstore {
58 
59 /**
60  * \brief store data in local machine using NCCL
61  */
62 class KVStoreNCCL : public KVStoreLocal {
63  public:
KVStoreNCCL()64   KVStoreNCCL() : KVStoreLocal() {
65     // Due to aggregation, we do not use the Comm interface
66     comm_ = nullptr;
67     pinned_ctx_ = Context::CPUPinned(0);
68     inited_ = false;
69   }
70 
~KVStoreNCCL()71   virtual ~KVStoreNCCL() {
72     for (auto e : nccl_data_) {
73       cudaStreamDestroy(e.second.stream);
74       ncclCommDestroy(e.second.comm);
75     }
76   }
77 
78  private:
InitImpl(const std::vector<int> & keys,const std::vector<NDArray> & values)79   void InitImpl(const std::vector<int>& keys,
80                 const std::vector<NDArray>& values) override {
81     for (size_t i = 0; i < keys.size(); ++i) {
82       CHECK(local_.find(keys[i]) == local_.end())
83           << "duplicate init of key " << keys[i];
84       local_[keys[i]] = values[i].Copy(pinned_ctx_);
85       InitKey(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
86     }
87   }
88 
PushImpl(const std::vector<int> & keys,const std::vector<NDArray> & values,int priority)89   void PushImpl(const std::vector<int>& keys,
90                 const std::vector<NDArray>& values,
91                 int priority) override {
92     std::vector<int> uniq_keys;
93     std::vector<std::vector<NDArray> > grouped_vals;
94     // nccl kvstore doesn't support sparse ndarray
95     GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals, true);
96 
97     std::vector<const NDArray*> merged_ptrs;
98     std::vector<NDArray*> local_ptrs;
99     bool nccl_called = false;
100 
101     Reduce(uniq_keys, grouped_vals, priority, &merged_ptrs);
102 
103     for (size_t i = 0; i < uniq_keys.size(); ++i) {
104       int key = uniq_keys[i];
105       if (grouped_vals[i].size() > 1) {
106         // We issued NCCL kernels, need to synchronize
107         nccl_called = true;
108       }
109       auto& merged = *(merged_ptrs[i]);
110       NDArray& local = local_[key];
111       if (updater_ != nullptr) {
112         CHECK(!local.is_none()) << "key " << key << " has not been inited";
113         // if merged is on gpu, we may need copy weight from cpu to gpu
114         if (merged.ctx().dev_mask() != cpu::kDevMask &&
115             local.ctx().dev_mask() == cpu::kDevMask) {
116           local = local.Copy(merged.ctx());
117         }
118       }
119       local_ptrs.push_back(&local);
120     }
121 
122     // Sync after all reductions in a group
123     if (nccl_called) {
124       CommSync(merged_ptrs, priority);
125     }
126 
127     for (size_t i = 0; i < uniq_keys.size(); ++i) {
128       int key = uniq_keys[i];
129       auto& merged = *(merged_ptrs[i]);
130       NDArray& local = *(local_ptrs[i]);
131       if (updater_ != nullptr) {
132         // call the updater with string keys
133         // if string keys are used and str_updater_ is available
134         // otherwise fallback to updater_ which uses int key interface
135         if (key_type_ == kStringKey && str_updater_ != nullptr) {
136           // after all language bindings picks up string interface changes
137           const std::string &str_key = reverse_str_key_dict_[key];
138           str_updater_(str_key, merged,  &local);
139         } else {
140           updater_(key, merged,  &local);
141         }
142       } else {
143         local = merged;
144       }
145     }
146   }
147 
PullImpl(const std::vector<int> & keys,const std::vector<NDArray * > & values,int priority,bool ignore_sparse)148   void PullImpl(const std::vector<int>& keys,
149                 const std::vector<NDArray*>& values,
150                 int priority, bool ignore_sparse) override {
151     CHECK(ignore_sparse) << "nccl kvstore pull doesn't support ignore_sparse=False";
152     std::vector<int> uniq_keys;
153     std::vector<std::vector<NDArray*> > grouped_vals;
154     GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals, true);
155     std::vector<NDArray> locals;
156     bool nccl_called = false;
157 
158     for (size_t i = 0; i < uniq_keys.size(); ++i) {
159       int key = uniq_keys[i];
160       const NDArray& local = local_[key];
161       locals.push_back(local_[key]);
162       CHECK(!local.is_none()) << "key " << key << " has not been inited";
163       if (grouped_vals[i].size() > 1) {
164         // We issued NCCL kernels, need to synchronize
165         nccl_called = true;
166       }
167     }
168 
169     Broadcast(uniq_keys, locals, grouped_vals, priority);
170     // Sync after all broadcasts in a group
171     if (nccl_called) {
172       const std::vector<const NDArray*> values_copy(values.begin(), values.end());
173       CommSync(values_copy, priority);
174     }
175   }
176 
177   void PullRowSparseImpl(const std::vector<int>& keys,
178                          const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
179                          int priority = 0) override {
180     LOG(FATAL) << "NCCL kvstore does not support sparse storage type";
181   }
182 
SetGradientCompression(const std::vector<std::pair<std::string,std::string>> & kwargs)183   void SetGradientCompression(const std::vector<std::pair<std::string, std::string> >
184                                       & kwargs) override {
185     LOG(FATAL) << "NCCL kvstore does not support gradient compression";
186   }
187 
188  protected:
189   /**
190    * \brief group values on keys
191    */
192   template <typename T>
GroupKVPairsHelper(const std::vector<int> & keys,const std::vector<T> & values,std::vector<int> * uniq_keys,std::vector<std::vector<T>> * grouped_vals,bool ignore_sparse)193   void GroupKVPairsHelper(const std::vector<int>& keys,
194                           const std::vector<T>& values,
195                           std::vector<int> *uniq_keys,
196                           std::vector<std::vector<T>> *grouped_vals,
197                           bool ignore_sparse) {
198     // check if the storage type of a value is valid
199     auto validator = [this](const int key, const T nd, bool ignore_sparse) -> bool {
200       CHECK(ignore_sparse) << "nccl kvstore pull doesn't support ignore_sparse=False";
201       auto stype = ptr(nd)->storage_type();
202       // valid NDArray
203       if (stype == kDefaultStorage) return true;
204       // invalid NDArray, abort
205       LOG(FATAL) << "NCCL kvstore does not support sparse storage type";
206       return false;
207     };
208     GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator, ignore_sparse);
209   }
210 
211  private:
212   // Aggregated reductions
Reduce(const std::vector<int> keys,const std::vector<std::vector<NDArray>> & srcs,int priority,std::vector<const NDArray * > * merged_ptrs)213   virtual void Reduce(const std::vector<int> keys,
214                       const std::vector<std::vector<NDArray>>& srcs,
215                       int priority,
216                       std::vector<const NDArray*>* merged_ptrs) {
217     std::vector<size_t> root_ids(keys.size());
218     std::vector<NDArray> reduces(keys.size());
219     merged_ptrs->resize(keys.size());
220     std::vector<Engine::VarHandle> const_vars;
221     std::vector<Engine::VarHandle> mutate_vars;
222 
223     for (size_t k = 0; k < keys.size(); ++k) {
224       auto& key = keys[k];
225       auto& src = srcs[k];
226       auto& root_id = root_ids[k];
227 
228       // avoid extra copy for single device, but it may bring problems for
229       // abnormal usage of kvstore
230       if (src.size() == 1) {
231         (*merged_ptrs)[k] = &src[0];
232         continue;
233       }
234 
235       if (!inited_) {
236         std::vector<Context> devs;
237         for (const auto& a : src) {
238           devs.push_back(a.ctx());
239         }
240         InitNCCL(devs);
241         InitMergeBuffer(devs);
242       }
243 
244       // Check whether we got the same set of devices
245       std::vector<int> dev_ids;
246       for (auto e : src) {
247         dev_ids.push_back(e.ctx().dev_id);
248       }
249       std::sort(dev_ids.begin(), dev_ids.end());
250       CHECK(device_ids_ == dev_ids) << "NCCL KVStore supports only single set of devices";
251 
252       auto& buf = merge_buf_[key];
253       int root = buf.merged.ctx().dev_id;
254       root_id = FindRootId(src, root);
255 
256       auto& reduce = buf.merged;
257       (*merged_ptrs)[k] = &reduce;
258       // Need to pass NDArrays by value to the engine
259       reduces[k] = reduce;
260 
261       for (size_t i = 0; i < src.size(); ++i) {
262         const_vars.push_back(src[i].var());
263       }
264       mutate_vars.push_back(reduce.var());
265     }
266 
267     Engine::Get()->PushSync([srcs, reduces, root_ids, this](RunContext rctx) {
268         std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU));
269 #if (NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR > 1))
270         ncclGroupStart();
271 #endif
272         for (size_t k = 0; k < srcs.size(); ++k) {
273           auto& src = srcs[k];
274           auto& root_id = root_ids[k];
275           auto& reduce = reduces[k];
276           if (src.size() <= 1) {
277             continue;
278           }
279           int root = nccl_data_[src[root_id].ctx().dev_id].rank;
280           ncclGroupStart();
281           for (size_t i = 0; i < src.size(); ++i) {
282             NCCLEntry cur = nccl_data_[src[i].ctx().dev_id];
283             if (i == root_id) {
284             MSHADOW_TYPE_SWITCH(src[i].dtype(), DType,
285             ncclReduce(src[i].data().dptr<DType>(),
286                               reduce.data().dptr<DType>(),
287                               src[i].shape().Size(),
288                               GetNCCLType(src[i].dtype()),
289                               ncclSum,
290                               root,
291                               cur.comm,
292                               cur.stream););
293             } else {
294             MSHADOW_TYPE_SWITCH(src[i].dtype(), DType,
295             ncclReduce(src[i].data().dptr<DType>(),
296                               nullptr,
297                               src[i].shape().Size(),
298                               GetNCCLType(src[i].dtype()),
299                               ncclSum,
300                               root,
301                               cur.comm,
302                               cur.stream););
303             }
304           }
305           ncclGroupEnd();
306         }
307 #if (NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR > 1))
308         ncclGroupEnd();
309 #endif
310       },
311       Context::CPU(),
312       const_vars,
313       mutate_vars,
314       FnProperty::kCPUPrioritized,
315       priority,
316       "KVStoreReduce");
317   }
318 
Broadcast(const std::vector<int> keys,const std::vector<NDArray> & srcs,const std::vector<std::vector<NDArray * >> & dsts,int priority)319   virtual void Broadcast(const std::vector<int> keys,
320       const std::vector<NDArray>& srcs,
321       const std::vector<std::vector<NDArray*>>& dsts,
322       int priority) {
323     std::vector<size_t> root_ids(keys.size());
324     std::vector<Engine::VarHandle> const_vars;
325     std::vector<Engine::VarHandle> mutable_vars;
326 
327     for (size_t k = 0; k < keys.size(); ++k) {
328       auto& key = keys[k];
329       auto& src = srcs[k];
330       auto& dst = dsts[k];
331       auto& root_id = root_ids[k];
332 
333       if (!inited_) {
334         // copy to a random device first
335         int dev_id = key % dst.size();
336         CopyFromTo(src, *dst[dev_id], priority);
337         for (size_t i = 0; i < dst.size(); ++i) {
338           if (i != static_cast<size_t>(dev_id)) {
339             CopyFromTo(*dst[dev_id], *dst[i], priority);
340           }
341         }
342       } else {
343         auto& buf = merge_buf_[key];
344         int root = src.ctx().dev_id;
345         assert(root == buf.merged.ctx().dev_id);
346         root_id = FindRootId(dst, root);
347 
348         // Check whether we got the same set of devices
349         std::vector<int> dev_ids;
350         for (size_t i = 0; i < dst.size(); ++i) {
351           auto& bcast = (i == root_id) ? src : *dst[i];
352           dev_ids.push_back(bcast.ctx().dev_id);
353         }
354         std::sort(dev_ids.begin(), dev_ids.end());
355         CHECK(device_ids_ == dev_ids) << "NCCL KVStore supports only single set of devices";
356 
357         // On root perform simple copy to the output
358         CopyFromTo(src, *dst[root_id], priority);
359         for (size_t i = 0; i < dst.size(); ++i) {
360           if ( i != root_id)
361             mutable_vars.push_back(dst[i]->var());
362         }
363         const_vars.push_back(src.var());
364       }
365     }
366 
367     // If not yet inited, then all work is already scheduled
368     if (!inited_) {
369       return;
370     }
371 
372     // We need to capture NDArrays by value
373     // in order to push to the engine
374     std::vector<std::vector<NDArray>> broadcasts(dsts.size());
375     for (size_t i = 0; i < dsts.size(); ++i) {
376       auto& broadcast = broadcasts[i];
377       broadcast.resize(dsts[i].size());
378       for (size_t j = 0; j < dsts[i].size(); ++j) {
379         broadcast[j] = *(dsts[i][j]);
380       }
381     }
382 
383     Engine::Get()->PushSync([srcs, broadcasts, root_ids, this](RunContext rctx) {
384         std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU));
385 #if (NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR > 1))
386         ncclGroupStart();
387 #endif
388         for (size_t k = 0; k < srcs.size(); ++k) {
389           auto& src = srcs[k];
390           auto& dst = broadcasts[k];
391           auto& root_id = root_ids[k];
392           if (dst.size() <= 1) {
393             continue;
394           }
395 
396           int root = nccl_data_[src.ctx().dev_id].rank;
397           ncclGroupStart();
398           for (size_t i = 0; i < dst.size(); ++i) {
399             auto& bcast = (i == root_id) ? src : dst[i];
400             NCCLEntry cur = nccl_data_[bcast.ctx().dev_id];
401             MSHADOW_TYPE_SWITCH(bcast.dtype(), DType,
402                 ncclBcast(bcast.data().dptr<DType>(),
403                   bcast.shape().Size(),
404                   GetNCCLType(bcast.dtype()),
405                   root,
406                   cur.comm,
407                   cur.stream););
408           }
409           ncclGroupEnd();
410         }
411 #if (NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR > 1))
412         ncclGroupEnd();
413 #endif
414       },
415       Context::CPU(),
416       const_vars,
417       mutable_vars,
418       FnProperty::kCPUPrioritized,
419       priority,
420       "KVStoreBCast");
421   }
422 
423   // Function that waits for NCCL collective to complete
424   template <typename T>
CommSync(const std::vector<T> & dst,int priority)425   void CommSync(const std::vector<T>& dst, int priority) {
426     std::vector<Engine::VarHandle> mutate_vars;
427     for (size_t i = 0; i < dst.size(); ++i) {
428         mutate_vars.push_back(ptr(dst[i])->var());
429     }
430     Engine::Get()->PushSync([this](RunContext rctx) {
431         mxnet::common::cuda::DeviceStore device_store;
432         for (auto cur : nccl_data_) {
433           device_store.SetDevice(cur.second.dev_id);
434           CUDA_CALL(cudaStreamSynchronize(cur.second.stream));
435         }
436       },
437       Context::CPU(),
438       {},
439       mutate_vars,
440       FnProperty::kCPUPrioritized,
441       priority,
442       "KVStoreStreamSync");
443   }
444 
445   // Initialize single key
446   void InitKey(int key, const NDArrayStorageType stype, const mxnet::TShape& shape,
447             int dtype = mshadow::kFloat32) {
448     if (stype == kDefaultStorage) {
449       key_attrs_.push_back(std::make_tuple(key, shape, dtype));
450     } else {
451       LOG(FATAL) << "NCCL KVStore does not support sparse storage type";
452     }
453   }
454 
GetNCCLType(int dtype)455   ncclDataType_t GetNCCLType(int dtype) {
456     switch (dtype) {
457       case mshadow::kFloat32:
458         return ncclFloat;
459       case mshadow::kFloat16:
460         return ncclHalf;
461       case mshadow::kFloat64:
462         return ncclDouble;
463       case mshadow::kUint8:
464         return ncclChar;
465       case mshadow::kInt32:
466         return ncclInt;
467       case mshadow::kInt64:
468         return ncclInt64;
469       default:
470         LOG(FATAL) << "Unknown type passed to NCCL KVStore";
471     }
472     return ncclNumTypes;
473   }
474 
InitNCCL(const std::vector<Context> & devs)475   void InitNCCL(const std::vector<Context>& devs) {
476     for (size_t i = 0; i < devs.size(); ++i) {
477       device_ids_.push_back(devs[i].dev_id);
478     }
479     std::sort(device_ids_.begin(), device_ids_.end());
480     std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU));
481     std::vector<ncclComm_t> comms(devs.size());
482     ncclCommInitAll(&(comms[0]), devs.size(), &(device_ids_[0]));
483     mxnet::common::cuda::DeviceStore device_store;
484     for (size_t i = 0; i < devs.size(); ++i) {
485       NCCLEntry e;
486       e.dev_id = device_ids_[i];
487       e.comm = comms[i];
488       e.rank = i;
489       device_store.SetDevice(e.dev_id);
490       cudaStreamCreate(&(e.stream));
491       nccl_data_[device_ids_[i]] = e;
492     }
493   }
494 
495   using KeyAttrs = std::tuple<int, mxnet::TShape, int>;
InitMergeBuffer(const std::vector<Context> & devs)496   void InitMergeBuffer(const std::vector<Context>& devs) {
497     for (size_t i = 0; i < key_attrs_.size(); ++i) {
498       int key  = std::get<0>(key_attrs_[i]);
499       mxnet::TShape s = std::get<1>(key_attrs_[i]);
500       int type = std::get<2>(key_attrs_[i]);
501       auto& buf = merge_buf_[key];
502       // always use devs[0] as root
503       buf.merged = NDArray(s, devs[0], false, type);
504     }
505     inited_ = true;
506   }
507 
508   // Functions that enable templates to work on both references
509   // and pointers
510   template<typename T>
ptr(const T & obj)511   const T * ptr(const T & obj) { return &obj; }
512 
513   template<typename T>
ptr(T * obj)514   const T * ptr(T * obj) { return obj; }
515 
516   // Find which element of the vector
517   // corresponds to root dev_id
518   template <typename T>
FindRootId(const std::vector<T> & vec,int root)519   size_t FindRootId(const std::vector<T>& vec, int root) {
520     size_t root_id = -1;
521     for (size_t i = 0; i < vec.size(); ++i) {
522       if (ptr(vec[i])->ctx().dev_id == root) {
523         root_id = i;
524         break;
525       }
526     }
527     return root_id;
528   }
529 
530   std::vector<KeyAttrs> key_attrs_;
531   /// \brief temporal space for pushing and pulling
532   struct BufferEntry {
533     /// \brief the merged value
534     NDArray merged;
535   };
536   struct NCCLEntry {
537     /// \brief device ID
538     int dev_id;
539     /// \brief NCCL commmunicator
540     ncclComm_t comm;
541     /// \brief NCCL rank
542     int rank;
543     /// \brief GPU stream to use with NCCL
544     cudaStream_t stream;
545   };
546   std::unordered_map<int, BufferEntry> merge_buf_;
547   std::unordered_map<int, NCCLEntry> nccl_data_;
548   bool inited_;
549   // \brief devices used with this KVStore
550   std::vector<int> device_ids_;
551 };
552 }  // namespace kvstore
553 }  // namespace mxnet
554 #endif  // MXNET_USE_NCCL
555 #endif  // MXNET_KVSTORE_KVSTORE_NCCL_H_
556