1 #include "darknet.h"
2 #include <cuda_runtime.h>
3 #include <curand.h>
4 #include <cublas_v2.h>
5 #include <float.h>
6
7 #include "activations.h"
8 #include "dark_cuda.h"
9
lhtan_activate_kernel(float x)10 __device__ float lhtan_activate_kernel(float x)
11 {
12 if(x < 0) return .001*x;
13 if(x > 1) return .001*(x-1) + 1;
14 return x;
15 }
lhtan_gradient_kernel(float x)16 __device__ float lhtan_gradient_kernel(float x)
17 {
18 if(x > 0 && x < 1) return 1;
19 return .001;
20 }
21
hardtan_activate_kernel(float x)22 __device__ float hardtan_activate_kernel(float x)
23 {
24 if (x < -1) return -1;
25 if (x > 1) return 1;
26 return x;
27 }
linear_activate_kernel(float x)28 __device__ float linear_activate_kernel(float x){return x;}
logistic_activate_kernel(float x)29 __device__ float logistic_activate_kernel(float x){return 1.f/(1.f + expf(-x));}
loggy_activate_kernel(float x)30 __device__ float loggy_activate_kernel(float x){return 2.f/(1.f + expf(-x)) - 1;}
relu_activate_kernel(float x)31 __device__ float relu_activate_kernel(float x){return x*(x>0);}
relu6_activate_kernel(float x)32 __device__ float relu6_activate_kernel(float x) { return min_val_cmp(max_val_cmp(x, 0), 6); }
elu_activate_kernel(float x)33 __device__ float elu_activate_kernel(float x){return (x >= 0)*x + (x < 0)*(expf(x)-1);}
selu_activate_kernel(float x)34 __device__ float selu_activate_kernel(float x) { return (x >= 0)*1.0507f*x + (x < 0)*1.0507f*1.6732f*(expf(x) - 1); }
relie_activate_kernel(float x)35 __device__ float relie_activate_kernel(float x){return (x>0) ? x : .01f*x;}
ramp_activate_kernel(float x)36 __device__ float ramp_activate_kernel(float x){return x*(x>0)+.1f*x;}
leaky_activate_kernel(float x)37 __device__ float leaky_activate_kernel(float x){return (x>0) ? x : .1f*x;}
tanh_activate_kernel(float x)38 __device__ float tanh_activate_kernel(float x){return (2/(1 + expf(-2*x)) - 1);}
gelu_activate_kernel(float x)39 __device__ float gelu_activate_kernel(float x){return (0.5*x*(1 + tanhf(0.797885*x + 0.035677*powf(x, 3))));}
softplus_kernel(float x,float threshold=20)40 __device__ float softplus_kernel(float x, float threshold = 20) {
41 if (x > threshold) return x; // too large
42 else if (x < -threshold) return expf(x); // too small
43 return log1pf(expf(x));
44 //return logf(expf(x) + 1);
45 }
plse_activate_kernel(float x)46 __device__ float plse_activate_kernel(float x)
47 {
48 if(x < -4) return .01f * (x + 4);
49 if(x > 4) return .01f * (x - 4) + 1;
50 return .125f*x + .5f;
51 }
stair_activate_kernel(float x)52 __device__ float stair_activate_kernel(float x)
53 {
54 int n = floorf(x);
55 if (n%2 == 0) return floorf(x/2.f);
56 else return (x - n) + floorf(x/2.f);
57 }
58
59
hardtan_gradient_kernel(float x)60 __device__ float hardtan_gradient_kernel(float x)
61 {
62 if (x > -1 && x < 1) return 1;
63 return 0;
64 }
linear_gradient_kernel(float x)65 __device__ float linear_gradient_kernel(float x){return 1;}
logistic_gradient_kernel(float x)66 __device__ float logistic_gradient_kernel(float x){return (1-x)*x;}
loggy_gradient_kernel(float x)67 __device__ float loggy_gradient_kernel(float x)
68 {
69 float y = (x+1.F)/2.F;
70 return 2*(1-y)*y;
71 }
relu_gradient_kernel(float x)72 __device__ float relu_gradient_kernel(float x){return (x>0);}
relu6_gradient_kernel(float x)73 __device__ float relu6_gradient_kernel(float x) { return (x > 0 && x < 6); }
elu_gradient_kernel(float x)74 __device__ float elu_gradient_kernel(float x){return (x >= 0) + (x < 0)*(x + 1);}
selu_gradient_kernel(float x)75 __device__ float selu_gradient_kernel(float x) { return (x >= 0)*1.0507f + (x < 0)*(x + 1.0507f*1.6732f); }
relie_gradient_kernel(float x)76 __device__ float relie_gradient_kernel(float x){return (x>0) ? 1 : .01f;}
ramp_gradient_kernel(float x)77 __device__ float ramp_gradient_kernel(float x){return (x>0)+.1f;}
leaky_gradient_kernel(float x)78 __device__ float leaky_gradient_kernel(float x){return (x>0) ? 1 : .1f;}
tanh_gradient_kernel(float x)79 __device__ float tanh_gradient_kernel(float x){return 1-x*x;}
sech_gpu(float x)80 __device__ float sech_gpu(float x) { return 2 / (expf(x) + expf(-x)); }
gelu_gradient_kernel(float x)81 __device__ float gelu_gradient_kernel(float x) {
82 const float x3 = powf(x, 3);
83 return 0.5*tanhf(0.0356774*x3 + 0.797885*x) + (0.0535161*x3 + 0.398942*x) * powf(sech_gpu(0.0356774*x3 + 0.797885*x), 2) + 0.5;
84 }
plse_gradient_kernel(float x)85 __device__ float plse_gradient_kernel(float x){return (x < 0 || x > 1) ? .01f : .125f;}
stair_gradient_kernel(float x)86 __device__ float stair_gradient_kernel(float x)
87 {
88 if (floor(x) == x) return 0;
89 return 1;
90 }
91
activate_kernel(float x,ACTIVATION a)92 __device__ float activate_kernel(float x, ACTIVATION a)
93 {
94 switch(a){
95 case LINEAR:
96 return linear_activate_kernel(x);
97 case LOGISTIC:
98 return logistic_activate_kernel(x);
99 case LOGGY:
100 return loggy_activate_kernel(x);
101 case RELU:
102 return relu_activate_kernel(x);
103 case RELU6:
104 return relu6_activate_kernel(x);
105 case ELU:
106 return elu_activate_kernel(x);
107 case SELU:
108 return selu_activate_kernel(x);
109 case GELU:
110 return gelu_activate_kernel(x);
111 case RELIE:
112 return relie_activate_kernel(x);
113 case RAMP:
114 return ramp_activate_kernel(x);
115 case LEAKY:
116 return leaky_activate_kernel(x);
117 case TANH:
118 return tanh_activate_kernel(x);
119 case PLSE:
120 return plse_activate_kernel(x);
121 case STAIR:
122 return stair_activate_kernel(x);
123 case HARDTAN:
124 return hardtan_activate_kernel(x);
125 case LHTAN:
126 return lhtan_activate_kernel(x);
127 }
128 return 0;
129 }
130
gradient_kernel(float x,ACTIVATION a)131 __device__ float gradient_kernel(float x, ACTIVATION a)
132 {
133 switch (a) {
134 case LINEAR:
135 return linear_gradient_kernel(x);
136 case LOGISTIC:
137 return logistic_gradient_kernel(x);
138 case LOGGY:
139 return loggy_gradient_kernel(x);
140 case RELU:
141 return relu_gradient_kernel(x);
142 case RELU6:
143 return relu6_gradient_kernel(x);
144 case NORM_CHAN:
145 return relu_gradient_kernel(x);
146 case ELU:
147 return elu_gradient_kernel(x);
148 case SELU:
149 return selu_gradient_kernel(x);
150 case GELU:
151 return gelu_gradient_kernel(x);
152 case RELIE:
153 return relie_gradient_kernel(x);
154 case RAMP:
155 return ramp_gradient_kernel(x);
156 case LEAKY:
157 return leaky_gradient_kernel(x);
158 case TANH:
159 return tanh_gradient_kernel(x);
160 case PLSE:
161 return plse_gradient_kernel(x);
162 case STAIR:
163 return stair_gradient_kernel(x);
164 case HARDTAN:
165 return hardtan_gradient_kernel(x);
166 case LHTAN:
167 return lhtan_gradient_kernel(x);
168 }
169 return 0;
170 }
171
binary_gradient_array_kernel(float * x,float * dy,int n,int s,BINARY_ACTIVATION a,float * dx)172 __global__ void binary_gradient_array_kernel(float *x, float *dy, int n, int s, BINARY_ACTIVATION a, float *dx)
173 {
174 int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
175 int i = id % s;
176 int b = id / s;
177 float x1 = x[b*s + i];
178 float x2 = x[b*s + s / 2 + i];
179 if (id < n) {
180 float de = dy[id];
181 dx[b*s + i] = x2*de;
182 dx[b*s + s / 2 + i] = x1*de;
183 }
184 }
185
binary_gradient_array_gpu(float * x,float * dx,int n,int size,BINARY_ACTIVATION a,float * y)186 extern "C" void binary_gradient_array_gpu(float *x, float *dx, int n, int size, BINARY_ACTIVATION a, float *y)
187 {
188 binary_gradient_array_kernel << <cuda_gridsize(n / 2), BLOCK, 0, get_cuda_stream() >> >(x, dx, n / 2, size, a, y);
189 CHECK_CUDA(cudaPeekAtLastError());
190 }
binary_activate_array_kernel(float * x,int n,int s,BINARY_ACTIVATION a,float * y)191 __global__ void binary_activate_array_kernel(float *x, int n, int s, BINARY_ACTIVATION a, float *y)
192 {
193 int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
194 int i = id % s;
195 int b = id / s;
196 float x1 = x[b*s + i];
197 float x2 = x[b*s + s / 2 + i];
198 if (id < n) y[id] = x1*x2;
199 }
200
binary_activate_array_gpu(float * x,int n,int size,BINARY_ACTIVATION a,float * y)201 extern "C" void binary_activate_array_gpu(float *x, int n, int size, BINARY_ACTIVATION a, float *y)
202 {
203 binary_activate_array_kernel << <cuda_gridsize(n / 2), BLOCK, 0, get_cuda_stream() >> >(x, n / 2, size, a, y);
204 CHECK_CUDA(cudaPeekAtLastError());
205 }
206
activate_array_kernel(float * x,int n,ACTIVATION a)207 __global__ void activate_array_kernel(float *x, int n, ACTIVATION a)
208 {
209 int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
210 if(i < n) x[i] = activate_kernel(x[i], a);
211 }
212
213
214
activate_array_swish_kernel(float * x,int n,float * output_sigmoid_gpu,float * output_gpu)215 __global__ void activate_array_swish_kernel(float *x, int n, float *output_sigmoid_gpu, float *output_gpu)
216 {
217 int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
218 if (i < n) {
219 float x_val = x[i];
220 float sigmoid = logistic_activate_kernel(x_val);
221 if (output_sigmoid_gpu) output_sigmoid_gpu[i] = sigmoid;
222 output_gpu[i] = x_val * sigmoid;
223 }
224 }
225
mish_njuffa(float x)226 __device__ float mish_njuffa(float x)
227 {
228 float r;
229 float e = expf(x);
230 r = 1.0f / fmaf(fmaf(-0.5f, e, -1.0f), e, -1.0f);
231 r = fmaf(r, x, x);
232 return r;
233 }
234
mish_yashas(float x)235 __device__ float mish_yashas(float x)
236 {
237 float e = __expf(x);
238 if (x <= -18.0f)
239 return x * e;
240
241 float n = e * e + 2 * e;
242 if (x <= -5.0f)
243 return x * __fdividef(n, n + 2);
244
245 return x - 2 * __fdividef(x, n + 2);
246 }
247
248 // https://github.com/digantamisra98/Mish
activate_array_mish_kernel(float * x,int n,float * activation_input,float * output_gpu)249 __global__ void activate_array_mish_kernel(float *x, int n, float *activation_input, float *output_gpu)
250 {
251 int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
252 if (i < n) {
253 const float MISH_THRESHOLD = 20;
254 float x_val = x[i];
255 if (activation_input) activation_input[i] = x_val; // store value before activation
256 //output_gpu[i] = x_val * tanh_activate_kernel(logf(1 + expf(x_val)));
257
258 // Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L17-L20
259 // TF: https://github.com/tensorflow/addons/blob/093cdfa85d334cbe19a37624c33198f3140109ed/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h#L40-L49
260 // log1p(x) == log(x + 1)
261 //output_gpu[i] = x_val * tanh_activate_kernel( softplus_kernel(x_val, MISH_THRESHOLD) );
262 output_gpu[i] = mish_yashas(x_val);
263 //output_gpu[i] = mish_njuffa(x_val);
264 }
265 }
266
activate_array_leaky_kernel(float * x,int n)267 __global__ void activate_array_leaky_kernel(float *x, int n)
268 {
269 int index = blockIdx.x*blockDim.x + threadIdx.x;
270 if (index < n) {
271 x[index] = leaky_activate_kernel(x[index]);
272 }
273 }
274
activate_array_selu_kernel(float * x,int n)275 __global__ void activate_array_selu_kernel(float *x, int n)
276 {
277 int index = blockIdx.x*blockDim.x + threadIdx.x;
278 if (index < n) {
279 x[index] = selu_activate_kernel(x[index]);
280 }
281 }
282
activate_array_gelu_kernel(float * x,int n)283 __global__ void activate_array_gelu_kernel(float *x, int n)
284 {
285 int index = blockIdx.x*blockDim.x + threadIdx.x;
286 if (index < n) {
287 x[index] = gelu_activate_kernel(x[index]);
288 }
289 }
290
activate_array_logistic_kernel(float * x,int n)291 __global__ void activate_array_logistic_kernel(float *x, int n)
292 {
293 int index = blockIdx.x*blockDim.x + threadIdx.x;
294 if (index < n) {
295 x[index] = logistic_activate_kernel(x[index]);
296 }
297 }
298
activate_array_tanh_kernel(float * x,int n)299 __global__ void activate_array_tanh_kernel(float *x, int n)
300 {
301 int index = blockIdx.x*blockDim.x + threadIdx.x;
302 if (index < n) {
303 x[index] = tanh_activate_kernel(x[index]);
304 }
305 }
306
activate_array_hardtan_kernel(float * x,int n)307 __global__ void activate_array_hardtan_kernel(float *x, int n)
308 {
309 int index = blockIdx.x*blockDim.x + threadIdx.x;
310 if (index < n) {
311 x[index] = hardtan_activate_kernel(x[index]);
312 }
313 }
314
activate_array_relu_kernel(float * x,int n)315 __global__ void activate_array_relu_kernel(float *x, int n)
316 {
317 int index = blockIdx.x*blockDim.x + threadIdx.x;
318 if (index < n) {
319 x[index] = relu_activate_kernel(x[index]);
320 }
321 }
322
activate_array_relu6_kernel(float * x,int n)323 __global__ void activate_array_relu6_kernel(float *x, int n)
324 {
325 int index = blockIdx.x*blockDim.x + threadIdx.x;
326 if (index < n) {
327 x[index] = relu6_activate_kernel(x[index]);
328 }
329 }
330
gradient_array_kernel(float * x,int n,ACTIVATION a,float * delta)331 __global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delta)
332 {
333 int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
334 if(i < n) delta[i] *= gradient_kernel(x[i], a);
335 }
336
337 // https://github.com/BVLC/caffe/blob/04ab089db018a292ae48d51732dd6c66766b36b6/src/caffe/layers/swish_layer.cu#L28-L30
gradient_array_swish_kernel(float * x,int n,float * sigmoid_gpu,float * delta)338 __global__ void gradient_array_swish_kernel(float *x, int n, float *sigmoid_gpu, float *delta)
339 {
340 int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
341 if (i < n) {
342 float swish = x[i];
343 delta[i] *= swish + sigmoid_gpu[i] * (1 - swish); // gradient_kernel(x[i], a);
344 }
345 }
346
347 // https://github.com/digantamisra98/Mish
gradient_array_mish_kernel(int n,float * activation_input_gpu,float * delta)348 __global__ void gradient_array_mish_kernel(int n, float *activation_input_gpu, float *delta)
349 {
350 int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
351 if (i < n) {
352 const float MISH_THRESHOLD = 20.0f;
353
354 // implementation from TensorFlow: https://github.com/tensorflow/addons/blob/093cdfa85d334cbe19a37624c33198f3140109ed/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h#L66-L80
355 // implementation from Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L26-L31
356 // log1p(x) == log(x + 1)
357 const float inp = activation_input_gpu[i];
358 const float sp = softplus_kernel(inp, MISH_THRESHOLD);
359 const float grad_sp = -expm1f(-sp);
360 //const float grad_sp = 1 - expf(-sp);
361 const float tsp = tanh(sp);
362 const float grad_tsp = (1 - tsp*tsp) * grad_sp;
363 const float grad = inp * grad_tsp + tsp;
364 delta[i] *= grad;
365
366 //float x = activation_input[i];
367 //float d = 2 * expf(x) + expf(2 * x) + 2;
368 //float w = 4 * (x + 1) + 4 * expf(2 * x) + expf(3 * x) + expf(x)*(4 * x + 6);
369 //float derivative = expf(x) * w / (d * d);
370 //delta[i] *= derivative;
371 }
372 }
373
gradient_array_leaky_kernel(float * x,int n,float * delta)374 __global__ void gradient_array_leaky_kernel(float *x, int n, float *delta)
375 {
376 int index = blockIdx.x*blockDim.x + threadIdx.x;
377 if (index < n) {
378 delta[index] *= leaky_gradient_kernel(x[index]);
379 }
380 }
381
gradient_array_selu_kernel(float * x,int n,float * delta)382 __global__ void gradient_array_selu_kernel(float *x, int n, float *delta)
383 {
384 int index = blockIdx.x*blockDim.x + threadIdx.x;
385 if (index < n) {
386 delta[index] *= selu_gradient_kernel(x[index]);
387 }
388 }
389
gradient_array_gelu_kernel(float * x,int n,float * delta)390 __global__ void gradient_array_gelu_kernel(float *x, int n, float *delta)
391 {
392 int index = blockIdx.x*blockDim.x + threadIdx.x;
393 if (index < n) {
394 delta[index] *= gelu_gradient_kernel(x[index]);
395 }
396 }
397
gradient_array_logistic_kernel(float * x,int n,float * delta)398 __global__ void gradient_array_logistic_kernel(float *x, int n, float *delta)
399 {
400 int index = blockIdx.x*blockDim.x + threadIdx.x;
401 if (index < n) {
402 delta[index] *= logistic_gradient_kernel(x[index]);
403 }
404 }
405
gradient_array_tanh_kernel(float * x,int n,float * delta)406 __global__ void gradient_array_tanh_kernel(float *x, int n, float *delta)
407 {
408 int index = blockIdx.x*blockDim.x + threadIdx.x;
409 if (index < n) {
410 delta[index] *= tanh_gradient_kernel(x[index]);
411 }
412 }
413
gradient_array_hardtan_kernel(float * x,int n,float * delta)414 __global__ void gradient_array_hardtan_kernel(float *x, int n, float *delta)
415 {
416 int index = blockIdx.x*blockDim.x + threadIdx.x;
417 if (index < n) {
418 delta[index] *= hardtan_gradient_kernel(x[index]);
419 }
420 }
421
gradient_array_relu_kernel(float * x,int n,float * delta)422 __global__ void gradient_array_relu_kernel(float *x, int n, float *delta)
423 {
424 int index = blockIdx.x*blockDim.x + threadIdx.x;
425 if (index < n) {
426 delta[index] *= relu_gradient_kernel(x[index]);
427 }
428 }
429
gradient_array_relu6_kernel(float * x,int n,float * delta)430 __global__ void gradient_array_relu6_kernel(float *x, int n, float *delta)
431 {
432 int index = blockIdx.x*blockDim.x + threadIdx.x;
433 if (index < n) {
434 delta[index] *= relu6_gradient_kernel(x[index]);
435 }
436 }
437
activate_array_ongpu(float * x,int n,ACTIVATION a)438 extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a)
439 {
440 const int num_blocks = get_number_of_blocks(n, BLOCK);
441 if (a == LINEAR) return;
442 else if(a == LEAKY) activate_array_leaky_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
443 else if (a == LOGISTIC) activate_array_logistic_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
444 else if (a == TANH) activate_array_tanh_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
445 else if (a == HARDTAN) activate_array_hardtan_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
446 else if (a == RELU) activate_array_relu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
447 else if (a == RELU6) activate_array_relu6_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
448 else if (a == SELU) activate_array_selu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
449 else if (a == GELU) activate_array_gelu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
450 else
451 activate_array_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream()>>>(x, n, a);
452 CHECK_CUDA(cudaPeekAtLastError());
453 }
454
activate_array_swish_ongpu(float * x,int n,float * output_sigmoid_gpu,float * output_gpu)455 extern "C" void activate_array_swish_ongpu(float *x, int n, float *output_sigmoid_gpu, float *output_gpu)
456 {
457 const int num_blocks = get_number_of_blocks(n, BLOCK);
458 activate_array_swish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> >(x, n, output_sigmoid_gpu, output_gpu);
459 CHECK_CUDA(cudaPeekAtLastError());
460 }
461
activate_array_mish_ongpu(float * x,int n,float * activation_input_gpu,float * output_gpu)462 extern "C" void activate_array_mish_ongpu(float *x, int n, float *activation_input_gpu, float *output_gpu)
463 {
464 const int num_blocks = get_number_of_blocks(n, BLOCK);
465 activate_array_mish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> >(x, n, activation_input_gpu, output_gpu);
466 CHECK_CUDA(cudaPeekAtLastError());
467 }
468
gradient_array_ongpu(float * x,int n,ACTIVATION a,float * delta)469 extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta)
470 {
471 const int num_blocks = get_number_of_blocks(n, BLOCK);
472 if (a == LINEAR) return;
473 else if (a == LEAKY) gradient_array_leaky_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
474 else if (a == LOGISTIC) gradient_array_logistic_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
475 else if (a == TANH) gradient_array_tanh_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
476 else if (a == HARDTAN) gradient_array_hardtan_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
477 else if (a == RELU) gradient_array_relu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
478 else if (a == RELU6) gradient_array_relu6_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
479 //else if (a == NORM_CHAN) gradient_array_relu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
480 else if (a == NORM_CHAN_SOFTMAX || a == NORM_CHAN) {
481 printf(" Error: should be used custom NORM_CHAN_SOFTMAX-function for gradient \n");
482 exit(0);
483 }
484 else if (a == SELU) gradient_array_selu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
485 else if (a == GELU) gradient_array_gelu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
486 else
487 gradient_array_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (x, n, a, delta);
488 CHECK_CUDA(cudaPeekAtLastError());
489 }
490
491
gradient_array_swish_ongpu(float * x,int n,float * sigmoid_gpu,float * delta)492 extern "C" void gradient_array_swish_ongpu(float *x, int n, float *sigmoid_gpu, float *delta)
493 {
494 const int num_blocks = get_number_of_blocks(n, BLOCK);
495 gradient_array_swish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (x, n, sigmoid_gpu, delta);
496 CHECK_CUDA(cudaPeekAtLastError());
497 }
498
gradient_array_mish_ongpu(int n,float * activation_input_gpu,float * delta)499 extern "C" void gradient_array_mish_ongpu(int n, float *activation_input_gpu, float *delta)
500 {
501 const int num_blocks = get_number_of_blocks(n, BLOCK);
502 gradient_array_mish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (n, activation_input_gpu, delta);
503 CHECK_CUDA(cudaPeekAtLastError());
504 }
505
506
activate_array_normalize_channels_kernel(float * x,int size,int batch,int channels,int wh_step,float * output_gpu)507 __global__ void activate_array_normalize_channels_kernel(float *x, int size, int batch, int channels, int wh_step, float *output_gpu)
508 {
509 int i = blockIdx.x * blockDim.x + threadIdx.x;
510
511 int wh_i = i % wh_step;
512 int b = i / wh_step;
513
514 const float eps = 0.0001;
515 if (i < size) {
516 float sum = eps;
517 int k;
518 for (k = 0; k < channels; ++k) {
519 float val = x[wh_i + k * wh_step + b*wh_step*channels];
520 if (val > 0) sum += val;
521 }
522 for (k = 0; k < channels; ++k) {
523 float val = x[wh_i + k * wh_step + b*wh_step*channels];
524 if (val > 0) val = val / sum;
525 else val = 0;
526 output_gpu[wh_i + k * wh_step + b*wh_step*channels] = val;
527 }
528 }
529 }
530
activate_array_normalize_channels_ongpu(float * x,int n,int batch,int channels,int wh_step,float * output_gpu)531 extern "C" void activate_array_normalize_channels_ongpu(float *x, int n, int batch, int channels, int wh_step, float *output_gpu)
532 {
533 // n = w*h*c*batch
534 // size = w*h*batch
535 int size = n / channels;
536
537 const int num_blocks = get_number_of_blocks(size, BLOCK);
538
539 activate_array_normalize_channels_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (x, size, batch, channels, wh_step, output_gpu);
540 CHECK_CUDA(cudaPeekAtLastError());
541 }
542
543
544
activate_array_normalize_channels_softmax_kernel(float * x,int size,int batch,int channels,int wh_step,float * output_gpu,int use_max_val)545 __global__ void activate_array_normalize_channels_softmax_kernel(float *x, int size, int batch, int channels, int wh_step, float *output_gpu, int use_max_val)
546 {
547 int i = blockIdx.x * blockDim.x + threadIdx.x;
548
549 int wh_i = i % wh_step;
550 int b = i / wh_step;
551
552 const float eps = 0.0001;
553 if (i < size) {
554 float sum = eps;
555 float max_val = -FLT_MAX;
556 int k;
557 if (use_max_val) {
558 for (k = 0; k < channels; ++k) {
559 float val = x[wh_i + k * wh_step + b*wh_step*channels];
560 if (val > max_val || k == 0) max_val = val;
561 }
562 }
563 else
564 max_val = 0;
565
566 for (k = 0; k < channels; ++k) {
567 float val = x[wh_i + k * wh_step + b*wh_step*channels];
568 sum += expf(val - max_val);
569 }
570 for (k = 0; k < channels; ++k) {
571 float val = x[wh_i + k * wh_step + b*wh_step*channels];
572 val = expf(val - max_val) / sum;
573 if (isnan(val) || isinf(val)) val = 0;
574 output_gpu[wh_i + k * wh_step + b*wh_step*channels] = val;
575 }
576 }
577 }
578
activate_array_normalize_channels_softmax_ongpu(float * x,int n,int batch,int channels,int wh_step,float * output_gpu,int use_max_val)579 extern "C" void activate_array_normalize_channels_softmax_ongpu(float *x, int n, int batch, int channels, int wh_step, float *output_gpu, int use_max_val)
580 {
581 // n = w*h*c*batch
582 // size = w*h*batch
583 int size = n / channels;
584
585 const int num_blocks = get_number_of_blocks(size, BLOCK);
586
587 activate_array_normalize_channels_softmax_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (x, size, batch, channels, wh_step, output_gpu, use_max_val);
588 CHECK_CUDA(cudaPeekAtLastError());
589 }
590
591
592
gradient_array_normalize_channels_softmax_kernel(float * x,int size,int batch,int channels,int wh_step,float * delta_gpu)593 __global__ void gradient_array_normalize_channels_softmax_kernel(float *x, int size, int batch, int channels, int wh_step, float *delta_gpu)
594 {
595 int i = blockIdx.x * blockDim.x + threadIdx.x;
596
597 int wh_i = i % wh_step;
598 int b = i / wh_step;
599
600 if (i < size) {
601 int k;
602 /*
603 float grad = 0;
604 for (k = 0; k < channels; ++k) {
605 const int index = wh_i + k * wh_step + b*wh_step*channels;
606 float out = x[index];
607 float delta = delta_gpu[index];
608 grad += out*fabs(delta);
609 }
610 */
611 for (k = 0; k < channels; ++k) {
612 const int index = wh_i + k * wh_step + b*wh_step*channels;
613 float delta = delta_gpu[index];
614 float grad = x[index] * (1 - x[index]);
615 delta = delta * grad;
616 if (isnan(delta) || isinf(delta)) delta = 0;
617 delta_gpu[index] = delta;
618 }
619 }
620 }
621
gradient_array_normalize_channels_softmax_ongpu(float * output_gpu,int n,int batch,int channels,int wh_step,float * delta_gpu)622 extern "C" void gradient_array_normalize_channels_softmax_ongpu(float *output_gpu, int n, int batch, int channels, int wh_step, float *delta_gpu)
623 {
624 // n = w*h*c*batch
625 // size = w*h*batch
626 int size = n / channels;
627
628 const int num_blocks = get_number_of_blocks(size, BLOCK);
629
630 gradient_array_normalize_channels_softmax_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (output_gpu, size, batch, channels, wh_step, delta_gpu);
631 CHECK_CUDA(cudaPeekAtLastError());
632 }
633
634
gradient_array_normalize_channels_kernel(float * x,int size,int batch,int channels,int wh_step,float * delta_gpu)635 __global__ void gradient_array_normalize_channels_kernel(float *x, int size, int batch, int channels, int wh_step, float *delta_gpu)
636 {
637 int i = blockIdx.x * blockDim.x + threadIdx.x;
638
639 int wh_i = i % wh_step;
640 int b = i / wh_step;
641
642 if (i < size) {
643 int k;
644 /*
645 float grad = 0;
646 for (k = 0; k < channels; ++k) {
647 const int index = wh_i + k * wh_step + b*wh_step*channels;
648 float out = x[index];
649 float delta = delta_gpu[index];
650 grad += out*fabs(delta);
651 }
652 */
653 for (k = 0; k < channels; ++k) {
654 const int index = wh_i + k * wh_step + b*wh_step*channels;
655 if (x[index] > 0) {
656 float delta = delta_gpu[index];
657 float grad = x[index];
658 delta = delta * grad;
659 delta_gpu[index] = delta;
660 }
661 }
662 }
663 }
664
gradient_array_normalize_channels_ongpu(float * output_gpu,int n,int batch,int channels,int wh_step,float * delta_gpu)665 extern "C" void gradient_array_normalize_channels_ongpu(float *output_gpu, int n, int batch, int channels, int wh_step, float *delta_gpu)
666 {
667 // n = w*h*c*batch
668 // size = w*h*batch
669 int size = n / channels;
670
671 const int num_blocks = get_number_of_blocks(size, BLOCK);
672
673 gradient_array_normalize_channels_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (output_gpu, size, batch, channels, wh_step, delta_gpu);
674 CHECK_CUDA(cudaPeekAtLastError());
675 }