1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file relu_lib.cu
22 * \brief simple custom relu and noisy relu operator implemented using CUDA function
23 */
24
25 #include <iostream>
26 #include "mxnet/lib_api.h"
27
28 using namespace mxnet::ext;
29
30 #define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block
31
relu_gpu_forward(float * out,float * in,int64_t N)32 __global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
33 int tid = blockIdx.x * blockDim.x + threadIdx.x;
34 if (tid < N)
35 out[tid] = in[tid] > 0 ? in[tid] : 0;
36 }
37
relu_gpu_backward(float * ingrad,float * outgrad,float * indata,int64_t N)38 __global__ void relu_gpu_backward(float *ingrad, float *outgrad, float *indata, int64_t N) {
39 int tid = blockIdx.x * blockDim.x + threadIdx.x;
40 if (tid < N)
41 ingrad[tid] = indata[tid] > 0 ? 1 * outgrad[tid] : 0;
42 }
43
forwardCPU(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)44 MXReturnValue forwardCPU(const std::unordered_map<std::string, std::string>& attrs,
45 std::vector<MXTensor>* inputs,
46 std::vector<MXTensor>* outputs,
47 const OpResource& res) {
48 float* in_data = inputs->at(0).data<float>();
49 float* out_data = outputs->at(0).data<float>();
50 for (int i=0; i<inputs->at(0).size(); i++) {
51 out_data[i] = in_data[i] > 0 ? in_data[i] : 0;
52 }
53 return MX_SUCCESS;
54 }
55
backwardCPU(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)56 MXReturnValue backwardCPU(const std::unordered_map<std::string, std::string>& attrs,
57 std::vector<MXTensor>* inputs,
58 std::vector<MXTensor>* outputs,
59 const OpResource& res) {
60 float* out_grad = inputs->at(0).data<float>();
61 float* in_data = inputs->at(1).data<float>();
62 float* in_grad = outputs->at(0).data<float>();
63 for (int i=0; i<inputs->at(1).size(); i++) {
64 in_grad[i] = in_data[i] > 0 ? 1 * out_grad[i] : 0;
65 }
66 return MX_SUCCESS;
67 }
68
forwardGPU(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)69 MXReturnValue forwardGPU(const std::unordered_map<std::string, std::string>& attrs,
70 std::vector<MXTensor>* inputs,
71 std::vector<MXTensor>* outputs,
72 const OpResource& res) {
73 float* in_data = inputs->at(0).data<float>();
74 float* out_data = outputs->at(0).data<float>();
75
76 mx_stream_t cuda_stream = res.get_cuda_stream();
77 int64_t N = inputs->at(0).size();
78 int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;
79
80 relu_gpu_forward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(out_data, in_data, N);
81
82 return MX_SUCCESS;
83 }
84
backwardGPU(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)85 MXReturnValue backwardGPU(const std::unordered_map<std::string, std::string>& attrs,
86 std::vector<MXTensor>* inputs,
87 std::vector<MXTensor>* outputs,
88 const OpResource& res) {
89 float* out_grad = inputs->at(0).data<float>();
90 float* in_data = inputs->at(1).data<float>();
91 float* in_grad = outputs->at(0).data<float>();
92
93 mx_stream_t cuda_stream = res.get_cuda_stream();
94 int64_t N = inputs->at(0).size();
95 int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;
96 relu_gpu_backward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(in_grad, out_grad, in_data, N);
97
98 return MX_SUCCESS;
99 }
100
parseAttrs(const std::unordered_map<std::string,std::string> & attrs,int * num_in,int * num_out)101 MXReturnValue parseAttrs(const std::unordered_map<std::string, std::string>& attrs,
102 int* num_in, int* num_out) {
103 *num_in = 1;
104 *num_out = 1;
105 return MX_SUCCESS;
106 }
107
inferType(const std::unordered_map<std::string,std::string> & attrs,std::vector<int> * intypes,std::vector<int> * outtypes)108 MXReturnValue inferType(const std::unordered_map<std::string, std::string>& attrs,
109 std::vector<int>* intypes,
110 std::vector<int>* outtypes) {
111 outtypes->at(0) = intypes->at(0);
112 return MX_SUCCESS;
113 }
114
inferShape(const std::unordered_map<std::string,std::string> & attrs,std::vector<std::vector<unsigned int>> * inshapes,std::vector<std::vector<unsigned int>> * outshapes)115 MXReturnValue inferShape(const std::unordered_map<std::string, std::string>& attrs,
116 std::vector<std::vector<unsigned int>>* inshapes,
117 std::vector<std::vector<unsigned int>>* outshapes) {
118 outshapes->at(0) = inshapes->at(0);
119 return MX_SUCCESS;
120 }
121
122 REGISTER_OP(my_relu)
123 .setParseAttrs(parseAttrs)
124 .setInferType(inferType)
125 .setInferShape(inferShape)
126 .setForward(forwardCPU, "cpu")
127 .setForward(forwardGPU, "gpu")
128 .setBackward(backwardCPU, "cpu")
129 .setBackward(backwardGPU, "gpu");
130
131 class MyStatefulReluCPU : public CustomStatefulOp {
132 public:
MyStatefulReluCPU(const std::unordered_map<std::string,std::string> & attrs)133 explicit MyStatefulReluCPU(const std::unordered_map<std::string, std::string>& attrs)
134 : attrs_(attrs) {}
Forward(std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & op_res)135 MXReturnValue Forward(std::vector<MXTensor>* inputs,
136 std::vector<MXTensor>* outputs,
137 const OpResource& op_res) {
138 return forwardCPU(attrs_, inputs, outputs, op_res);
139 }
Backward(std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & op_res)140 MXReturnValue Backward(std::vector<MXTensor>* inputs,
141 std::vector<MXTensor>* outputs,
142 const OpResource& op_res) {
143 return backwardCPU(attrs_, inputs, outputs, op_res);
144 }
~MyStatefulReluCPU()145 ~MyStatefulReluCPU() {}
146 private:
147 const std::unordered_map<std::string, std::string> attrs_;
148 };
149
150 class MyStatefulReluGPU : public CustomStatefulOp {
151 public:
MyStatefulReluGPU(const std::unordered_map<std::string,std::string> & attrs)152 explicit MyStatefulReluGPU(const std::unordered_map<std::string, std::string>& attrs)
153 : attrs_(attrs) {}
Forward(std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & op_res)154 MXReturnValue Forward(std::vector<MXTensor>* inputs,
155 std::vector<MXTensor>* outputs,
156 const OpResource& op_res) {
157 return forwardGPU(attrs_, inputs, outputs, op_res);
158 }
Backward(std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & op_res)159 MXReturnValue Backward(std::vector<MXTensor>* inputs,
160 std::vector<MXTensor>* outputs,
161 const OpResource& op_res) {
162 return backwardGPU(attrs_, inputs, outputs, op_res);
163 }
~MyStatefulReluGPU()164 ~MyStatefulReluGPU() {}
165 private:
166 const std::unordered_map<std::string, std::string> attrs_;
167 };
168
createOpStateCPU(const std::unordered_map<std::string,std::string> & attrs,const MXContext & ctx,const std::vector<std::vector<unsigned int>> & in_shapes,const std::vector<int> in_types,CustomStatefulOp ** op_inst)169 MXReturnValue createOpStateCPU(const std::unordered_map<std::string, std::string>& attrs,
170 const MXContext& ctx,
171 const std::vector<std::vector<unsigned int> >& in_shapes,
172 const std::vector<int> in_types,
173 CustomStatefulOp** op_inst) {
174 *op_inst = new MyStatefulReluCPU(attrs);
175 return MX_SUCCESS;
176 }
177
createOpStateGPU(const std::unordered_map<std::string,std::string> & attrs,const MXContext & ctx,const std::vector<std::vector<unsigned int>> & in_shapes,const std::vector<int> in_types,CustomStatefulOp ** op_inst)178 MXReturnValue createOpStateGPU(const std::unordered_map<std::string, std::string>& attrs,
179 const MXContext& ctx,
180 const std::vector<std::vector<unsigned int> >& in_shapes,
181 const std::vector<int> in_types,
182 CustomStatefulOp** op_inst) {
183 *op_inst = new MyStatefulReluGPU(attrs);
184 return MX_SUCCESS;
185 }
186
187 REGISTER_OP(my_state_relu)
188 .setParseAttrs(parseAttrs)
189 .setInferType(inferType)
190 .setInferShape(inferShape)
191 .setCreateOpState(createOpStateCPU, "cpu")
192 .setCreateOpState(createOpStateGPU, "gpu");
193
194 /*
195 * Below is noisy ReLU operator example
196 * noisy ReLU is made from ReLU extended to include Gaussian noise
197 * forward - add Gaussian noise generated from normal distribution to each unit
198 * backward - gradient doesn't need to change since noise is constant
199 */
200
201 #define NumRandomPerThread 64 // mxnet recommended random numbers generated per thread
202
noisy_relu_gpu_forward(float * out,float * in,int64_t N,mx_gpu_rand_t * states,int step)203 __global__ void noisy_relu_gpu_forward(float *out, float *in, int64_t N, mx_gpu_rand_t* states, int step) {
204 // the launcher logic ensures tid less than NumGPURandomStates
205 int tid = blockIdx.x * blockDim.x + threadIdx.x;
206 // each thread generates unique sequence of random numbers
207 mx_gpu_rand_t thread_state = states[tid];
208 // each thread works on <step> number of calculation
209 int start = tid * step;
210 int end = start + step;
211 for (int i=start; i<end && i<N; ++i) {
212 float noise = curand_normal(&thread_state);
213 out[i] = in[i] + noise > 0 ? in[i] + noise : 0;
214 }
215 }
216
noisyForwardCPU(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)217 MXReturnValue noisyForwardCPU(const std::unordered_map<std::string, std::string>& attrs,
218 std::vector<MXTensor>* inputs,
219 std::vector<MXTensor>* outputs,
220 const OpResource& res) {
221 float* in_data = inputs->at(0).data<float>();
222 float* out_data = outputs->at(0).data<float>();
223
224 mx_cpu_rand_t* states = res.get_cpu_rand_states();
225 std::normal_distribution<float> dist_normal;
226
227 for (int i=0; i<inputs->at(0).size(); ++i) {
228 float noise = dist_normal(*states);
229 out_data[i] = in_data[i] + noise > 0 ? in_data[i] + noise : 0;
230 }
231 return MX_SUCCESS;
232 }
233
noisyForwardGPU(const std::unordered_map<std::string,std::string> & attrs,std::vector<MXTensor> * inputs,std::vector<MXTensor> * outputs,const OpResource & res)234 MXReturnValue noisyForwardGPU(const std::unordered_map<std::string, std::string>& attrs,
235 std::vector<MXTensor>* inputs,
236 std::vector<MXTensor>* outputs,
237 const OpResource& res) {
238 float* in_data = inputs->at(0).data<float>();
239 float* out_data = outputs->at(0).data<float>();
240
241 mx_stream_t cuda_stream = res.get_cuda_stream();
242 int64_t N = inputs->at(0).size();
243
244 // below is mxnet recommended workflow to parallel random number generating
245 int nthread = (N + NumRandomPerThread - 1) / NumRandomPerThread;
246 // we should not launch more threads than mxnet supported random number GPU states
247 int num_thread_need = nthread < MX_NUM_GPU_RANDOM_STATES ? nthread : MX_NUM_GPU_RANDOM_STATES;
248 // each cuda thread processes [step * tid, step * id + step) snippet of input tensor
249 int step = (N + num_thread_need - 1) / num_thread_need;
250 // this can ensure number of parallel threads less than mxnet supported random number states
251 int num_block = (num_thread_need + NumThreadPerBlock - 1) / NumThreadPerBlock;
252
253 noisy_relu_gpu_forward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(
254 out_data, in_data, N, res.get_gpu_rand_states(), step);
255
256 return MX_SUCCESS;
257 }
258
259 REGISTER_OP(my_noisy_relu)
260 .setParseAttrs(parseAttrs)
261 .setInferType(inferType)
262 .setInferShape(inferShape)
263 .setForward(noisyForwardCPU, "cpu")
264 .setForward(noisyForwardGPU, "gpu")
265 .setBackward(backwardCPU, "cpu")
266 .setBackward(backwardGPU, "gpu");
267
initialize(int version)268 MXReturnValue initialize(int version) {
269 if (version >= 10900) {
270 std::cout << "MXNet version " << version << " supported" << std::endl;
271 return MX_SUCCESS;
272 } else {
273 MX_ERROR_MSG << "MXNet version " << version << " not supported";
274 return MX_FAIL;
275 }
276 }
277