1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /** 21 * @file kvstore_nccl.h 22 * @brief NCCL implementation of KVStore 23 */ 24 #ifndef MXNET_KVSTORE_KVSTORE_NCCL_H_ 25 #define MXNET_KVSTORE_KVSTORE_NCCL_H_ 26 27 #if MXNET_USE_NCCL 28 29 #include <mxnet/kvstore.h> 30 #include <nccl.h> 31 #include <unordered_map> 32 #include <bitset> 33 #include <vector> 34 #include <string> 35 #include <utility> 36 #include <functional> 37 #include <algorithm> 38 #include <tuple> 39 #include "./comm.h" 40 #include "./kvstore_local.h" 41 #include "../common/cuda_utils.h" 42 43 // NCCL v2 introduces NCCL_MAJOR macro for versioning, 44 // so if there is no such macro defined in nccl.h 45 // then it is NCCL v1 46 #ifndef NCCL_MAJOR 47 #define NCCL_MAJOR 1 48 #endif 49 50 #if NCCL_MAJOR == 1 51 #define ncclGroupStart() 52 #define ncclGroupEnd() 53 #define ncclNumTypes nccl_NUM_TYPES 54 #endif // NCCL_MAJOR == 1 55 56 namespace mxnet { 57 namespace kvstore { 58 59 /** 60 * \brief store data in local machine using NCCL 61 */ 62 class KVStoreNCCL : public KVStoreLocal { 63 public: KVStoreNCCL()64 KVStoreNCCL() : KVStoreLocal() { 65 // Due to aggregation, we do not use the Comm interface 66 comm_ = nullptr; 67 pinned_ctx_ = Context::CPUPinned(0); 68 inited_ = false; 69 } 70 ~KVStoreNCCL()71 virtual ~KVStoreNCCL() { 72 for (auto e : nccl_data_) { 73 cudaStreamDestroy(e.second.stream); 74 ncclCommDestroy(e.second.comm); 75 } 76 } 77 78 private: InitImpl(const std::vector<int> & keys,const std::vector<NDArray> & values)79 void InitImpl(const std::vector<int>& keys, 80 const std::vector<NDArray>& values) override { 81 for (size_t i = 0; i < keys.size(); ++i) { 82 CHECK(local_.find(keys[i]) == local_.end()) 83 << "duplicate init of key " << keys[i]; 84 local_[keys[i]] = values[i].Copy(pinned_ctx_); 85 InitKey(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); 86 } 87 } 88 PushImpl(const std::vector<int> & keys,const std::vector<NDArray> & values,int priority)89 void PushImpl(const std::vector<int>& keys, 90 const std::vector<NDArray>& values, 91 int priority) override { 92 std::vector<int> uniq_keys; 93 std::vector<std::vector<NDArray> > grouped_vals; 94 // nccl kvstore doesn't support sparse ndarray 95 GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals, true); 96 97 std::vector<const NDArray*> merged_ptrs; 98 std::vector<NDArray*> local_ptrs; 99 bool nccl_called = false; 100 101 Reduce(uniq_keys, grouped_vals, priority, &merged_ptrs); 102 103 for (size_t i = 0; i < uniq_keys.size(); ++i) { 104 int key = uniq_keys[i]; 105 if (grouped_vals[i].size() > 1) { 106 // We issued NCCL kernels, need to synchronize 107 nccl_called = true; 108 } 109 auto& merged = *(merged_ptrs[i]); 110 NDArray& local = local_[key]; 111 if (updater_ != nullptr) { 112 CHECK(!local.is_none()) << "key " << key << " has not been inited"; 113 // if merged is on gpu, we may need copy weight from cpu to gpu 114 if (merged.ctx().dev_mask() != cpu::kDevMask && 115 local.ctx().dev_mask() == cpu::kDevMask) { 116 local = local.Copy(merged.ctx()); 117 } 118 } 119 local_ptrs.push_back(&local); 120 } 121 122 // Sync after all reductions in a group 123 if (nccl_called) { 124 CommSync(merged_ptrs, priority); 125 } 126 127 for (size_t i = 0; i < uniq_keys.size(); ++i) { 128 int key = uniq_keys[i]; 129 auto& merged = *(merged_ptrs[i]); 130 NDArray& local = *(local_ptrs[i]); 131 if (updater_ != nullptr) { 132 // call the updater with string keys 133 // if string keys are used and str_updater_ is available 134 // otherwise fallback to updater_ which uses int key interface 135 if (key_type_ == kStringKey && str_updater_ != nullptr) { 136 // after all language bindings picks up string interface changes 137 const std::string &str_key = reverse_str_key_dict_[key]; 138 str_updater_(str_key, merged, &local); 139 } else { 140 updater_(key, merged, &local); 141 } 142 } else { 143 local = merged; 144 } 145 } 146 } 147 PullImpl(const std::vector<int> & keys,const std::vector<NDArray * > & values,int priority,bool ignore_sparse)148 void PullImpl(const std::vector<int>& keys, 149 const std::vector<NDArray*>& values, 150 int priority, bool ignore_sparse) override { 151 CHECK(ignore_sparse) << "nccl kvstore pull doesn't support ignore_sparse=False"; 152 std::vector<int> uniq_keys; 153 std::vector<std::vector<NDArray*> > grouped_vals; 154 GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals, true); 155 std::vector<NDArray> locals; 156 bool nccl_called = false; 157 158 for (size_t i = 0; i < uniq_keys.size(); ++i) { 159 int key = uniq_keys[i]; 160 const NDArray& local = local_[key]; 161 locals.push_back(local_[key]); 162 CHECK(!local.is_none()) << "key " << key << " has not been inited"; 163 if (grouped_vals[i].size() > 1) { 164 // We issued NCCL kernels, need to synchronize 165 nccl_called = true; 166 } 167 } 168 169 Broadcast(uniq_keys, locals, grouped_vals, priority); 170 // Sync after all broadcasts in a group 171 if (nccl_called) { 172 const std::vector<const NDArray*> values_copy(values.begin(), values.end()); 173 CommSync(values_copy, priority); 174 } 175 } 176 177 void PullRowSparseImpl(const std::vector<int>& keys, 178 const std::vector<std::pair<NDArray*, NDArray>>& val_rowids, 179 int priority = 0) override { 180 LOG(FATAL) << "NCCL kvstore does not support sparse storage type"; 181 } 182 SetGradientCompression(const std::vector<std::pair<std::string,std::string>> & kwargs)183 void SetGradientCompression(const std::vector<std::pair<std::string, std::string> > 184 & kwargs) override { 185 LOG(FATAL) << "NCCL kvstore does not support gradient compression"; 186 } 187 188 protected: 189 /** 190 * \brief group values on keys 191 */ 192 template <typename T> GroupKVPairsHelper(const std::vector<int> & keys,const std::vector<T> & values,std::vector<int> * uniq_keys,std::vector<std::vector<T>> * grouped_vals,bool ignore_sparse)193 void GroupKVPairsHelper(const std::vector<int>& keys, 194 const std::vector<T>& values, 195 std::vector<int> *uniq_keys, 196 std::vector<std::vector<T>> *grouped_vals, 197 bool ignore_sparse) { 198 // check if the storage type of a value is valid 199 auto validator = [this](const int key, const T nd, bool ignore_sparse) -> bool { 200 CHECK(ignore_sparse) << "nccl kvstore pull doesn't support ignore_sparse=False"; 201 auto stype = ptr(nd)->storage_type(); 202 // valid NDArray 203 if (stype == kDefaultStorage) return true; 204 // invalid NDArray, abort 205 LOG(FATAL) << "NCCL kvstore does not support sparse storage type"; 206 return false; 207 }; 208 GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator, ignore_sparse); 209 } 210 211 private: 212 // Aggregated reductions Reduce(const std::vector<int> keys,const std::vector<std::vector<NDArray>> & srcs,int priority,std::vector<const NDArray * > * merged_ptrs)213 virtual void Reduce(const std::vector<int> keys, 214 const std::vector<std::vector<NDArray>>& srcs, 215 int priority, 216 std::vector<const NDArray*>* merged_ptrs) { 217 std::vector<size_t> root_ids(keys.size()); 218 std::vector<NDArray> reduces(keys.size()); 219 merged_ptrs->resize(keys.size()); 220 std::vector<Engine::VarHandle> const_vars; 221 std::vector<Engine::VarHandle> mutate_vars; 222 223 for (size_t k = 0; k < keys.size(); ++k) { 224 auto& key = keys[k]; 225 auto& src = srcs[k]; 226 auto& root_id = root_ids[k]; 227 228 // avoid extra copy for single device, but it may bring problems for 229 // abnormal usage of kvstore 230 if (src.size() == 1) { 231 (*merged_ptrs)[k] = &src[0]; 232 continue; 233 } 234 235 if (!inited_) { 236 std::vector<Context> devs; 237 for (const auto& a : src) { 238 devs.push_back(a.ctx()); 239 } 240 InitNCCL(devs); 241 InitMergeBuffer(devs); 242 } 243 244 // Check whether we got the same set of devices 245 std::vector<int> dev_ids; 246 for (auto e : src) { 247 dev_ids.push_back(e.ctx().dev_id); 248 } 249 std::sort(dev_ids.begin(), dev_ids.end()); 250 CHECK(device_ids_ == dev_ids) << "NCCL KVStore supports only single set of devices"; 251 252 auto& buf = merge_buf_[key]; 253 int root = buf.merged.ctx().dev_id; 254 root_id = FindRootId(src, root); 255 256 auto& reduce = buf.merged; 257 (*merged_ptrs)[k] = &reduce; 258 // Need to pass NDArrays by value to the engine 259 reduces[k] = reduce; 260 261 for (size_t i = 0; i < src.size(); ++i) { 262 const_vars.push_back(src[i].var()); 263 } 264 mutate_vars.push_back(reduce.var()); 265 } 266 267 Engine::Get()->PushSync([srcs, reduces, root_ids, this](RunContext rctx) { 268 std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU)); 269 #if (NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR > 1)) 270 ncclGroupStart(); 271 #endif 272 for (size_t k = 0; k < srcs.size(); ++k) { 273 auto& src = srcs[k]; 274 auto& root_id = root_ids[k]; 275 auto& reduce = reduces[k]; 276 if (src.size() <= 1) { 277 continue; 278 } 279 int root = nccl_data_[src[root_id].ctx().dev_id].rank; 280 ncclGroupStart(); 281 for (size_t i = 0; i < src.size(); ++i) { 282 NCCLEntry cur = nccl_data_[src[i].ctx().dev_id]; 283 if (i == root_id) { 284 MSHADOW_TYPE_SWITCH(src[i].dtype(), DType, 285 ncclReduce(src[i].data().dptr<DType>(), 286 reduce.data().dptr<DType>(), 287 src[i].shape().Size(), 288 GetNCCLType(src[i].dtype()), 289 ncclSum, 290 root, 291 cur.comm, 292 cur.stream);); 293 } else { 294 MSHADOW_TYPE_SWITCH(src[i].dtype(), DType, 295 ncclReduce(src[i].data().dptr<DType>(), 296 nullptr, 297 src[i].shape().Size(), 298 GetNCCLType(src[i].dtype()), 299 ncclSum, 300 root, 301 cur.comm, 302 cur.stream);); 303 } 304 } 305 ncclGroupEnd(); 306 } 307 #if (NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR > 1)) 308 ncclGroupEnd(); 309 #endif 310 }, 311 Context::CPU(), 312 const_vars, 313 mutate_vars, 314 FnProperty::kCPUPrioritized, 315 priority, 316 "KVStoreReduce"); 317 } 318 Broadcast(const std::vector<int> keys,const std::vector<NDArray> & srcs,const std::vector<std::vector<NDArray * >> & dsts,int priority)319 virtual void Broadcast(const std::vector<int> keys, 320 const std::vector<NDArray>& srcs, 321 const std::vector<std::vector<NDArray*>>& dsts, 322 int priority) { 323 std::vector<size_t> root_ids(keys.size()); 324 std::vector<Engine::VarHandle> const_vars; 325 std::vector<Engine::VarHandle> mutable_vars; 326 327 for (size_t k = 0; k < keys.size(); ++k) { 328 auto& key = keys[k]; 329 auto& src = srcs[k]; 330 auto& dst = dsts[k]; 331 auto& root_id = root_ids[k]; 332 333 if (!inited_) { 334 // copy to a random device first 335 int dev_id = key % dst.size(); 336 CopyFromTo(src, *dst[dev_id], priority); 337 for (size_t i = 0; i < dst.size(); ++i) { 338 if (i != static_cast<size_t>(dev_id)) { 339 CopyFromTo(*dst[dev_id], *dst[i], priority); 340 } 341 } 342 } else { 343 auto& buf = merge_buf_[key]; 344 int root = src.ctx().dev_id; 345 assert(root == buf.merged.ctx().dev_id); 346 root_id = FindRootId(dst, root); 347 348 // Check whether we got the same set of devices 349 std::vector<int> dev_ids; 350 for (size_t i = 0; i < dst.size(); ++i) { 351 auto& bcast = (i == root_id) ? src : *dst[i]; 352 dev_ids.push_back(bcast.ctx().dev_id); 353 } 354 std::sort(dev_ids.begin(), dev_ids.end()); 355 CHECK(device_ids_ == dev_ids) << "NCCL KVStore supports only single set of devices"; 356 357 // On root perform simple copy to the output 358 CopyFromTo(src, *dst[root_id], priority); 359 for (size_t i = 0; i < dst.size(); ++i) { 360 if ( i != root_id) 361 mutable_vars.push_back(dst[i]->var()); 362 } 363 const_vars.push_back(src.var()); 364 } 365 } 366 367 // If not yet inited, then all work is already scheduled 368 if (!inited_) { 369 return; 370 } 371 372 // We need to capture NDArrays by value 373 // in order to push to the engine 374 std::vector<std::vector<NDArray>> broadcasts(dsts.size()); 375 for (size_t i = 0; i < dsts.size(); ++i) { 376 auto& broadcast = broadcasts[i]; 377 broadcast.resize(dsts[i].size()); 378 for (size_t j = 0; j < dsts[i].size(); ++j) { 379 broadcast[j] = *(dsts[i][j]); 380 } 381 } 382 383 Engine::Get()->PushSync([srcs, broadcasts, root_ids, this](RunContext rctx) { 384 std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU)); 385 #if (NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR > 1)) 386 ncclGroupStart(); 387 #endif 388 for (size_t k = 0; k < srcs.size(); ++k) { 389 auto& src = srcs[k]; 390 auto& dst = broadcasts[k]; 391 auto& root_id = root_ids[k]; 392 if (dst.size() <= 1) { 393 continue; 394 } 395 396 int root = nccl_data_[src.ctx().dev_id].rank; 397 ncclGroupStart(); 398 for (size_t i = 0; i < dst.size(); ++i) { 399 auto& bcast = (i == root_id) ? src : dst[i]; 400 NCCLEntry cur = nccl_data_[bcast.ctx().dev_id]; 401 MSHADOW_TYPE_SWITCH(bcast.dtype(), DType, 402 ncclBcast(bcast.data().dptr<DType>(), 403 bcast.shape().Size(), 404 GetNCCLType(bcast.dtype()), 405 root, 406 cur.comm, 407 cur.stream);); 408 } 409 ncclGroupEnd(); 410 } 411 #if (NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR > 1)) 412 ncclGroupEnd(); 413 #endif 414 }, 415 Context::CPU(), 416 const_vars, 417 mutable_vars, 418 FnProperty::kCPUPrioritized, 419 priority, 420 "KVStoreBCast"); 421 } 422 423 // Function that waits for NCCL collective to complete 424 template <typename T> CommSync(const std::vector<T> & dst,int priority)425 void CommSync(const std::vector<T>& dst, int priority) { 426 std::vector<Engine::VarHandle> mutate_vars; 427 for (size_t i = 0; i < dst.size(); ++i) { 428 mutate_vars.push_back(ptr(dst[i])->var()); 429 } 430 Engine::Get()->PushSync([this](RunContext rctx) { 431 mxnet::common::cuda::DeviceStore device_store; 432 for (auto cur : nccl_data_) { 433 device_store.SetDevice(cur.second.dev_id); 434 CUDA_CALL(cudaStreamSynchronize(cur.second.stream)); 435 } 436 }, 437 Context::CPU(), 438 {}, 439 mutate_vars, 440 FnProperty::kCPUPrioritized, 441 priority, 442 "KVStoreStreamSync"); 443 } 444 445 // Initialize single key 446 void InitKey(int key, const NDArrayStorageType stype, const mxnet::TShape& shape, 447 int dtype = mshadow::kFloat32) { 448 if (stype == kDefaultStorage) { 449 key_attrs_.push_back(std::make_tuple(key, shape, dtype)); 450 } else { 451 LOG(FATAL) << "NCCL KVStore does not support sparse storage type"; 452 } 453 } 454 GetNCCLType(int dtype)455 ncclDataType_t GetNCCLType(int dtype) { 456 switch (dtype) { 457 case mshadow::kFloat32: 458 return ncclFloat; 459 case mshadow::kFloat16: 460 return ncclHalf; 461 case mshadow::kFloat64: 462 return ncclDouble; 463 case mshadow::kUint8: 464 return ncclChar; 465 case mshadow::kInt32: 466 return ncclInt; 467 case mshadow::kInt64: 468 return ncclInt64; 469 default: 470 LOG(FATAL) << "Unknown type passed to NCCL KVStore"; 471 } 472 return ncclNumTypes; 473 } 474 InitNCCL(const std::vector<Context> & devs)475 void InitNCCL(const std::vector<Context>& devs) { 476 for (size_t i = 0; i < devs.size(); ++i) { 477 device_ids_.push_back(devs[i].dev_id); 478 } 479 std::sort(device_ids_.begin(), device_ids_.end()); 480 std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU)); 481 std::vector<ncclComm_t> comms(devs.size()); 482 ncclCommInitAll(&(comms[0]), devs.size(), &(device_ids_[0])); 483 mxnet::common::cuda::DeviceStore device_store; 484 for (size_t i = 0; i < devs.size(); ++i) { 485 NCCLEntry e; 486 e.dev_id = device_ids_[i]; 487 e.comm = comms[i]; 488 e.rank = i; 489 device_store.SetDevice(e.dev_id); 490 cudaStreamCreate(&(e.stream)); 491 nccl_data_[device_ids_[i]] = e; 492 } 493 } 494 495 using KeyAttrs = std::tuple<int, mxnet::TShape, int>; InitMergeBuffer(const std::vector<Context> & devs)496 void InitMergeBuffer(const std::vector<Context>& devs) { 497 for (size_t i = 0; i < key_attrs_.size(); ++i) { 498 int key = std::get<0>(key_attrs_[i]); 499 mxnet::TShape s = std::get<1>(key_attrs_[i]); 500 int type = std::get<2>(key_attrs_[i]); 501 auto& buf = merge_buf_[key]; 502 // always use devs[0] as root 503 buf.merged = NDArray(s, devs[0], false, type); 504 } 505 inited_ = true; 506 } 507 508 // Functions that enable templates to work on both references 509 // and pointers 510 template<typename T> ptr(const T & obj)511 const T * ptr(const T & obj) { return &obj; } 512 513 template<typename T> ptr(T * obj)514 const T * ptr(T * obj) { return obj; } 515 516 // Find which element of the vector 517 // corresponds to root dev_id 518 template <typename T> FindRootId(const std::vector<T> & vec,int root)519 size_t FindRootId(const std::vector<T>& vec, int root) { 520 size_t root_id = -1; 521 for (size_t i = 0; i < vec.size(); ++i) { 522 if (ptr(vec[i])->ctx().dev_id == root) { 523 root_id = i; 524 break; 525 } 526 } 527 return root_id; 528 } 529 530 std::vector<KeyAttrs> key_attrs_; 531 /// \brief temporal space for pushing and pulling 532 struct BufferEntry { 533 /// \brief the merged value 534 NDArray merged; 535 }; 536 struct NCCLEntry { 537 /// \brief device ID 538 int dev_id; 539 /// \brief NCCL commmunicator 540 ncclComm_t comm; 541 /// \brief NCCL rank 542 int rank; 543 /// \brief GPU stream to use with NCCL 544 cudaStream_t stream; 545 }; 546 std::unordered_map<int, BufferEntry> merge_buf_; 547 std::unordered_map<int, NCCLEntry> nccl_data_; 548 bool inited_; 549 // \brief devices used with this KVStore 550 std::vector<int> device_ids_; 551 }; 552 } // namespace kvstore 553 } // namespace mxnet 554 #endif // MXNET_USE_NCCL 555 #endif // MXNET_KVSTORE_KVSTORE_NCCL_H_ 556