1 #include "darknet.h"
2 #include <cuda_runtime.h>
3 #include <curand.h>
4 #include <cublas_v2.h>
5 #include <float.h>
6 
7 #include "activations.h"
8 #include "dark_cuda.h"
9 
lhtan_activate_kernel(float x)10 __device__ float lhtan_activate_kernel(float x)
11 {
12     if(x < 0) return .001*x;
13     if(x > 1) return .001*(x-1) + 1;
14     return x;
15 }
lhtan_gradient_kernel(float x)16 __device__ float lhtan_gradient_kernel(float x)
17 {
18     if(x > 0 && x < 1) return 1;
19     return .001;
20 }
21 
hardtan_activate_kernel(float x)22 __device__ float hardtan_activate_kernel(float x)
23 {
24     if (x < -1) return -1;
25     if (x > 1) return 1;
26     return x;
27 }
linear_activate_kernel(float x)28 __device__ float linear_activate_kernel(float x){return x;}
logistic_activate_kernel(float x)29 __device__ float logistic_activate_kernel(float x){return 1.f/(1.f + expf(-x));}
loggy_activate_kernel(float x)30 __device__ float loggy_activate_kernel(float x){return 2.f/(1.f + expf(-x)) - 1;}
relu_activate_kernel(float x)31 __device__ float relu_activate_kernel(float x){return x*(x>0);}
relu6_activate_kernel(float x)32 __device__ float relu6_activate_kernel(float x) { return min_val_cmp(max_val_cmp(x, 0), 6); }
elu_activate_kernel(float x)33 __device__ float elu_activate_kernel(float x){return (x >= 0)*x + (x < 0)*(expf(x)-1);}
selu_activate_kernel(float x)34 __device__ float selu_activate_kernel(float x) { return (x >= 0)*1.0507f*x + (x < 0)*1.0507f*1.6732f*(expf(x) - 1); }
relie_activate_kernel(float x)35 __device__ float relie_activate_kernel(float x){return (x>0) ? x : .01f*x;}
ramp_activate_kernel(float x)36 __device__ float ramp_activate_kernel(float x){return x*(x>0)+.1f*x;}
leaky_activate_kernel(float x)37 __device__ float leaky_activate_kernel(float x){return (x>0) ? x : .1f*x;}
tanh_activate_kernel(float x)38 __device__ float tanh_activate_kernel(float x){return (2/(1 + expf(-2*x)) - 1);}
gelu_activate_kernel(float x)39 __device__ float gelu_activate_kernel(float x){return (0.5*x*(1 + tanhf(0.797885*x + 0.035677*powf(x, 3))));}
softplus_kernel(float x,float threshold=20)40 __device__ float softplus_kernel(float x, float threshold = 20) {
41     if (x > threshold) return x;                // too large
42     else if (x < -threshold) return expf(x);    // too small
43     return log1pf(expf(x));
44     //return logf(expf(x) + 1);
45 }
plse_activate_kernel(float x)46 __device__ float plse_activate_kernel(float x)
47 {
48     if(x < -4) return .01f * (x + 4);
49     if(x > 4)  return .01f * (x - 4) + 1;
50     return .125f*x + .5f;
51 }
stair_activate_kernel(float x)52 __device__ float stair_activate_kernel(float x)
53 {
54     int n = floorf(x);
55     if (n%2 == 0) return floorf(x/2.f);
56     else return (x - n) + floorf(x/2.f);
57 }
58 
59 
hardtan_gradient_kernel(float x)60 __device__ float hardtan_gradient_kernel(float x)
61 {
62     if (x > -1 && x < 1) return 1;
63     return 0;
64 }
linear_gradient_kernel(float x)65 __device__ float linear_gradient_kernel(float x){return 1;}
logistic_gradient_kernel(float x)66 __device__ float logistic_gradient_kernel(float x){return (1-x)*x;}
loggy_gradient_kernel(float x)67 __device__ float loggy_gradient_kernel(float x)
68 {
69     float y = (x+1.F)/2.F;
70     return 2*(1-y)*y;
71 }
relu_gradient_kernel(float x)72 __device__ float relu_gradient_kernel(float x){return (x>0);}
relu6_gradient_kernel(float x)73 __device__ float relu6_gradient_kernel(float x) { return (x > 0 && x < 6); }
elu_gradient_kernel(float x)74 __device__ float elu_gradient_kernel(float x){return (x >= 0) + (x < 0)*(x + 1);}
selu_gradient_kernel(float x)75 __device__ float selu_gradient_kernel(float x) { return (x >= 0)*1.0507f + (x < 0)*(x + 1.0507f*1.6732f); }
relie_gradient_kernel(float x)76 __device__ float relie_gradient_kernel(float x){return (x>0) ? 1 : .01f;}
ramp_gradient_kernel(float x)77 __device__ float ramp_gradient_kernel(float x){return (x>0)+.1f;}
leaky_gradient_kernel(float x)78 __device__ float leaky_gradient_kernel(float x){return (x>0) ? 1 : .1f;}
tanh_gradient_kernel(float x)79 __device__ float tanh_gradient_kernel(float x){return 1-x*x;}
sech_gpu(float x)80 __device__ float sech_gpu(float x) { return 2 / (expf(x) + expf(-x)); }
gelu_gradient_kernel(float x)81 __device__ float gelu_gradient_kernel(float x) {
82     const float x3 = powf(x, 3);
83     return 0.5*tanhf(0.0356774*x3 + 0.797885*x) + (0.0535161*x3 + 0.398942*x) * powf(sech_gpu(0.0356774*x3 + 0.797885*x), 2) + 0.5;
84 }
plse_gradient_kernel(float x)85 __device__ float plse_gradient_kernel(float x){return (x < 0 || x > 1) ? .01f : .125f;}
stair_gradient_kernel(float x)86 __device__ float stair_gradient_kernel(float x)
87 {
88     if (floor(x) == x) return 0;
89     return 1;
90 }
91 
activate_kernel(float x,ACTIVATION a)92 __device__ float activate_kernel(float x, ACTIVATION a)
93 {
94     switch(a){
95         case LINEAR:
96             return linear_activate_kernel(x);
97         case LOGISTIC:
98             return logistic_activate_kernel(x);
99         case LOGGY:
100             return loggy_activate_kernel(x);
101         case RELU:
102             return relu_activate_kernel(x);
103         case RELU6:
104             return relu6_activate_kernel(x);
105         case ELU:
106             return elu_activate_kernel(x);
107         case SELU:
108             return selu_activate_kernel(x);
109         case GELU:
110             return gelu_activate_kernel(x);
111         case RELIE:
112             return relie_activate_kernel(x);
113         case RAMP:
114             return ramp_activate_kernel(x);
115         case LEAKY:
116             return leaky_activate_kernel(x);
117         case TANH:
118             return tanh_activate_kernel(x);
119         case PLSE:
120             return plse_activate_kernel(x);
121         case STAIR:
122             return stair_activate_kernel(x);
123         case HARDTAN:
124             return hardtan_activate_kernel(x);
125         case LHTAN:
126             return lhtan_activate_kernel(x);
127     }
128     return 0;
129 }
130 
gradient_kernel(float x,ACTIVATION a)131 __device__ float gradient_kernel(float x, ACTIVATION a)
132 {
133     switch (a) {
134     case LINEAR:
135         return linear_gradient_kernel(x);
136     case LOGISTIC:
137         return logistic_gradient_kernel(x);
138     case LOGGY:
139         return loggy_gradient_kernel(x);
140     case RELU:
141         return relu_gradient_kernel(x);
142     case RELU6:
143         return relu6_gradient_kernel(x);
144     case NORM_CHAN:
145         return relu_gradient_kernel(x);
146     case ELU:
147         return elu_gradient_kernel(x);
148     case SELU:
149         return selu_gradient_kernel(x);
150     case GELU:
151         return gelu_gradient_kernel(x);
152     case RELIE:
153         return relie_gradient_kernel(x);
154     case RAMP:
155         return ramp_gradient_kernel(x);
156     case LEAKY:
157         return leaky_gradient_kernel(x);
158     case TANH:
159         return tanh_gradient_kernel(x);
160     case PLSE:
161         return plse_gradient_kernel(x);
162     case STAIR:
163         return stair_gradient_kernel(x);
164     case HARDTAN:
165         return hardtan_gradient_kernel(x);
166     case LHTAN:
167         return lhtan_gradient_kernel(x);
168     }
169     return 0;
170 }
171 
binary_gradient_array_kernel(float * x,float * dy,int n,int s,BINARY_ACTIVATION a,float * dx)172 __global__ void binary_gradient_array_kernel(float *x, float *dy, int n, int s, BINARY_ACTIVATION a, float *dx)
173 {
174     int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
175     int i = id % s;
176     int b = id / s;
177     float x1 = x[b*s + i];
178     float x2 = x[b*s + s / 2 + i];
179     if (id < n) {
180         float de = dy[id];
181         dx[b*s + i] = x2*de;
182         dx[b*s + s / 2 + i] = x1*de;
183     }
184 }
185 
binary_gradient_array_gpu(float * x,float * dx,int n,int size,BINARY_ACTIVATION a,float * y)186 extern "C" void binary_gradient_array_gpu(float *x, float *dx, int n, int size, BINARY_ACTIVATION a, float *y)
187 {
188     binary_gradient_array_kernel << <cuda_gridsize(n / 2), BLOCK, 0, get_cuda_stream() >> >(x, dx, n / 2, size, a, y);
189     CHECK_CUDA(cudaPeekAtLastError());
190 }
binary_activate_array_kernel(float * x,int n,int s,BINARY_ACTIVATION a,float * y)191 __global__ void binary_activate_array_kernel(float *x, int n, int s, BINARY_ACTIVATION a, float *y)
192 {
193     int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
194     int i = id % s;
195     int b = id / s;
196     float x1 = x[b*s + i];
197     float x2 = x[b*s + s / 2 + i];
198     if (id < n) y[id] = x1*x2;
199 }
200 
binary_activate_array_gpu(float * x,int n,int size,BINARY_ACTIVATION a,float * y)201 extern "C" void binary_activate_array_gpu(float *x, int n, int size, BINARY_ACTIVATION a, float *y)
202 {
203     binary_activate_array_kernel << <cuda_gridsize(n / 2), BLOCK, 0, get_cuda_stream() >> >(x, n / 2, size, a, y);
204     CHECK_CUDA(cudaPeekAtLastError());
205 }
206 
activate_array_kernel(float * x,int n,ACTIVATION a)207 __global__ void activate_array_kernel(float *x, int n, ACTIVATION a)
208 {
209     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
210     if(i < n) x[i] = activate_kernel(x[i], a);
211 }
212 
213 
214 
activate_array_swish_kernel(float * x,int n,float * output_sigmoid_gpu,float * output_gpu)215 __global__ void activate_array_swish_kernel(float *x, int n, float *output_sigmoid_gpu, float *output_gpu)
216 {
217     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
218     if (i < n) {
219         float x_val = x[i];
220         float sigmoid = logistic_activate_kernel(x_val);
221         if (output_sigmoid_gpu) output_sigmoid_gpu[i] = sigmoid;
222         output_gpu[i] = x_val * sigmoid;
223     }
224 }
225 
mish_njuffa(float x)226 __device__ float mish_njuffa(float x)
227 {
228     float r;
229     float e = expf(x);
230     r = 1.0f / fmaf(fmaf(-0.5f, e, -1.0f), e, -1.0f);
231     r = fmaf(r, x, x);
232     return r;
233 }
234 
mish_yashas(float x)235 __device__ float mish_yashas(float x)
236 {
237     float e = __expf(x);
238     if (x <= -18.0f)
239         return x * e;
240 
241     float n = e * e + 2 * e;
242     if (x <= -5.0f)
243         return x * __fdividef(n, n + 2);
244 
245     return x - 2 * __fdividef(x, n + 2);
246 }
247 
248 // https://github.com/digantamisra98/Mish
activate_array_mish_kernel(float * x,int n,float * activation_input,float * output_gpu)249 __global__ void activate_array_mish_kernel(float *x, int n, float *activation_input, float *output_gpu)
250 {
251     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
252     if (i < n) {
253         const float MISH_THRESHOLD = 20;
254         float x_val = x[i];
255         if (activation_input) activation_input[i] = x_val;    // store value before activation
256         //output_gpu[i] = x_val * tanh_activate_kernel(logf(1 + expf(x_val)));
257 
258         // Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L17-L20
259         // TF: https://github.com/tensorflow/addons/blob/093cdfa85d334cbe19a37624c33198f3140109ed/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h#L40-L49
260         // log1p(x) == log(x + 1)
261         //output_gpu[i] = x_val * tanh_activate_kernel( softplus_kernel(x_val, MISH_THRESHOLD) );
262         output_gpu[i] = mish_yashas(x_val);
263         //output_gpu[i] = mish_njuffa(x_val);
264     }
265 }
266 
activate_array_leaky_kernel(float * x,int n)267 __global__ void activate_array_leaky_kernel(float *x, int n)
268 {
269     int index = blockIdx.x*blockDim.x + threadIdx.x;
270     if (index < n) {
271         x[index] = leaky_activate_kernel(x[index]);
272     }
273 }
274 
activate_array_selu_kernel(float * x,int n)275 __global__ void activate_array_selu_kernel(float *x, int n)
276 {
277     int index = blockIdx.x*blockDim.x + threadIdx.x;
278     if (index < n) {
279         x[index] = selu_activate_kernel(x[index]);
280     }
281 }
282 
activate_array_gelu_kernel(float * x,int n)283 __global__ void activate_array_gelu_kernel(float *x, int n)
284 {
285     int index = blockIdx.x*blockDim.x + threadIdx.x;
286     if (index < n) {
287         x[index] = gelu_activate_kernel(x[index]);
288     }
289 }
290 
activate_array_logistic_kernel(float * x,int n)291 __global__ void activate_array_logistic_kernel(float *x, int n)
292 {
293     int index = blockIdx.x*blockDim.x + threadIdx.x;
294     if (index < n) {
295         x[index] = logistic_activate_kernel(x[index]);
296     }
297 }
298 
activate_array_tanh_kernel(float * x,int n)299 __global__ void activate_array_tanh_kernel(float *x, int n)
300 {
301     int index = blockIdx.x*blockDim.x + threadIdx.x;
302     if (index < n) {
303         x[index] = tanh_activate_kernel(x[index]);
304     }
305 }
306 
activate_array_hardtan_kernel(float * x,int n)307 __global__ void activate_array_hardtan_kernel(float *x, int n)
308 {
309     int index = blockIdx.x*blockDim.x + threadIdx.x;
310     if (index < n) {
311         x[index] = hardtan_activate_kernel(x[index]);
312     }
313 }
314 
activate_array_relu_kernel(float * x,int n)315 __global__ void activate_array_relu_kernel(float *x, int n)
316 {
317     int index = blockIdx.x*blockDim.x + threadIdx.x;
318     if (index < n) {
319         x[index] = relu_activate_kernel(x[index]);
320     }
321 }
322 
activate_array_relu6_kernel(float * x,int n)323 __global__ void activate_array_relu6_kernel(float *x, int n)
324 {
325     int index = blockIdx.x*blockDim.x + threadIdx.x;
326     if (index < n) {
327         x[index] = relu6_activate_kernel(x[index]);
328     }
329 }
330 
gradient_array_kernel(float * x,int n,ACTIVATION a,float * delta)331 __global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delta)
332 {
333     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
334     if(i < n) delta[i] *= gradient_kernel(x[i], a);
335 }
336 
337 // https://github.com/BVLC/caffe/blob/04ab089db018a292ae48d51732dd6c66766b36b6/src/caffe/layers/swish_layer.cu#L28-L30
gradient_array_swish_kernel(float * x,int n,float * sigmoid_gpu,float * delta)338 __global__ void gradient_array_swish_kernel(float *x, int n, float *sigmoid_gpu, float *delta)
339 {
340     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
341     if (i < n) {
342         float swish = x[i];
343         delta[i] *= swish + sigmoid_gpu[i] * (1 - swish); // gradient_kernel(x[i], a);
344     }
345 }
346 
347 // https://github.com/digantamisra98/Mish
gradient_array_mish_kernel(int n,float * activation_input_gpu,float * delta)348 __global__ void gradient_array_mish_kernel(int n, float *activation_input_gpu, float *delta)
349 {
350     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
351     if (i < n) {
352         const float MISH_THRESHOLD = 20.0f;
353 
354         // implementation from TensorFlow: https://github.com/tensorflow/addons/blob/093cdfa85d334cbe19a37624c33198f3140109ed/tensorflow_addons/custom_ops/activations/cc/kernels/mish_op.h#L66-L80
355         // implementation from Pytorch: https://github.com/thomasbrandon/mish-cuda/blob/master/csrc/mish.h#L26-L31
356         // log1p(x) == log(x + 1)
357         const float inp = activation_input_gpu[i];
358         const float sp = softplus_kernel(inp, MISH_THRESHOLD);
359         const float grad_sp = -expm1f(-sp);
360         //const float grad_sp = 1 - expf(-sp);
361         const float tsp = tanh(sp);
362         const float grad_tsp = (1 - tsp*tsp) * grad_sp;
363         const float grad = inp * grad_tsp + tsp;
364         delta[i] *= grad;
365 
366         //float x = activation_input[i];
367         //float d = 2 * expf(x) + expf(2 * x) + 2;
368         //float w = 4 * (x + 1) + 4 * expf(2 * x) + expf(3 * x) + expf(x)*(4 * x + 6);
369         //float derivative = expf(x) * w / (d * d);
370         //delta[i] *= derivative;
371     }
372 }
373 
gradient_array_leaky_kernel(float * x,int n,float * delta)374 __global__ void gradient_array_leaky_kernel(float *x, int n, float *delta)
375 {
376     int index = blockIdx.x*blockDim.x + threadIdx.x;
377     if (index < n) {
378         delta[index] *= leaky_gradient_kernel(x[index]);
379     }
380 }
381 
gradient_array_selu_kernel(float * x,int n,float * delta)382 __global__ void gradient_array_selu_kernel(float *x, int n, float *delta)
383 {
384     int index = blockIdx.x*blockDim.x + threadIdx.x;
385     if (index < n) {
386         delta[index] *= selu_gradient_kernel(x[index]);
387     }
388 }
389 
gradient_array_gelu_kernel(float * x,int n,float * delta)390 __global__ void gradient_array_gelu_kernel(float *x, int n, float *delta)
391 {
392     int index = blockIdx.x*blockDim.x + threadIdx.x;
393     if (index < n) {
394         delta[index] *= gelu_gradient_kernel(x[index]);
395     }
396 }
397 
gradient_array_logistic_kernel(float * x,int n,float * delta)398 __global__ void gradient_array_logistic_kernel(float *x, int n, float *delta)
399 {
400     int index = blockIdx.x*blockDim.x + threadIdx.x;
401     if (index < n) {
402         delta[index] *= logistic_gradient_kernel(x[index]);
403     }
404 }
405 
gradient_array_tanh_kernel(float * x,int n,float * delta)406 __global__ void gradient_array_tanh_kernel(float *x, int n, float *delta)
407 {
408     int index = blockIdx.x*blockDim.x + threadIdx.x;
409     if (index < n) {
410         delta[index] *= tanh_gradient_kernel(x[index]);
411     }
412 }
413 
gradient_array_hardtan_kernel(float * x,int n,float * delta)414 __global__ void gradient_array_hardtan_kernel(float *x, int n, float *delta)
415 {
416     int index = blockIdx.x*blockDim.x + threadIdx.x;
417     if (index < n) {
418         delta[index] *= hardtan_gradient_kernel(x[index]);
419     }
420 }
421 
gradient_array_relu_kernel(float * x,int n,float * delta)422 __global__ void gradient_array_relu_kernel(float *x, int n, float *delta)
423 {
424     int index = blockIdx.x*blockDim.x + threadIdx.x;
425     if (index < n) {
426         delta[index] *= relu_gradient_kernel(x[index]);
427     }
428 }
429 
gradient_array_relu6_kernel(float * x,int n,float * delta)430 __global__ void gradient_array_relu6_kernel(float *x, int n, float *delta)
431 {
432     int index = blockIdx.x*blockDim.x + threadIdx.x;
433     if (index < n) {
434         delta[index] *= relu6_gradient_kernel(x[index]);
435     }
436 }
437 
activate_array_ongpu(float * x,int n,ACTIVATION a)438 extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a)
439 {
440     const int num_blocks = get_number_of_blocks(n, BLOCK);
441     if (a == LINEAR) return;
442     else if(a == LEAKY) activate_array_leaky_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
443     else if (a == LOGISTIC) activate_array_logistic_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
444     else if (a == TANH) activate_array_tanh_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
445     else if (a == HARDTAN) activate_array_hardtan_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
446     else if (a == RELU) activate_array_relu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
447     else if (a == RELU6) activate_array_relu6_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
448     else if (a == SELU) activate_array_selu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
449     else if (a == GELU) activate_array_gelu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n);
450     else
451         activate_array_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream()>>>(x, n, a);
452     CHECK_CUDA(cudaPeekAtLastError());
453 }
454 
activate_array_swish_ongpu(float * x,int n,float * output_sigmoid_gpu,float * output_gpu)455 extern "C" void activate_array_swish_ongpu(float *x, int n, float *output_sigmoid_gpu, float *output_gpu)
456 {
457     const int num_blocks = get_number_of_blocks(n, BLOCK);
458     activate_array_swish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> >(x, n, output_sigmoid_gpu, output_gpu);
459     CHECK_CUDA(cudaPeekAtLastError());
460 }
461 
activate_array_mish_ongpu(float * x,int n,float * activation_input_gpu,float * output_gpu)462 extern "C" void activate_array_mish_ongpu(float *x, int n, float *activation_input_gpu, float *output_gpu)
463 {
464     const int num_blocks = get_number_of_blocks(n, BLOCK);
465     activate_array_mish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> >(x, n, activation_input_gpu, output_gpu);
466     CHECK_CUDA(cudaPeekAtLastError());
467 }
468 
gradient_array_ongpu(float * x,int n,ACTIVATION a,float * delta)469 extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta)
470 {
471     const int num_blocks = get_number_of_blocks(n, BLOCK);
472     if (a == LINEAR) return;
473     else if (a == LEAKY) gradient_array_leaky_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
474     else if (a == LOGISTIC) gradient_array_logistic_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
475     else if (a == TANH) gradient_array_tanh_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
476     else if (a == HARDTAN) gradient_array_hardtan_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
477     else if (a == RELU) gradient_array_relu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
478     else if (a == RELU6) gradient_array_relu6_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
479     //else if (a == NORM_CHAN) gradient_array_relu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
480     else if (a == NORM_CHAN_SOFTMAX || a == NORM_CHAN) {
481         printf(" Error: should be used custom NORM_CHAN_SOFTMAX-function for gradient \n");
482         exit(0);
483     }
484     else if (a == SELU) gradient_array_selu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
485     else if (a == GELU) gradient_array_gelu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
486     else
487         gradient_array_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (x, n, a, delta);
488     CHECK_CUDA(cudaPeekAtLastError());
489 }
490 
491 
gradient_array_swish_ongpu(float * x,int n,float * sigmoid_gpu,float * delta)492 extern "C" void gradient_array_swish_ongpu(float *x, int n, float *sigmoid_gpu, float *delta)
493 {
494     const int num_blocks = get_number_of_blocks(n, BLOCK);
495     gradient_array_swish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (x, n, sigmoid_gpu, delta);
496     CHECK_CUDA(cudaPeekAtLastError());
497 }
498 
gradient_array_mish_ongpu(int n,float * activation_input_gpu,float * delta)499 extern "C" void gradient_array_mish_ongpu(int n, float *activation_input_gpu, float *delta)
500 {
501     const int num_blocks = get_number_of_blocks(n, BLOCK);
502     gradient_array_mish_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (n, activation_input_gpu, delta);
503     CHECK_CUDA(cudaPeekAtLastError());
504 }
505 
506 
activate_array_normalize_channels_kernel(float * x,int size,int batch,int channels,int wh_step,float * output_gpu)507 __global__ void activate_array_normalize_channels_kernel(float *x, int size, int batch, int channels, int wh_step, float *output_gpu)
508 {
509     int i = blockIdx.x * blockDim.x + threadIdx.x;
510 
511     int wh_i = i % wh_step;
512     int b = i / wh_step;
513 
514     const float eps = 0.0001;
515     if (i < size) {
516         float sum = eps;
517         int k;
518         for (k = 0; k < channels; ++k) {
519             float val = x[wh_i + k * wh_step + b*wh_step*channels];
520             if (val > 0) sum += val;
521         }
522         for (k = 0; k < channels; ++k) {
523             float val = x[wh_i + k * wh_step + b*wh_step*channels];
524             if (val > 0) val = val / sum;
525             else val = 0;
526             output_gpu[wh_i + k * wh_step + b*wh_step*channels] = val;
527         }
528     }
529 }
530 
activate_array_normalize_channels_ongpu(float * x,int n,int batch,int channels,int wh_step,float * output_gpu)531 extern "C" void activate_array_normalize_channels_ongpu(float *x, int n, int batch, int channels, int wh_step, float *output_gpu)
532 {
533     // n = w*h*c*batch
534     // size = w*h*batch
535     int size = n / channels;
536 
537     const int num_blocks = get_number_of_blocks(size, BLOCK);
538 
539     activate_array_normalize_channels_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (x, size, batch, channels, wh_step, output_gpu);
540     CHECK_CUDA(cudaPeekAtLastError());
541 }
542 
543 
544 
activate_array_normalize_channels_softmax_kernel(float * x,int size,int batch,int channels,int wh_step,float * output_gpu,int use_max_val)545 __global__ void activate_array_normalize_channels_softmax_kernel(float *x, int size, int batch, int channels, int wh_step, float *output_gpu, int use_max_val)
546 {
547     int i = blockIdx.x * blockDim.x + threadIdx.x;
548 
549     int wh_i = i % wh_step;
550     int b = i / wh_step;
551 
552     const float eps = 0.0001;
553     if (i < size) {
554         float sum = eps;
555         float max_val = -FLT_MAX;
556         int k;
557         if (use_max_val) {
558             for (k = 0; k < channels; ++k) {
559                 float val = x[wh_i + k * wh_step + b*wh_step*channels];
560                 if (val > max_val || k == 0) max_val = val;
561             }
562         }
563         else
564             max_val = 0;
565 
566         for (k = 0; k < channels; ++k) {
567             float val = x[wh_i + k * wh_step + b*wh_step*channels];
568             sum += expf(val - max_val);
569         }
570         for (k = 0; k < channels; ++k) {
571             float val = x[wh_i + k * wh_step + b*wh_step*channels];
572             val = expf(val - max_val) / sum;
573             if (isnan(val) || isinf(val)) val = 0;
574             output_gpu[wh_i + k * wh_step + b*wh_step*channels] = val;
575         }
576     }
577 }
578 
activate_array_normalize_channels_softmax_ongpu(float * x,int n,int batch,int channels,int wh_step,float * output_gpu,int use_max_val)579 extern "C" void activate_array_normalize_channels_softmax_ongpu(float *x, int n, int batch, int channels, int wh_step, float *output_gpu, int use_max_val)
580 {
581     // n = w*h*c*batch
582     // size = w*h*batch
583     int size = n / channels;
584 
585     const int num_blocks = get_number_of_blocks(size, BLOCK);
586 
587     activate_array_normalize_channels_softmax_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (x, size, batch, channels, wh_step, output_gpu, use_max_val);
588     CHECK_CUDA(cudaPeekAtLastError());
589 }
590 
591 
592 
gradient_array_normalize_channels_softmax_kernel(float * x,int size,int batch,int channels,int wh_step,float * delta_gpu)593 __global__ void gradient_array_normalize_channels_softmax_kernel(float *x, int size, int batch, int channels, int wh_step, float *delta_gpu)
594 {
595     int i = blockIdx.x * blockDim.x + threadIdx.x;
596 
597     int wh_i = i % wh_step;
598     int b = i / wh_step;
599 
600     if (i < size) {
601         int k;
602         /*
603         float grad = 0;
604         for (k = 0; k < channels; ++k) {
605             const int index = wh_i + k * wh_step + b*wh_step*channels;
606             float out = x[index];
607             float delta = delta_gpu[index];
608             grad += out*fabs(delta);
609         }
610         */
611         for (k = 0; k < channels; ++k) {
612             const int index = wh_i + k * wh_step + b*wh_step*channels;
613             float delta = delta_gpu[index];
614             float grad = x[index] * (1 - x[index]);
615             delta = delta * grad;
616             if (isnan(delta) || isinf(delta)) delta = 0;
617             delta_gpu[index] = delta;
618         }
619     }
620 }
621 
gradient_array_normalize_channels_softmax_ongpu(float * output_gpu,int n,int batch,int channels,int wh_step,float * delta_gpu)622 extern "C" void gradient_array_normalize_channels_softmax_ongpu(float *output_gpu, int n, int batch, int channels, int wh_step, float *delta_gpu)
623 {
624     // n = w*h*c*batch
625     // size = w*h*batch
626     int size = n / channels;
627 
628     const int num_blocks = get_number_of_blocks(size, BLOCK);
629 
630     gradient_array_normalize_channels_softmax_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (output_gpu, size, batch, channels, wh_step, delta_gpu);
631     CHECK_CUDA(cudaPeekAtLastError());
632 }
633 
634 
gradient_array_normalize_channels_kernel(float * x,int size,int batch,int channels,int wh_step,float * delta_gpu)635 __global__ void gradient_array_normalize_channels_kernel(float *x, int size, int batch, int channels, int wh_step, float *delta_gpu)
636 {
637     int i = blockIdx.x * blockDim.x + threadIdx.x;
638 
639     int wh_i = i % wh_step;
640     int b = i / wh_step;
641 
642     if (i < size) {
643         int k;
644         /*
645         float grad = 0;
646         for (k = 0; k < channels; ++k) {
647             const int index = wh_i + k * wh_step + b*wh_step*channels;
648             float out = x[index];
649             float delta = delta_gpu[index];
650             grad += out*fabs(delta);
651         }
652         */
653         for (k = 0; k < channels; ++k) {
654             const int index = wh_i + k * wh_step + b*wh_step*channels;
655             if (x[index] > 0) {
656                 float delta = delta_gpu[index];
657                 float grad = x[index];
658                 delta = delta * grad;
659                 delta_gpu[index] = delta;
660             }
661         }
662     }
663 }
664 
gradient_array_normalize_channels_ongpu(float * output_gpu,int n,int batch,int channels,int wh_step,float * delta_gpu)665 extern "C" void gradient_array_normalize_channels_ongpu(float *output_gpu, int n, int batch, int channels, int wh_step, float *delta_gpu)
666 {
667     // n = w*h*c*batch
668     // size = w*h*batch
669     int size = n / channels;
670 
671     const int num_blocks = get_number_of_blocks(size, BLOCK);
672 
673     gradient_array_normalize_channels_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (output_gpu, size, batch, channels, wh_step, delta_gpu);
674     CHECK_CUDA(cudaPeekAtLastError());
675 }