1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_ENGINE_STREAM_MANAGER_H_
21 #define MXNET_ENGINE_STREAM_MANAGER_H_
22 
23 #include <dmlc/base.h>
24 #include <mxnet/base.h>
25 #include <cstddef>
26 #include <array>
27 #include <string>
28 #include <mutex>
29 #include "../common/cuda_utils.h"
30 
31 namespace mxnet {
32 namespace engine {
33 
34 /*!
35  * \brief Stream manager.
36  *
37  * Uses a basic round-robin algorithm to dispatch GPU streams. Returns default
38  * context on CPU.
39  */
40 template <std::size_t kNumGpus, std::size_t kStreams>
41 class StreamManager {
42  public:
43   StreamManager();
~StreamManager()44   ~StreamManager() {
45     Finalize();
46   }
47   RunContext GetRunContext(Context const& ctx);
48   RunContext GetIORunContext(Context const& ctx);
49   void Finalize();
50  private:
51   std::mutex mutex_;
52 #if MXNET_USE_CUDA
53   std::array<std::array<mshadow::Stream<gpu>*, kStreams>, kNumGpus>
54       gpu_streams_;
55   std::array<std::array<GPUAuxStream*, kStreams>, kNumGpus>
56       gpu_aux_streams_;
57   std::array<mshadow::Stream<gpu>*, kNumGpus> gpu_io_streams_;
58   std::array<int, kNumGpus> gpu_cnt_;
59 #endif  // MXNET_USE_CUDA
60   DISALLOW_COPY_AND_ASSIGN(StreamManager);
61 };  // class StreamManager
62 
63 template <std::size_t kNumGpus, std::size_t kStreams>
GetRunContext(Context const & ctx)64 RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
65     Context const& ctx) {
66   RunContext ret;
67   switch (ctx.dev_mask()) {
68     case cpu::kDevMask:
69       ret = RunContext{ctx, nullptr, nullptr, false};
70       break;
71     case gpu::kDevMask: {
72 #if MXNET_USE_CUDA
73       std::size_t use_counter;
74       {
75         std::lock_guard<std::mutex> lock{mutex_};
76         auto&& counter = gpu_cnt_.at(ctx.dev_id);
77         if (counter == -1) {
78           mxnet::common::cuda::DeviceStore device_store(ctx.dev_id);
79           for (auto&& primary_stream : gpu_streams_.at(ctx.dev_id)) {
80             primary_stream = mshadow::NewStream<gpu>(true, MXNET_USE_CUDNN != 0, ctx.dev_id);
81           }
82           int idx = 0;
83           for (auto&& aux_stream : gpu_aux_streams_.at(ctx.dev_id)) {
84             auto primary_stream = gpu_streams_.at(ctx.dev_id).at(idx++);
85             aux_stream = new GPUAuxStream(primary_stream);
86           }
87           counter = 0;
88         }
89         use_counter = counter;
90         counter = (counter + 1) % kStreams;
91       }
92       ret = RunContext{ctx,
93                        gpu_streams_.at(ctx.dev_id).at(use_counter),
94                        gpu_aux_streams_.at(ctx.dev_id).at(use_counter),
95                        false};
96       break;
97 #else
98       LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
99 #endif  // MXNET_USE_CUDA
100     default:
101       LOG(FATAL) << "Not Reached";
102     }
103   }
104   return ret;
105 }
106 
107 template <std::size_t kNumGpus, std::size_t kStreams>
GetIORunContext(Context const & ctx)108 RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
109     Context const& ctx) {
110   RunContext ret;
111   switch (ctx.dev_mask()) {
112     case cpu::kDevMask:
113       ret = RunContext{ctx, nullptr, nullptr, false};
114       break;
115     case gpu::kDevMask: {
116 #if MXNET_USE_CUDA
117       {
118         std::lock_guard<std::mutex> lock{mutex_};
119         if (gpu_io_streams_.at(ctx.dev_id) == nullptr) {
120           mxnet::common::cuda::DeviceStore device_store(ctx.dev_id);
121           gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream<gpu>(false, false, ctx.dev_id);
122         }
123       }
124       ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), nullptr, false};
125       break;
126 #else
127       LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
128 #endif  // MXNET_USE_CUDA
129     default:
130       LOG(FATAL) << "Not Reached";
131     }
132   }
133   return ret;
134 }
135 
136 template <std::size_t kNumGpus, std::size_t kStreams>
StreamManager()137 StreamManager<kNumGpus, kStreams>::StreamManager() {
138 #if MXNET_USE_CUDA
139   for (std::size_t i = 0; i < kNumGpus; ++i) {
140     gpu_cnt_.at(i) = -1;
141   }
142   for (auto&& i : gpu_io_streams_) {
143     i = nullptr;
144   }
145 #endif  // MXNET_USE_CUDA
146 }
147 
148 template <std::size_t kNumGpus, std::size_t kStreams>
Finalize()149 void StreamManager<kNumGpus, kStreams>::Finalize() {
150 #if MXNET_USE_CUDA
151   for (std::size_t i = 0; i < kNumGpus; ++i) {
152     if (gpu_cnt_.at(i) != -1) {
153       for (auto&& primary_stream : gpu_streams_.at(i)) {
154         // Catch exception for CUDA driver shutdown
155         MSHADOW_CATCH_ERROR(mshadow::DeleteStream<gpu>(primary_stream));
156       }
157       for (auto&& aux_stream : gpu_aux_streams_.at(i)) {
158         delete aux_stream;
159       }
160       gpu_cnt_.at(i) = -1;
161     }
162   }
163 #endif  // MXNET_USE_CUDA
164 }
165 
166 }  // namespace engine
167 }  // namespace mxnet
168 
169 #endif  // MXNET_ENGINE_STREAM_MANAGER_H_
170