1 #include <cuda_runtime.h>
2 #include <curand.h>
3 #include <cublas_v2.h>
4 #include <assert.h>
5 #include <float.h>
6 
7 #include "blas.h"
8 #include "dark_cuda.h"
9 #include "utils.h"
10 #include "tree.h"
11 
12 __inline__ __device__
warpAllReduceSum(float val)13 float warpAllReduceSum(float val) {
14     for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2)
15 #if CUDART_VERSION >= 9000
16         val += __shfl_xor_sync(0xffffffff, val, mask);
17 #else
18         val += __shfl_xor(val, mask);
19 #endif
20     return val;
21 }
22 
compare_2_arrays_kernel(float * one,float * two,int size)23 __global__ void compare_2_arrays_kernel(float *one, float *two, int size)
24 {
25     const int index = blockIdx.x*blockDim.x + threadIdx.x;
26     if (index >= size) return;
27 
28     const float diff = 100 * fabs(one[index] - two[index]) / fabs(one[index]);
29 
30     if (diff > 10) printf(" i: %d - one = %f, two = %f, diff = %f %% \n", index, one[index], two[index], diff);
31 }
32 
compare_2_arrays_gpu(float * one,float * two,int size)33 void compare_2_arrays_gpu(float *one, float *two, int size)
34 {
35     const int num_blocks = get_number_of_blocks(size, BLOCK);
36 
37     compare_2_arrays_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(one, two, size);
38     CHECK_CUDA(cudaPeekAtLastError());
39     CHECK_CUDA(cudaDeviceSynchronize());
40 }
41 
mean_array_kernel(float * src,int size,float alpha,float * avg)42 __global__ void mean_array_kernel(float *src, int size, float alpha, float *avg)
43 {
44     const int i = blockIdx.x*blockDim.x + threadIdx.x;
45     if (i >= size) return;
46 
47     avg[i] = avg[i] * (1 - alpha) + src[i] * alpha;
48     src[i] = avg[i];
49 }
50 
51 
mean_array_gpu(float * src,int size,float alpha,float * avg)52 void mean_array_gpu(float *src, int size, float alpha, float *avg)
53 {
54     const int num_blocks = get_number_of_blocks(size, BLOCK);
55 
56     mean_array_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(src, size, alpha, avg);
57     CHECK_CUDA(cudaPeekAtLastError());
58 }
59 
60 
scale_bias_kernel(float * output,float * scale,int batch,int filters,int spatial,int current_size)61 __global__ void scale_bias_kernel(float *output, float *scale, int batch, int filters, int spatial, int current_size)
62 {
63     const int index = blockIdx.x*blockDim.x + threadIdx.x;
64     if (index >= current_size) return;
65 
66     int f = (index / spatial) % filters;
67     output[index] *= scale[f];
68 }
69 
scale_bias_gpu(float * output,float * scale,int batch,int filters,int spatial)70 void scale_bias_gpu(float *output, float *scale, int batch, int filters, int spatial)
71 {
72     const int current_size = batch * filters * spatial;
73     const int num_blocks = get_number_of_blocks(current_size, BLOCK);
74 
75     scale_bias_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(output, scale, batch, filters, spatial, current_size);
76     CHECK_CUDA(cudaPeekAtLastError());
77 }
78 
79 
backward_scale_kernel(float * x_norm,float * delta,int batch,int n,int size,float * scale_updates)80 __global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
81 {
82     __shared__ float part[BLOCK];
83     int i,b;
84     int filter = blockIdx.x;
85     int p = threadIdx.x;
86     float sum = 0;
87     for(b = 0; b < batch; ++b){
88         for(i = 0; i < size; i += BLOCK){
89             int index = p + i + size*(filter + n*b);
90             sum += (p+i < size) ? delta[index]*x_norm[index] : 0;
91         }
92     }
93     part[p] = sum;
94     __syncthreads();
95     if (p == 0) {
96         for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i];
97     }
98 }
99 
backward_scale_gpu(float * x_norm,float * delta,int batch,int n,int size,float * scale_updates)100 void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
101 {
102     backward_scale_kernel<<<n, BLOCK, 0, get_cuda_stream() >>>(x_norm, delta, batch, n, size, scale_updates);
103     CHECK_CUDA(cudaPeekAtLastError());
104 }
105 
add_bias_kernel(float * output,float * biases,int batch,int filters,int spatial,int current_size)106 __global__ void add_bias_kernel(float *output, float *biases, int batch, int filters, int spatial, int current_size)
107 {
108     const int index = blockIdx.x*blockDim.x + threadIdx.x;
109     if (index >= current_size) return;
110 
111     int f = (index / spatial) % filters;
112     output[index] += biases[f];
113 }
114 
add_bias_gpu(float * output,float * biases,int batch,int filters,int spatial)115 void add_bias_gpu(float *output, float *biases, int batch, int filters, int spatial)
116 {
117     const int current_size = batch * filters * spatial;
118     const int num_blocks = get_number_of_blocks(current_size, BLOCK);
119 
120     add_bias_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(output, biases, batch, filters, spatial, current_size);
121     CHECK_CUDA(cudaPeekAtLastError());
122 }
123 
backward_bias_kernel(float * bias_updates,float * delta,int batch,int n,int size)124 __global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
125 {
126     __shared__ float part[BLOCK];
127     int i,b;
128     int filter = blockIdx.x;
129     int p = threadIdx.x;
130     float sum = 0;
131     for(b = 0; b < batch; ++b){
132         for(i = 0; i < size; i += BLOCK){
133             int index = p + i + size*(filter + n*b);
134             sum += (p+i < size) ? delta[index] : 0;
135         }
136     }
137     part[p] = sum;
138     __syncthreads();
139     if (p == 0) {
140         for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
141     }
142 }
143 
144 /*
145 __global__ void dot_kernel(float *output, float scale, int batch, int n, int size, float *delta)
146 {
147     int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
148     int f1 = index / n;
149     int f2 = index % n;
150     if (f2 <= f1) return;
151 
152     float sum = 0;
153     float norm1 = 0;
154     float norm2 = 0;
155     int b, i;
156     for(b = 0; b <  batch; ++b){
157         for(i = 0; i < size; ++i){
158             int i1 = b * size * n + f1 * size + i;
159             int i2 = b * size * n + f2 * size + i;
160             sum += output[i1] * output[i2];
161             norm1 += output[i1] * output[i1];
162             norm2 += output[i2] * output[i2];
163         }
164     }
165     norm1 = sqrt(norm1);
166     norm2 = sqrt(norm2);
167     float norm = norm1 * norm2;
168     sum = sum / norm;
169     for(b = 0; b <  batch; ++b){
170         for(i = 0; i < size; ++i){
171             int i1 = b * size * n + f1 * size + i;
172             int i2 = b * size * n + f2 * size + i;
173             delta[i1] += - scale * sum * output[i2] / norm;
174             delta[i2] += - scale * sum * output[i1] / norm;
175         }
176     }
177 }
178 
179 void dot_error_gpu(layer l)
180 {
181     dot_kernel<<<cuda_gridsize(l.n*l.n), BLOCK, 0, get_cuda_stream()>>>(l.output_gpu, l.dot, l.batch, l.n, l.out_w * l.out_h, l.delta_gpu);
182     CHECK_CUDA(cudaPeekAtLastError());
183 }
184 */
185 
backward_bias_gpu(float * bias_updates,float * delta,int batch,int n,int size)186 void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
187 {
188     backward_bias_kernel<<<n, BLOCK, 0, get_cuda_stream() >>>(bias_updates, delta, batch, n, size);
189     CHECK_CUDA(cudaPeekAtLastError());
190 }
191 
adam_kernel(int N,float * x,float * m,float * v,float B1,float B2,float rate,float eps,int t)192 __global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
193 {
194     int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
195     if (index >= N) return;
196 
197     float mhat = m[index] / (1.f - powf(B1, t));
198     float vhat = v[index] / (1.f - powf(B2, t));
199 
200     x[index] = x[index] + rate * mhat / (sqrtf(vhat) + eps);
201 }
202 
adam_gpu(int n,float * x,float * m,float * v,float B1,float B2,float rate,float eps,int t)203 extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
204 {
205     adam_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> >(n, x, m, v, B1, B2, rate, eps, t);
206     CHECK_CUDA(cudaPeekAtLastError());
207 }
208 
adam_update_gpu(float * w,float * d,float * m,float * v,float B1,float B2,float eps,float decay,float rate,int n,int batch,int t)209 extern "C" void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, float eps, float decay, float rate, int n, int batch, int t)
210 {
211     scal_ongpu(n, B1, m, 1);
212     scal_ongpu(n, B2, v, 1);
213     axpy_ongpu(n, -decay*batch, w, 1, d, 1);
214 
215     axpy_ongpu(n, (1 - B1), d, 1, m, 1);
216     mul_ongpu(n, d, 1, d, 1);
217     axpy_ongpu(n, (1 - B2), d, 1, v, 1);
218 
219     adam_gpu(n, w, m, v, B1, B2, rate, eps, t);
220     fill_ongpu(n, 0, d, 1);
221     CHECK_CUDA(cudaPeekAtLastError());
222 }
223 
normalize_kernel(int N,float * x,float * mean,float * variance,int batch,int filters,int spatial)224 __global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial)
225 {
226     const int index = blockIdx.x*blockDim.x + threadIdx.x;
227     if (index >= N) return;
228     int f = (index / spatial) % filters;
229 
230     x[index] = (x[index] - mean[f]) / (sqrtf(variance[f] + .00001f));
231 }
232 
normalize_gpu(float * x,float * mean,float * variance,int batch,int filters,int spatial)233 extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial)
234 {
235     const int current_size = batch * filters * spatial;
236     const int num_blocks = get_number_of_blocks(current_size, BLOCK);
237 
238     normalize_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(current_size, x, mean, variance, batch, filters, spatial);
239     CHECK_CUDA(cudaPeekAtLastError());
240 }
241 
normalize_delta_kernel(int N,float * x,float * mean,float * variance,float * mean_delta,float * variance_delta,int batch,int filters,int spatial,float * delta)242 __global__ void normalize_delta_kernel(int N, float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
243 {
244     int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
245     if (index >= N) return;
246     int f = (index/spatial)%filters;
247 
248     delta[index] = delta[index] * 1.F/(sqrtf(variance[f]) + .000001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
249 }
250 
normalize_delta_gpu(float * x,float * mean,float * variance,float * mean_delta,float * variance_delta,int batch,int filters,int spatial,float * delta)251 extern "C" void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
252 {
253     size_t N = batch*filters*spatial;
254     normalize_delta_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, x, mean, variance, mean_delta, variance_delta, batch, filters, spatial, delta);
255     CHECK_CUDA(cudaPeekAtLastError());
256 }
257 
variance_delta_kernel(float * x,float * delta,float * mean,float * variance,int batch,int filters,int spatial,float * variance_delta)258 __global__ void  variance_delta_kernel(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
259 {
260     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
261     if (i >= filters) return;
262     int j,k;
263     variance_delta[i] = 0;
264     for(j = 0; j < batch; ++j){
265         for(k = 0; k < spatial; ++k){
266             int index = j*filters*spatial + i*spatial + k;
267             variance_delta[i] += delta[index]*(x[index] - mean[i]);
268         }
269     }
270     variance_delta[i] *= -.5 * powf(variance[i] + .000001f, (float)(-3./2.));
271 }
272 
accumulate_kernel(float * x,int n,int groups,float * sum)273 __global__ void accumulate_kernel(float *x, int n, int groups, float *sum)
274 {
275     int k;
276     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
277     if (i >= groups) return;
278     sum[i] = 0;
279     for(k = 0; k < n; ++k){
280         sum[i] += x[k*groups + i];
281     }
282 }
283 
fast_mean_delta_kernel(float * delta,float * variance,int batch,int filters,int spatial,float * mean_delta)284 __global__ void fast_mean_delta_kernel(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
285 {
286     const int threads = BLOCK;
287     __shared__ float local[threads];
288 
289     int id = threadIdx.x;
290     local[id] = 0;
291 
292     int filter = blockIdx.x;
293 
294     int i, j;
295     for(j = 0; j < batch; ++j){
296         for(i = 0; i < spatial; i += threads){
297             int index = j*spatial*filters + filter*spatial + i + id;
298             local[id] += (i+id < spatial) ? delta[index] : 0;
299         }
300     }
301     __syncthreads();
302 
303     if(id == 0){
304         mean_delta[filter] = 0;
305         for(i = 0; i < threads; ++i){
306             mean_delta[filter] += local[i];
307         }
308         mean_delta[filter] *= (-1.F/sqrtf(variance[filter] + .000001f));
309     }
310 }
311 
fast_variance_delta_kernel(float * x,float * delta,float * mean,float * variance,int batch,int filters,int spatial,float * variance_delta)312 __global__ void  fast_variance_delta_kernel(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
313 {
314     const int threads = BLOCK;
315     __shared__ float local[threads];
316 
317     int id = threadIdx.x;
318     local[id] = 0;
319 
320     int filter = blockIdx.x;
321 
322     int i, j;
323     for(j = 0; j < batch; ++j){
324         for(i = 0; i < spatial; i += threads){
325             int index = j*spatial*filters + filter*spatial + i + id;
326 
327             local[id] += (i+id < spatial) ? delta[index]*(x[index] - mean[filter]) : 0;
328         }
329     }
330     __syncthreads();
331 
332     if(id == 0){
333         variance_delta[filter] = 0;
334         for(i = 0; i < threads; ++i){
335             variance_delta[filter] += local[i];
336         }
337         variance_delta[filter] *= -.5 * powf(variance[filter] + .000001f, (float)(-3./2.));
338     }
339 }
340 
341 
mean_delta_kernel(float * delta,float * variance,int batch,int filters,int spatial,float * mean_delta)342 __global__ void mean_delta_kernel(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
343 {
344     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
345     if (i >= filters) return;
346     int j,k;
347     mean_delta[i] = 0;
348     for (j = 0; j < batch; ++j) {
349         for (k = 0; k < spatial; ++k) {
350             int index = j*filters*spatial + i*spatial + k;
351             mean_delta[i] += delta[index];
352         }
353     }
354     mean_delta[i] *= (-1.F/sqrtf(variance[i] + .000001f));
355 }
356 
mean_delta_gpu(float * delta,float * variance,int batch,int filters,int spatial,float * mean_delta)357 extern "C" void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
358 {
359     mean_delta_kernel<<<cuda_gridsize(filters), BLOCK, 0, get_cuda_stream() >>>(delta, variance, batch, filters, spatial, mean_delta);
360     CHECK_CUDA(cudaPeekAtLastError());
361 }
362 
fast_mean_delta_gpu(float * delta,float * variance,int batch,int filters,int spatial,float * mean_delta)363 extern "C" void fast_mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
364 {
365     fast_mean_delta_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(delta, variance, batch, filters, spatial, mean_delta);
366     CHECK_CUDA(cudaPeekAtLastError());
367 }
368 
fast_variance_delta_gpu(float * x,float * delta,float * mean,float * variance,int batch,int filters,int spatial,float * variance_delta)369 extern "C" void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
370 {
371     fast_variance_delta_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(x, delta, mean, variance, batch, filters, spatial, variance_delta);
372     CHECK_CUDA(cudaPeekAtLastError());
373 }
374 
mean_kernel(float * x,int batch,int filters,int spatial,float * mean)375 __global__ void  mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
376 {
377     float scale = 1.F/(batch * spatial);
378     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
379     if (i >= filters) return;
380     int j,k;
381     mean[i] = 0;
382     for(j = 0; j < batch; ++j){
383         for(k = 0; k < spatial; ++k){
384             int index = j*filters*spatial + i*spatial + k;
385             mean[i] += x[index];
386         }
387     }
388     mean[i] *= scale;
389 }
390 
variance_kernel(float * x,float * mean,int batch,int filters,int spatial,float * variance)391 __global__ void variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
392 {
393     float scale = 1.F/(batch * spatial - 1);
394     int j,k;
395     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
396     if (i >= filters) return;
397     variance[i] = 0;
398     for(j = 0; j < batch; ++j){
399         for(k = 0; k < spatial; ++k){
400             int index = j*filters*spatial + i*spatial + k;
401             variance[i] += powf((x[index] - mean[i]), 2);
402         }
403     }
404     variance[i] *= scale;
405 }
406 
reorg_kernel(int N,float * x,int w,int h,int c,int batch,int stride,int forward,float * out)407 __global__ void reorg_kernel(int N, float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
408 {
409     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
410     if(i >= N) return;
411     int in_index = i;
412     int in_w = i%w;
413     i = i/w;
414     int in_h = i%h;
415     i = i/h;
416     int in_c = i%c;
417     i = i/c;
418     int b = i%batch;
419 
420     int out_c = c/(stride*stride);
421 
422     int c2 = in_c % out_c;
423     int offset = in_c / out_c;
424     int w2 = in_w*stride + offset % stride;
425     int h2 = in_h*stride + offset / stride;
426     //printf("%d\n", offset);
427     int out_index = w2 + w*stride*(h2 + h*stride*(c2 + out_c*b));
428 
429    // printf("%d %d %d\n", w2, h2, c2);
430     //printf("%d %d\n", in_index, out_index);
431     //if(out_index >= N || out_index < 0) printf("bad bad bad \n");
432 
433     if(forward) out[out_index] = x[in_index];
434     else out[in_index] = x[out_index];
435     //if(forward) out[1] = x[1];
436     //else out[0] = x[0];
437 }
438 
constrain_weight_updates_kernel(int N,float coef,float * weights_gpu,float * weight_updates_gpu)439 __global__ void constrain_weight_updates_kernel(int N, float coef, float *weights_gpu, float *weight_updates_gpu)
440 {
441     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
442     if (i < N) {
443         const float w = weights_gpu[i];
444         const float wu = weight_updates_gpu[i];
445         const float wu_sign = (wu == 0) ? 0 : (fabs(wu) / wu);
446         const float abs_limit = fabs(w * coef);
447         if (fabs(wu) > abs_limit) weight_updates_gpu[i] = abs_limit * wu_sign;
448     }
449 }
450 
constrain_weight_updates_ongpu(int N,float coef,float * weights_gpu,float * weight_updates_gpu)451 extern "C" void constrain_weight_updates_ongpu(int N, float coef, float *weights_gpu, float *weight_updates_gpu)
452 {
453     constrain_weight_updates_kernel << <cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >> >(N, coef, weights_gpu, weight_updates_gpu);
454     CHECK_CUDA(cudaPeekAtLastError());
455 }
456 
axpy_kernel(int N,float ALPHA,float * X,int OFFX,int INCX,float * Y,int OFFY,int INCY)457 __global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX,  float *Y, int OFFY, int INCY)
458 {
459     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
460     if(i < N) Y[OFFY+i*INCY] += ALPHA*X[OFFX+i*INCX];
461 }
462 
pow_kernel(int N,float ALPHA,float * X,int INCX,float * Y,int INCY)463 __global__ void pow_kernel(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
464 {
465     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
466     if(i < N) Y[i*INCY] = powf(X[i*INCX], ALPHA);
467 }
468 
const_kernel(int N,float ALPHA,float * X,int INCX)469 __global__ void const_kernel(int N, float ALPHA, float *X, int INCX)
470 {
471     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
472     if(i < N) X[i*INCX] = ALPHA;
473 }
474 
constrain_kernel(int N,float ALPHA,float * X,int INCX)475 __global__ void constrain_kernel(int N, float ALPHA, float *X, int INCX)
476 {
477     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
478     if(i < N) X[i*INCX] = fminf(ALPHA, fmaxf(-ALPHA, X[i*INCX]));
479 }
constrain_min_max_kernel(int N,float MIN,float MAX,float * X,int INCX)480 __global__ void constrain_min_max_kernel(int N, float MIN, float MAX, float *X, int INCX)
481 {
482     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
483     if (i < N) X[i*INCX] = fminf(MAX, fmaxf(MIN, X[i*INCX]));
484 }
485 
supp_kernel(int N,float ALPHA,float * X,int INCX)486 __global__ void supp_kernel(int N, float ALPHA, float *X, int INCX)
487 {
488     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
489     if(i < N) {
490         if((X[i*INCX] * X[i*INCX]) < (ALPHA * ALPHA)) X[i*INCX] = 0;
491     }
492 }
493 
scal_kernel(int N,float ALPHA,float * X,int INCX)494 __global__ void scal_kernel(int N, float ALPHA, float *X, int INCX)
495 {
496     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
497     if(i < N) X[i*INCX] *= ALPHA;
498 }
499 
scal_add_kernel(int N,float ALPHA,float BETA,float * X,int INCX)500 __global__ void scal_add_kernel(int N, float ALPHA, float BETA, float *X, int INCX)
501 {
502     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
503     if (i < N) X[i*INCX] = X[i*INCX] * ALPHA + BETA;
504 }
505 
fill_kernel(int N,float ALPHA,float * X,int INCX)506 __global__ void fill_kernel(int N, float ALPHA, float *X, int INCX)
507 {
508     const int index = blockIdx.x*blockDim.x + threadIdx.x;
509     if (index >= N) return;
510     X[index*INCX] = ALPHA;
511 }
512 
mask_kernel_new_api(int n,float * x,float mask_num,float * mask,float val)513 __global__ void mask_kernel_new_api(int n, float *x, float mask_num, float *mask, float val)
514 {
515 	int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
516 	if (i < n && mask[i] == mask_num) x[i] = val;
517 }
518 
mask_kernel(int n,float * x,float mask_num,float * mask)519 __global__ void mask_kernel(int n, float *x, float mask_num, float *mask)
520 {
521     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
522     if(i < n && mask[i] == mask_num) x[i] = mask_num;
523 }
524 
copy_kernel(int N,float * X,int OFFX,int INCX,float * Y,int OFFY,int INCY)525 __global__ void copy_kernel(int N,  float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY)
526 {
527     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
528     if(i < N) Y[i*INCY + OFFY] = X[i*INCX + OFFX];
529 }
530 
simple_copy_kernel(int size,float * src,float * dst)531 __global__ void simple_copy_kernel(int size, float *src, float *dst)
532 {
533     int index = blockIdx.x*blockDim.x + threadIdx.x;
534     if (index < size)
535         dst[index] = src[index];
536 }
537 
mul_kernel(int N,float * X,int INCX,float * Y,int INCY)538 __global__ void mul_kernel(int N, float *X, int INCX, float *Y, int INCY)
539 {
540     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
541     if(i < N) Y[i*INCY] *= X[i*INCX];
542 }
543 
544 
fast_mean_kernel(float * x,int batch,int filters,int spatial,float * mean)545 __global__ void  fast_mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
546 {
547     const int threads = BLOCK;
548     __shared__ float local[threads];
549 
550     int id = threadIdx.x;
551     local[id] = 0;
552 
553     int filter = blockIdx.x;
554 
555     int i, j;
556     for(j = 0; j < batch; ++j){
557         for(i = 0; i < spatial; i += threads){
558             int index = j*spatial*filters + filter*spatial + i + id;
559             local[id] += (i+id < spatial) ? x[index] : 0;
560         }
561     }
562     __syncthreads();
563 
564     if(id == 0){
565         float mean_tmp = 0;
566         for(i = 0; i < threads; ++i){
567             mean_tmp += local[i];
568         }
569         mean_tmp /= spatial * batch;
570         mean[filter] = mean_tmp;
571     }
572 }
573 
fast_mean_gpu(float * x,int batch,int filters,int spatial,float * mean)574 extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
575 {
576     fast_mean_kernel << <filters, BLOCK, 0, get_cuda_stream() >> >(x, batch, filters, spatial, mean);
577     CHECK_CUDA(cudaPeekAtLastError());
578 }
579 
fast_variance_kernel(float * x,float * mean,int batch,int filters,int spatial,float * variance)580 __global__ void  fast_variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
581 {
582     const int threads = BLOCK;
583     __shared__ float local[threads];
584 
585     int id = threadIdx.x;
586     local[id] = 0;
587 
588     int filter = blockIdx.x;
589 
590     int i, j;
591     for(j = 0; j < batch; ++j){
592         for(i = 0; i < spatial; i += threads){
593             int index = j*spatial*filters + filter*spatial + i + id;
594 
595             local[id] += (i+id < spatial) ? powf((x[index] - mean[filter]), 2) : 0;
596         }
597     }
598     __syncthreads();
599 
600     if(id == 0){
601         float variance_tmp = 0;
602         for(i = 0; i < threads; ++i){
603             variance_tmp += local[i];
604         }
605         variance_tmp /= (spatial * batch);// -1);
606         variance[filter] = variance_tmp;
607     }
608 }
609 
fast_variance_gpu(float * x,float * mean,int batch,int filters,int spatial,float * variance)610 extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
611 {
612     fast_variance_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(x, mean, batch, filters, spatial, variance);
613     CHECK_CUDA(cudaPeekAtLastError());
614 }
615 
616 
fast_v_cbn_kernel(const float * x,float * mean,int batch,int filters,int spatial,int minibatch_index,int max_minibatch_index,float * m_avg,float * v_avg,float * variance,const float alpha,float * rolling_mean_gpu,float * rolling_variance_gpu,int inverse_variance,float epsilon)617 __global__ void  fast_v_cbn_kernel(const float *x, float *mean, int batch, int filters, int spatial, int minibatch_index, int max_minibatch_index, float *m_avg, float *v_avg, float *variance,
618     const float alpha, float *rolling_mean_gpu, float *rolling_variance_gpu, int inverse_variance, float epsilon)
619 {
620     const int threads = BLOCK;
621     __shared__ float local[threads];
622 
623     int id = threadIdx.x;
624     local[id] = 0;
625 
626     int filter = blockIdx.x;
627 
628     int i, j;
629     for (j = 0; j < batch; ++j) {
630         for (i = 0; i < spatial; i += threads) {
631             int index = j*spatial*filters + filter*spatial + i + id;
632 
633             local[id] += (i + id < spatial) ? powf(x[index], 2) : 0;
634         }
635     }
636     __syncthreads();
637 
638     if (id == 0) {
639         float v_tmp = 0;
640         v_tmp = 0;
641         for (i = 0; i < threads; ++i) {
642             v_tmp += local[i];
643         }
644         v_tmp /= (spatial * batch - 1);
645 
646         v_tmp = fmax(v_tmp, powf(mean[filter], 2));
647 
648 
649         const float alpha_cbn = 1.0f / minibatch_index;
650 
651         m_avg[filter] = alpha_cbn * mean[filter] + (1 - alpha_cbn) * m_avg[filter];
652         mean[filter] = m_avg[filter];
653 
654         v_avg[filter] = alpha_cbn * v_tmp + (1 - alpha_cbn) * v_avg[filter];
655 
656         float variance_tmp = fmax(0.0f, v_avg[filter] - powf(m_avg[filter], 2));
657         if (inverse_variance) variance[filter] = 1.0f / sqrtf(variance_tmp + epsilon);
658         else variance[filter] = variance_tmp;
659 
660         //if (max_minibatch_index == minibatch_index)
661         {
662             if(rolling_mean_gpu) rolling_mean_gpu[filter] = alpha * mean[filter] + (1 - alpha) * rolling_mean_gpu[filter];
663 
664             if(rolling_variance_gpu) rolling_variance_gpu[filter] = alpha * variance_tmp + (1 - alpha) * rolling_variance_gpu[filter];
665         }
666     }
667 }
668 
fast_v_cbn_gpu(const float * x,float * mean,int batch,int filters,int spatial,int minibatch_index,int max_minibatch_index,float * m_avg,float * v_avg,float * variance,const float alpha,float * rolling_mean_gpu,float * rolling_variance_gpu,int inverse_variance,float epsilon)669 extern "C" void fast_v_cbn_gpu(const float *x, float *mean, int batch, int filters, int spatial, int minibatch_index, int max_minibatch_index, float *m_avg, float *v_avg, float *variance,
670     const float alpha, float *rolling_mean_gpu, float *rolling_variance_gpu, int inverse_variance, float epsilon)
671 {
672     fast_v_cbn_kernel << <filters, BLOCK, 0, get_cuda_stream() >> >(x, mean, batch, filters, spatial, minibatch_index, max_minibatch_index, m_avg, v_avg, variance, alpha, rolling_mean_gpu, rolling_variance_gpu, inverse_variance, epsilon);
673     CHECK_CUDA(cudaPeekAtLastError());
674 }
675 
inverse_variance_kernel(int size,float * src,float * dst,float epsilon)676 __global__ void inverse_variance_kernel(int size, float *src, float *dst, float epsilon)
677 {
678     int index = blockIdx.x*blockDim.x + threadIdx.x;
679     if (index < size)
680         dst[index] = 1.0f / sqrtf(src[index] + epsilon);
681 }
682 
inverse_variance_ongpu(int size,float * src,float * dst,float epsilon)683 extern "C" void inverse_variance_ongpu(int size, float *src, float *dst, float epsilon)
684 {
685     const int num_blocks = size / BLOCK + 1;
686     inverse_variance_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(size, src, dst, epsilon);
687     CHECK_CUDA(cudaPeekAtLastError());
688 }
689 
normalize_scale_bias_kernel(int N,float * x,float * mean,float * variance,float * scales,float * biases,int batch,int filters,int spatial,int inverse_variance,float epsilon)690 __global__ void normalize_scale_bias_kernel(int N, float *x, float *mean, float *variance, float *scales, float *biases, int batch, int filters, int spatial, int inverse_variance, float epsilon)
691 {
692     const int index = blockIdx.x*blockDim.x + threadIdx.x;
693     if (index >= N) return;
694     int f = (index / spatial) % filters;
695 
696     float val = 0;
697     if(inverse_variance) val = (x[index] - mean[f]) * variance[f];
698     else val = (x[index] - mean[f]) / (sqrtf(variance[f] + epsilon));
699     val *= scales[f];
700     val += biases[f];
701 
702     if (!isnan(val) && !isinf(val))
703         x[index] = val;
704 }
705 
normalize_scale_bias_gpu(float * x,float * mean,float * variance,float * scales,float * biases,int batch,int filters,int spatial,int inverse_variance,float epsilon)706 extern "C" void normalize_scale_bias_gpu(float *x, float *mean, float *variance, float *scales, float *biases, int batch, int filters, int spatial, int inverse_variance, float epsilon)
707 {
708     const int current_size = batch * filters * spatial;
709     const int num_blocks = get_number_of_blocks(current_size, BLOCK);
710 
711     normalize_scale_bias_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(current_size, x, mean, variance, scales, biases, batch, filters, spatial, inverse_variance, epsilon);
712     CHECK_CUDA(cudaPeekAtLastError());
713 }
714 
mean_gpu(float * x,int batch,int filters,int spatial,float * mean)715 extern "C" void mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
716 {
717     mean_kernel<<<cuda_gridsize(filters), BLOCK, 0, get_cuda_stream() >>>(x, batch, filters, spatial, mean);
718     CHECK_CUDA(cudaPeekAtLastError());
719 }
720 
variance_gpu(float * x,float * mean,int batch,int filters,int spatial,float * variance)721 extern "C" void variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
722 {
723     variance_kernel<<<cuda_gridsize(filters), BLOCK, 0, get_cuda_stream() >>>(x, mean, batch, filters, spatial, variance);
724     CHECK_CUDA(cudaPeekAtLastError());
725 }
726 
axpy_ongpu(int N,float ALPHA,float * X,int INCX,float * Y,int INCY)727 extern "C" void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
728 {
729     axpy_ongpu_offset(N, ALPHA, X, 0, INCX, Y, 0, INCY);
730 }
731 
pow_ongpu(int N,float ALPHA,float * X,int INCX,float * Y,int INCY)732 extern "C" void pow_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
733 {
734     pow_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX, Y, INCY);
735     CHECK_CUDA(cudaPeekAtLastError());
736 }
737 
axpy_ongpu_offset(int N,float ALPHA,float * X,int OFFX,int INCX,float * Y,int OFFY,int INCY)738 extern "C" void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
739 {
740     axpy_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, OFFX, INCX, Y, OFFY, INCY);
741     CHECK_CUDA(cudaPeekAtLastError());
742 }
743 
copy_ongpu(int N,float * X,int INCX,float * Y,int INCY)744 extern "C" void copy_ongpu(int N, float * X, int INCX, float * Y, int INCY)
745 {
746     copy_ongpu_offset(N, X, 0, INCX, Y, 0, INCY);
747 }
748 
simple_copy_ongpu(int size,float * src,float * dst)749 extern "C" void simple_copy_ongpu(int size, float *src, float *dst)
750 {
751     const int num_blocks = size / BLOCK + 1;
752     simple_copy_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(size, src, dst);
753     CHECK_CUDA(cudaPeekAtLastError());
754 }
755 
memcpy_ongpu(void * dst,void * src,int size_bytes)756 extern "C" void memcpy_ongpu(void *dst, void *src, int size_bytes)
757 {
758     CHECK_CUDA(cudaMemcpyAsync(dst, src, size_bytes, cudaMemcpyDefault, get_cuda_stream()));
759     CHECK_CUDA(cudaPeekAtLastError());
760 }
761 
mul_ongpu(int N,float * X,int INCX,float * Y,int INCY)762 extern "C" void mul_ongpu(int N, float * X, int INCX, float * Y, int INCY)
763 {
764     mul_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, X, INCX, Y, INCY);
765     CHECK_CUDA(cudaPeekAtLastError());
766 }
767 
copy_ongpu_offset(int N,float * X,int OFFX,int INCX,float * Y,int OFFY,int INCY)768 extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
769 {
770     copy_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, X, OFFX, INCX, Y, OFFY, INCY);
771     CHECK_CUDA(cudaPeekAtLastError());
772 }
773 
flatten_kernel(int N,float * x,int spatial,int layers,int batch,int forward,float * out)774 __global__ void flatten_kernel(int N, float *x, int spatial, int layers, int batch, int forward, float *out)
775 {
776     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
777     if(i >= N) return;
778     int in_s = i%spatial;
779     i = i/spatial;
780     int in_c = i%layers;
781     i = i/layers;
782     int b = i;
783 
784     int i1 = b*layers*spatial + in_c*spatial + in_s;
785     int i2 = b*layers*spatial + in_s*layers +  in_c;
786 
787     if (forward) out[i2] = x[i1];
788     else out[i1] = x[i2];
789 }
790 
flatten_ongpu(float * x,int spatial,int layers,int batch,int forward,float * out)791 extern "C" void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out)
792 {
793     int size = spatial*batch*layers;
794     flatten_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, x, spatial, layers, batch, forward, out);
795     CHECK_CUDA(cudaPeekAtLastError());
796 }
797 
reorg_ongpu(float * x,int w,int h,int c,int batch,int stride,int forward,float * out)798 extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
799 {
800     int size = w*h*c*batch;
801     reorg_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, x, w, h, c, batch, stride, forward, out);
802     CHECK_CUDA(cudaPeekAtLastError());
803 }
804 
mask_gpu_new_api(int N,float * X,float mask_num,float * mask,float val)805 extern "C" void mask_gpu_new_api(int N, float * X, float mask_num, float * mask, float val)
806 {
807 	mask_kernel_new_api <<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, X, mask_num, mask, val);
808     CHECK_CUDA(cudaPeekAtLastError());
809 }
810 
mask_ongpu(int N,float * X,float mask_num,float * mask)811 extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask)
812 {
813     mask_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, X, mask_num, mask);
814     CHECK_CUDA(cudaPeekAtLastError());
815 }
816 
const_ongpu(int N,float ALPHA,float * X,int INCX)817 extern "C" void const_ongpu(int N, float ALPHA, float * X, int INCX)
818 {
819     const_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX);
820     CHECK_CUDA(cudaPeekAtLastError());
821 }
822 
constrain_ongpu(int N,float ALPHA,float * X,int INCX)823 extern "C" void constrain_ongpu(int N, float ALPHA, float * X, int INCX)
824 {
825     constrain_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX);
826     CHECK_CUDA(cudaPeekAtLastError());
827 }
828 
constrain_min_max_ongpu(int N,float MIN,float MAX,float * X,int INCX)829 extern "C" void constrain_min_max_ongpu(int N, float MIN, float MAX, float * X, int INCX)
830 {
831     constrain_min_max_kernel << <cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >> >(N, MIN, MAX, X, INCX);
832     CHECK_CUDA(cudaPeekAtLastError());
833 }
834 
835 
scal_ongpu(int N,float ALPHA,float * X,int INCX)836 extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX)
837 {
838     scal_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, INCX);
839     CHECK_CUDA(cudaPeekAtLastError());
840 }
841 
scal_add_ongpu(int N,float ALPHA,float BETA,float * X,int INCX)842 extern "C" void scal_add_ongpu(int N, float ALPHA, float BETA, float * X, int INCX)
843 {
844     scal_add_kernel << <cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >> >(N, ALPHA, BETA, X, INCX);
845     CHECK_CUDA(cudaPeekAtLastError());
846 }
847 
supp_ongpu(int N,float ALPHA,float * X,int INCX)848 extern "C" void supp_ongpu(int N, float ALPHA, float * X, int INCX)
849 {
850     supp_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX);
851     CHECK_CUDA(cudaPeekAtLastError());
852 }
853 
fill_ongpu(int N,float ALPHA,float * X,int INCX)854 extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX)
855 {
856     //fill_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, INCX);
857     //CHECK_CUDA(cudaPeekAtLastError());
858     fill_kernel << <get_number_of_blocks(N, BLOCK), BLOCK, 0, get_cuda_stream() >> >(N, ALPHA, X, INCX);
859     CHECK_CUDA(cudaPeekAtLastError());
860 }
861 
gradient_centralization_kernel(int filters,int f_size,float * in)862 __global__ void gradient_centralization_kernel(int filters, int f_size, float *in)
863 {
864     const int index = blockIdx.x*blockDim.x + threadIdx.x;
865     const int tid = index % WARP_SIZE;
866     const int f = index / WARP_SIZE;
867 
868     if (f >= filters) return;
869 
870     float mean = 0;
871     for (int i = 0; i < f_size; i += WARP_SIZE) {
872         mean += warpAllReduceSum(in[f*f_size + i + tid]);
873     }
874     mean = mean / f_size;
875     for (int i = 0; i < f_size; i += WARP_SIZE) {
876         in[f*f_size + i + tid] -= mean;
877     }
878 
879 }
880 
gradient_centralization_gpu(int w,int h,int c,int f,float * in)881 extern "C" void gradient_centralization_gpu(int w, int h, int c, int f, float *in)
882 {
883     const int size = f * WARP_SIZE;
884     const int f_size = c * h * w;
885     if (f_size % WARP_SIZE == 0) {
886 
887         gradient_centralization_kernel << <get_number_of_blocks(size, BLOCK), BLOCK, 0, get_cuda_stream() >> > (f, f_size, in);
888         CHECK_CUDA(cudaPeekAtLastError());
889     }
890 }
891 
relu(float src)892 __device__ float relu(float src) {
893     if (src > 0) return src;
894     return 0;
895 }
896 
lrelu(float src)897 __device__ float lrelu(float src) {
898     const float eps = 0.001;
899     if (src > eps) return src;
900     return eps;
901 }
902 
grad_relu(float src)903 __device__ float grad_relu(float src) {
904     return (src > 0);
905 }
906 
grad_lrelu(float src)907 __device__ float grad_lrelu(float src) {
908     const float eps = 0.001;
909     return (src > eps);
910 }
911 
shortcut_singlelayer_simple_kernel(int size,int src_outputs,int batch,int n,int * outputs_of_layers_gpu,float ** layers_output_gpu,float * out,float * in,float * weights_gpu,int nweights,WEIGHTS_NORMALIZATION_T weights_normalization)912 __global__ void shortcut_singlelayer_simple_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights, WEIGHTS_NORMALIZATION_T weights_normalization)
913 {
914     const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
915     if (id >= size) return;
916 
917     int src_id = id;
918     const int src_i = src_id % src_outputs;
919     src_id /= src_outputs;
920     int src_b = src_id;
921 
922     float out_val = in[id];
923 
924     int add_outputs = outputs_of_layers_gpu[0];
925     if (src_i < add_outputs) {
926         int add_index = add_outputs*src_b + src_i;
927 
928         float *add = layers_output_gpu[0];
929         out_val += add[add_index];
930     }
931     out[id] = out_val;
932 }
933 
shortcut_multilayer_kernel(int size,int src_outputs,int batch,int n,int * outputs_of_layers_gpu,float ** layers_output_gpu,float * out,float * in,float * weights_gpu,int nweights,WEIGHTS_NORMALIZATION_T weights_normalization)934 __global__ void shortcut_multilayer_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights, WEIGHTS_NORMALIZATION_T weights_normalization)
935 {
936     const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
937     if (id >= size) return;
938 
939     // nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
940     const int layer_step = nweights / (n + 1);    // 1 or l.c or (l.c * l.h * l.w)
941     int step = 0;
942     if (nweights > 0) step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
943 
944     int src_id = id;
945     const int src_i = src_id % src_outputs;
946     src_id /= src_outputs;
947     int src_b = src_id;
948 
949     float sum = 1, max_val = -FLT_MAX;
950     if (weights_gpu && weights_normalization) {
951         if (weights_normalization == SOFTMAX_NORMALIZATION) {
952             for (int i = 0; i < (n + 1); ++i) {
953                 const int weights_index = src_i / step + i*layer_step;  // [0 or c or (c, h ,w)]
954                 const float w = weights_gpu[weights_index];
955                 if (max_val < w) max_val = w;
956             }
957         }
958         const float eps = 0.0001;
959         sum = eps;
960         for (int i = 0; i < (n + 1); ++i) {
961             const int weights_index = src_i / step + i*layer_step;  // [0 or c or (c, h ,w)]
962             const float w = weights_gpu[weights_index];
963             if (weights_normalization == RELU_NORMALIZATION) sum += lrelu(w);
964             else if (weights_normalization == SOFTMAX_NORMALIZATION) sum += expf(w - max_val);
965         }
966     }
967 
968     float out_val = 0;
969 
970     if (weights_gpu) {
971         float w = weights_gpu[src_i / step];
972         if (weights_normalization == RELU_NORMALIZATION) w = lrelu(w) / sum;
973         else if (weights_normalization == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
974 
975         out_val = in[id] * w; // [0 or c or (c, h ,w)]
976     }
977     else out_val = in[id];
978 
979     // layers
980     for (int i = 0; i < n; ++i) {
981         int add_outputs = outputs_of_layers_gpu[i];
982         if (src_i < add_outputs) {
983             int add_index = add_outputs*src_b + src_i;
984 
985             float *add = layers_output_gpu[i];
986 
987             if (weights_gpu) {
988                 const int weights_index = src_i / step + (i + 1)*layer_step;  // [0 or c or (c, h ,w)]
989                 float w = weights_gpu[weights_index];
990                 if (weights_normalization == RELU_NORMALIZATION) w = lrelu(w) / sum;
991                 else if (weights_normalization == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
992 
993                 out_val += add[add_index] * w; // [0 or c or (c, h ,w)]
994             }
995             else out_val += add[add_index];
996         }
997     }
998     out[id] = out_val;
999 }
1000 
shortcut_multilayer_gpu(int src_outputs,int batch,int n,int * outputs_of_layers_gpu,float ** layers_output_gpu,float * out,float * in,float * weights_gpu,int nweights,WEIGHTS_NORMALIZATION_T weights_normalization)1001 extern "C" void shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu, float **layers_output_gpu, float *out, float *in, float *weights_gpu, int nweights, WEIGHTS_NORMALIZATION_T weights_normalization)
1002 {
1003     //printf(" src_outputs = %d, batch = %d, n = %d \n", src_outputs, batch, n);
1004     int size = batch * src_outputs;
1005     if (nweights == 0 && n == 1) {
1006         shortcut_singlelayer_simple_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu, layers_output_gpu, out, in, weights_gpu, nweights, weights_normalization);
1007     }
1008     else {
1009         shortcut_multilayer_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu, layers_output_gpu, out, in, weights_gpu, nweights, weights_normalization);
1010     }
1011     CHECK_CUDA(cudaPeekAtLastError());
1012 }
1013 
1014 
backward_shortcut_multilayer_kernel(int size,int src_outputs,int batch,int n,int * outputs_of_layers_gpu,float ** layers_delta_gpu,float * delta_out,float * delta_in,float * weights_gpu,float * weight_updates_gpu,int nweights,float * in,float ** layers_output_gpu,WEIGHTS_NORMALIZATION_T weights_normalization)1015 __global__ void backward_shortcut_multilayer_kernel(int size, int src_outputs, int batch, int n, int *outputs_of_layers_gpu,
1016     float **layers_delta_gpu, float *delta_out, float *delta_in, float *weights_gpu, float *weight_updates_gpu, int nweights, float *in, float **layers_output_gpu, WEIGHTS_NORMALIZATION_T weights_normalization)
1017 {
1018     const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1019     if (id >= size) return;
1020 
1021     // nweights - l.n or l.n*l.c or (l.n*l.c*l.h*l.w)
1022     const int layer_step = nweights / (n + 1);    // 1 or l.c or (l.c * l.h * l.w)
1023     int step = 0;
1024     if (nweights > 0) step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
1025 
1026     int src_id = id;
1027     const int src_i = src_id % src_outputs;
1028     src_id /= src_outputs;
1029     int src_b = src_id;
1030 
1031     float grad = 1, sum = 1, max_val = -FLT_MAX;
1032     int i;
1033     if (weights_gpu && weights_normalization) {
1034         if (weights_normalization == SOFTMAX_NORMALIZATION) {
1035             for (int i = 0; i < (n + 1); ++i) {
1036                 const int weights_index = src_i / step + i*layer_step;  // [0 or c or (c, h ,w)]
1037                 float w = weights_gpu[weights_index];
1038                 if (max_val < w) max_val = w;
1039             }
1040         }
1041         const float eps = 0.0001;
1042         sum = eps;
1043         for (i = 0; i < (n + 1); ++i) {
1044             const int weights_index = src_i / step + i*layer_step;  // [0 or c or (c, h ,w)]
1045             const float w = weights_gpu[weights_index];
1046             if (weights_normalization == RELU_NORMALIZATION) sum += lrelu(w);
1047             else if (weights_normalization == SOFTMAX_NORMALIZATION) sum += expf(w - max_val);
1048         }
1049 
1050     }
1051 
1052     if (weights_gpu) {
1053         float w = weights_gpu[src_i / step];
1054         if (weights_normalization == RELU_NORMALIZATION) w = lrelu(w) / sum;
1055         else if (weights_normalization == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
1056 
1057         if (weights_normalization == RELU_NORMALIZATION) grad = w;
1058         else if (weights_normalization == SOFTMAX_NORMALIZATION) grad = w*(1-w);
1059 
1060         delta_out[id] += delta_in[id] * w; // [0 or c or (c, h ,w)]
1061         float weights_update_tmp = delta_in[id] * in[id] * grad;// / step;
1062 
1063         if (layer_step == 1 && (size/32) > (id/32 + 1)) {
1064             if (isnan(weights_update_tmp) || isinf(weights_update_tmp)) {
1065                 weights_update_tmp = 0;
1066             }
1067             float wu = warpAllReduceSum(weights_update_tmp);
1068             if (threadIdx.x % 32 == 0) {
1069                 if (!isnan(wu) && !isinf(wu))
1070                     atomicAdd(&weight_updates_gpu[src_i / step], wu);
1071             }
1072         }
1073         else {
1074             if (!isnan(weights_update_tmp) && !isinf(weights_update_tmp))
1075                 atomicAdd(&weight_updates_gpu[src_i / step], weights_update_tmp);
1076                 //weight_updates_gpu[src_i / step] += weights_update_tmp;
1077         }
1078     }
1079     else delta_out[id] += delta_in[id];
1080 
1081     // layers
1082     for (int i = 0; i < n; ++i) {
1083         int add_outputs = outputs_of_layers_gpu[i];
1084         if (src_i < add_outputs) {
1085             int add_index = add_outputs*src_b + src_i;
1086             int out_index = id;
1087 
1088             float *layer_delta = layers_delta_gpu[i];
1089             if (weights_gpu) {
1090                 float *add = layers_output_gpu[i];
1091 
1092                 const int weights_index = src_i / step + (i + 1)*layer_step;  // [0 or c or (c, h ,w)]
1093                 float w = weights_gpu[weights_index];
1094                 if (weights_normalization == RELU_NORMALIZATION) w = lrelu(w) / sum;
1095                 else if (weights_normalization == SOFTMAX_NORMALIZATION) w = expf(w - max_val) / sum;
1096 
1097                 if (weights_normalization == RELU_NORMALIZATION) grad = w;
1098                 else if (weights_normalization == SOFTMAX_NORMALIZATION) grad = w*(1 - w);
1099 
1100                 layer_delta[add_index] += delta_in[id] * w;
1101                 float weights_update_tmp = delta_in[id] * add[add_index] * grad;// / step;
1102 
1103                 if (layer_step == 1 && (size / 32) > (id / 32 + 1)) {
1104                     if (isnan(weights_update_tmp) || isinf(weights_update_tmp)) {
1105                         weights_update_tmp = 0;
1106                     }
1107                     float wu = warpAllReduceSum(weights_update_tmp);
1108                     if (threadIdx.x % 32 == 0) {
1109                         if (!isnan(wu) && !isinf(wu))
1110                             atomicAdd(&weight_updates_gpu[weights_index], wu);
1111                         //if(weights_gpu[weights_index] != 1) printf(" wu = %f, weights_update_tmp = %f, w = %f, weights_gpu[weights_index] = %f, grad = %f, weights_normalization = %d ",
1112                         //    wu, weights_update_tmp, w, weights_gpu[weights_index], grad, weights_normalization);
1113                     }
1114                 }
1115                 else {
1116                     if (!isnan(weights_update_tmp) && !isinf(weights_update_tmp))
1117                         atomicAdd(&weight_updates_gpu[weights_index], weights_update_tmp);
1118                         //weight_updates_gpu[weights_index] += weights_update_tmp;
1119                 }
1120             }
1121             else layer_delta[add_index] += delta_in[id];
1122         }
1123     }
1124 }
1125 
backward_shortcut_multilayer_gpu(int src_outputs,int batch,int n,int * outputs_of_layers_gpu,float ** layers_delta_gpu,float * delta_out,float * delta_in,float * weights_gpu,float * weight_updates_gpu,int nweights,float * in,float ** layers_output_gpu,WEIGHTS_NORMALIZATION_T weights_normalization)1126 extern "C" void backward_shortcut_multilayer_gpu(int src_outputs, int batch, int n, int *outputs_of_layers_gpu,
1127     float **layers_delta_gpu, float *delta_out, float *delta_in, float *weights_gpu, float *weight_updates_gpu, int nweights, float *in, float **layers_output_gpu, WEIGHTS_NORMALIZATION_T weights_normalization)
1128 {
1129     const int layer_step = nweights / (n + 1);    // 1 or l.c or (l.c * l.h * l.w)
1130     int step = 0;
1131     if (nweights > 0) step = src_outputs / layer_step; // (l.c * l.h * l.w) or (l.w*l.h) or 1
1132     //printf(" nweights = %d, n = %d, layer_step = %d, step = %d \n", nweights, n, layer_step, step);
1133 
1134     //printf(" src_outputs = %d, batch = %d, n = %d \n", src_outputs, batch, n);
1135     int size = batch * src_outputs;
1136     backward_shortcut_multilayer_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, src_outputs, batch, n, outputs_of_layers_gpu,
1137         layers_delta_gpu, delta_out, delta_in, weights_gpu, weight_updates_gpu, nweights, in, layers_output_gpu, weights_normalization);
1138     CHECK_CUDA(cudaPeekAtLastError());
1139 }
1140 
shortcut_kernel(int size,int minw,int minh,int minc,int stride,int sample,int batch,int w1,int h1,int c1,float * add,int w2,int h2,int c2,float * out)1141 __global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
1142 {
1143     int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1144     if (id >= size) return;
1145     int i = id % minw;
1146     id /= minw;
1147     int j = id % minh;
1148     id /= minh;
1149     int k = id % minc;
1150     id /= minc;
1151     int b = id % batch;
1152 
1153     int out_index = i*sample + w2*(j*sample + h2*(k + c2*b));
1154     int add_index = i*stride + w1*(j*stride + h1*(k + c1*b));
1155     out[out_index] += add[add_index];
1156 }
1157 
shortcut_gpu(int batch,int w1,int h1,int c1,float * add,int w2,int h2,int c2,float * out)1158 extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
1159 {
1160     int minw = (w1 < w2) ? w1 : w2;
1161     int minh = (h1 < h2) ? h1 : h2;
1162     int minc = (c1 < c2) ? c1 : c2;
1163 
1164     int stride = w1/w2;
1165     int sample = w2/w1;
1166     assert(stride == h1/h2);
1167     assert(sample == h2/h1);
1168     if(stride < 1) stride = 1;
1169     if(sample < 1) sample = 1;
1170 
1171     int size = batch * minw * minh * minc;
1172     shortcut_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
1173     CHECK_CUDA(cudaPeekAtLastError());
1174 }
1175 
simple_input_shortcut_kernel(float * in,int size,float * add,float * out)1176 __global__ void simple_input_shortcut_kernel(float *in, int size, float *add, float *out)
1177 {
1178     int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1179     if (id >= size) return;
1180 
1181     out[id] = in[id] + add[id];
1182 }
1183 
input_shortcut_kernel(float * in,int size,int minw,int minh,int minc,int stride,int sample,int batch,int w1,int h1,int c1,float * add,int w2,int h2,int c2,float * out)1184 __global__ void input_shortcut_kernel(float *in, int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
1185 {
1186     int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1187     if (id >= size) return;
1188     int i = id % minw;
1189     id /= minw;
1190     int j = id % minh;
1191     id /= minh;
1192     int k = id % minc;
1193     id /= minc;
1194     int b = id % batch;
1195 
1196     int out_index = i*sample + w2*(j*sample + h2*(k + c2*b));
1197     int add_index = i*stride + w1*(j*stride + h1*(k + c1*b));
1198     out[out_index] = in[out_index] + add[add_index];
1199 }
1200 
input_shortcut_gpu(float * in,int batch,int w1,int h1,int c1,float * add,int w2,int h2,int c2,float * out)1201 extern "C" void input_shortcut_gpu(float *in, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
1202 {
1203     if (w1 == w2 && h1 == h2 && c1 == c2) {
1204         int size = batch * w1 * h1 * c1;
1205         simple_input_shortcut_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> >(in, size, add, out);
1206         CHECK_CUDA(cudaPeekAtLastError());
1207         return;
1208     }
1209 
1210     int minw = (w1 < w2) ? w1 : w2;
1211     int minh = (h1 < h2) ? h1 : h2;
1212     int minc = (c1 < c2) ? c1 : c2;
1213 
1214     int stride = w1 / w2;
1215     int sample = w2 / w1;
1216     assert(stride == h1 / h2);
1217     assert(sample == h2 / h1);
1218     if (stride < 1) stride = 1;
1219     if (sample < 1) sample = 1;
1220 
1221     int size = batch * minw * minh * minc;
1222     //input_shortcut_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> >(in, size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
1223     simple_copy_ongpu(w2 * h2 * c2 * batch, in, out);
1224     shortcut_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> >(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
1225     CHECK_CUDA(cudaPeekAtLastError());
1226 }
1227 
smooth_l1_kernel(int n,float * pred,float * truth,float * delta,float * error)1228 __global__ void smooth_l1_kernel(int n, float *pred, float *truth, float *delta, float *error)
1229 {
1230     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1231     if(i < n){
1232         float diff = truth[i] - pred[i];
1233         float abs_val = abs(diff);
1234         if(abs_val < 1) {
1235             error[i] = diff * diff;
1236             delta[i] = diff;
1237         }
1238         else {
1239             error[i] = 2*abs_val - 1;
1240             delta[i] = (diff < 0) ? -1 : 1;
1241         }
1242     }
1243 }
1244 
smooth_l1_gpu(int n,float * pred,float * truth,float * delta,float * error)1245 extern "C" void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error)
1246 {
1247     smooth_l1_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(n, pred, truth, delta, error);
1248     CHECK_CUDA(cudaPeekAtLastError());
1249 }
1250 
softmax_x_ent_kernel(int n,float * pred,float * truth,float * delta,float * error)1251 __global__ void softmax_x_ent_kernel(int n, float *pred, float *truth, float *delta, float *error)
1252 {
1253 	int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1254 	if (i < n) {
1255 		float t = truth[i];
1256 		float p = pred[i];
1257 		error[i] = (t) ? -log(p) : 0;
1258 		delta[i] = t - p;
1259 	}
1260 }
1261 
softmax_x_ent_gpu(int n,float * pred,float * truth,float * delta,float * error)1262 extern "C" void softmax_x_ent_gpu(int n, float *pred, float *truth, float *delta, float *error)
1263 {
1264 	softmax_x_ent_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> >(n, pred, truth, delta, error);
1265     CHECK_CUDA(cudaPeekAtLastError());
1266 }
1267 
l2_kernel(int n,float * pred,float * truth,float * delta,float * error)1268 __global__ void l2_kernel(int n, float *pred, float *truth, float *delta, float *error)
1269 {
1270     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1271     if(i < n){
1272         float diff = truth[i] - pred[i];
1273         error[i] = diff * diff; //I know this is technically wrong, deal with it.
1274         delta[i] = diff;
1275     }
1276 }
1277 
l2_gpu(int n,float * pred,float * truth,float * delta,float * error)1278 extern "C" void l2_gpu(int n, float *pred, float *truth, float *delta, float *error)
1279 {
1280     l2_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >>>(n, pred, truth, delta, error);
1281     CHECK_CUDA(cudaPeekAtLastError());
1282 }
1283 
1284 
1285 
weighted_sum_kernel(int n,float * a,float * b,float * s,float * c)1286 __global__ void weighted_sum_kernel(int n, float *a, float *b, float *s, float *c)
1287 {
1288     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1289     if(i < n){
1290         c[i] = s[i]*a[i] + (1-s[i])*(b ? b[i] : 0);
1291     }
1292 }
1293 
weighted_sum_gpu(float * a,float * b,float * s,int num,float * c)1294 extern "C" void weighted_sum_gpu(float *a, float *b, float *s, int num, float *c)
1295 {
1296     weighted_sum_kernel<<<cuda_gridsize(num), BLOCK, 0, get_cuda_stream() >>>(num, a, b, s, c);
1297     CHECK_CUDA(cudaPeekAtLastError());
1298 }
1299 
weighted_delta_kernel(int n,float * a,float * b,float * s,float * da,float * db,float * ds,float * dc)1300 __global__ void weighted_delta_kernel(int n, float *a, float *b, float *s, float *da, float *db, float *ds, float *dc)
1301 {
1302     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1303     if(i < n){
1304         if(da) da[i] += dc[i] * s[i];
1305         db[i] += dc[i] * (1-s[i]);
1306         ds[i] += dc[i] * a[i] + dc[i] * -b[i];
1307     }
1308 }
1309 
weighted_delta_gpu(float * a,float * b,float * s,float * da,float * db,float * ds,int num,float * dc)1310 extern "C" void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc)
1311 {
1312     weighted_delta_kernel<<<cuda_gridsize(num), BLOCK, 0, get_cuda_stream() >>>(num, a, b, s, da, db, ds, dc);
1313     CHECK_CUDA(cudaPeekAtLastError());
1314 }
1315 
mult_add_into_kernel(int n,float * a,float * b,float * c)1316 __global__ void mult_add_into_kernel(int n, float *a, float *b, float *c)
1317 {
1318     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1319     if(i < n){
1320         c[i] += a[i]*b[i];
1321     }
1322 }
1323 
mult_add_into_gpu(int num,float * a,float * b,float * c)1324 extern "C" void mult_add_into_gpu(int num, float *a, float *b, float *c)
1325 {
1326     mult_add_into_kernel<<<cuda_gridsize(num), BLOCK, 0, get_cuda_stream() >>>(num, a, b, c);
1327     CHECK_CUDA(cudaPeekAtLastError());
1328 }
1329 
1330 
softmax_device(int n,float * input,float temp,float * output)1331 __device__ void softmax_device(int n, float *input, float temp, float *output)
1332 {
1333     int i;
1334     float sum = 0;
1335     float largest = -INFINITY;
1336     for(i = 0; i < n; ++i){
1337         int val = input[i];
1338         largest = (val>largest) ? val : largest;
1339     }
1340     for(i = 0; i < n; ++i){
1341         float e = exp(input[i]/temp - largest/temp);
1342         sum += e;
1343         output[i] = e;
1344     }
1345     for(i = 0; i < n; ++i){
1346         output[i] /= sum;
1347     }
1348 }
1349 
softmax_kernel(int n,int offset,int batch,float * input,float temp,float * output)1350 __global__ void softmax_kernel(int n, int offset, int batch, float *input, float temp, float *output)
1351 {
1352     int b = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1353     if(b >= batch) return;
1354     softmax_device(n, input + b*offset, temp, output + b*offset);
1355 }
1356 
softmax_gpu(float * input,int n,int offset,int groups,float temp,float * output)1357 extern "C" void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output)
1358 {
1359     int inputs = n;
1360     int batch = groups;
1361     softmax_kernel<<<cuda_gridsize(batch), BLOCK, 0, get_cuda_stream()>>>(inputs, offset, batch, input, temp, output);
1362     CHECK_CUDA(cudaPeekAtLastError());
1363 }
1364 
softmax_device_new_api(float * input,int n,float temp,int stride,float * output)1365 __device__ void softmax_device_new_api(float *input, int n, float temp, int stride, float *output)
1366 {
1367 	int i;
1368 	float sum = 0;
1369 	float largest = -INFINITY;
1370 	for (i = 0; i < n; ++i) {
1371 		int val = input[i*stride];
1372 		largest = (val>largest) ? val : largest;
1373 	}
1374 	for (i = 0; i < n; ++i) {
1375 		float e = expf(input[i*stride] / temp - largest / temp);
1376 		sum += e;
1377 		output[i*stride] = e;
1378 	}
1379 	for (i = 0; i < n; ++i) {
1380 		output[i*stride] /= sum;
1381 	}
1382 }
1383 
softmax_kernel_new_api(float * input,int n,int batch,int batch_offset,int groups,int group_offset,int stride,float temp,float * output)1384 __global__ void softmax_kernel_new_api(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output)
1385 {
1386 	int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1387 	if (id >= batch*groups) return;
1388 	int b = id / groups;
1389 	int g = id % groups;
1390 	softmax_device_new_api(input + b*batch_offset + g*group_offset, n, temp, stride, output + b*batch_offset + g*group_offset);
1391 }
1392 
softmax_gpu_new_api(float * input,int n,int batch,int batch_offset,int groups,int group_offset,int stride,float temp,float * output)1393 extern "C" void softmax_gpu_new_api(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output)
1394 {
1395 	softmax_kernel_new_api << <cuda_gridsize(batch*groups), BLOCK, 0, get_cuda_stream() >> >(input, n, batch, batch_offset, groups, group_offset, stride, temp, output);
1396     CHECK_CUDA(cudaPeekAtLastError());
1397 }
1398 
1399 
upsample_kernel(size_t N,float * x,int w,int h,int c,int batch,int stride,int forward,float scale,float * out)1400 __global__ void upsample_kernel(size_t N, float *x, int w, int h, int c, int batch, int stride, int forward, float scale, float *out)
1401 {
1402     size_t i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1403     if (i >= N) return;
1404     int out_index = i;
1405     int out_w = i % (w*stride);
1406     i = i / (w*stride);
1407     int out_h = i % (h*stride);
1408     i = i / (h*stride);
1409     int out_c = i%c;
1410     i = i / c;
1411     int b = i%batch;
1412 
1413     int in_w = out_w / stride;
1414     int in_h = out_h / stride;
1415     int in_c = out_c;
1416 
1417     int in_index = b*w*h*c + in_c*w*h + in_h*w + in_w;
1418 
1419 
1420     if (forward) out[out_index] += scale * x[in_index];
1421     else atomicAdd(x + in_index, scale * out[out_index]);
1422 }
1423 
upsample_gpu(float * in,int w,int h,int c,int batch,int stride,int forward,float scale,float * out)1424 extern "C" void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out)
1425 {
1426     size_t size = w*h*c*batch*stride*stride;
1427     upsample_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> >(size, in, w, h, c, batch, stride, forward, scale, out);
1428     CHECK_CUDA(cudaPeekAtLastError());
1429 }
1430 
softmax_tree_kernel(float * input,int spatial,int batch,int stride,float temp,float * output,int groups,int * group_size,int * group_offset)1431 __global__ void softmax_tree_kernel(float *input, int spatial, int batch, int stride, float temp, float *output, int groups, int *group_size, int *group_offset)
1432 {
1433 	int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
1434 	if (id >= spatial*batch*groups) return;
1435 	int s = id % spatial;
1436 	id = id / spatial;
1437 	int g = id % groups;
1438 	int b = id / groups;
1439 	int goff = group_offset[g] * spatial;
1440 	int boff = b*stride;
1441 	softmax_device_new_api(input + goff + boff + s, group_size[g], temp, spatial, output + goff + boff + s);
1442 }
1443 
softmax_tree_gpu(float * input,int spatial,int batch,int stride,float temp,float * output,tree hier)1444 extern "C" void softmax_tree_gpu(float *input, int spatial, int batch, int stride, float temp, float *output, tree hier)
1445 {
1446 	int *tree_groups_size = cuda_make_int_array_new_api(hier.group_size, hier.groups);
1447 	int *tree_groups_offset = cuda_make_int_array_new_api(hier.group_offset, hier.groups);
1448 	/*
1449 	static int *tree_groups_size = 0;
1450 	static int *tree_groups_offset = 0;
1451 	if(!tree_groups_size){
1452 	tree_groups_size = cuda_make_int_array(hier.group_size, hier.groups);
1453 	tree_groups_offset = cuda_make_int_array(hier.group_offset, hier.groups);
1454 	}
1455 	*/
1456 	int num = spatial*batch*hier.groups;
1457 	softmax_tree_kernel <<<cuda_gridsize(num), BLOCK, 0, get_cuda_stream() >>>(input, spatial, batch, stride, temp, output, hier.groups, tree_groups_size, tree_groups_offset);
1458     CHECK_CUDA(cudaPeekAtLastError());
1459 	cuda_free((float *)tree_groups_size);
1460 	cuda_free((float *)tree_groups_offset);
1461 }
1462 
1463 
fix_nan_and_inf_kernel(float * input,size_t size)1464 __global__ void fix_nan_and_inf_kernel(float *input, size_t size)
1465 {
1466     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1467     if (index < size) {
1468         float val = input[index];
1469         if (isnan(val) || isinf(val)) {
1470             input[index] = 1.0f / (fabs((float)index) + 1);  // pseudo random value
1471         }
1472     }
1473 }
1474 
fix_nan_and_inf(float * input,size_t size)1475 extern "C" void fix_nan_and_inf(float *input, size_t size)
1476 {
1477     const int block_size = BLOCK;
1478     const int num_blocks = get_number_of_blocks(size, block_size);
1479     fix_nan_and_inf_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> >(input, size);
1480     CHECK_CUDA(cudaPeekAtLastError());
1481     //CHECK_CUDA(cudaDeviceSynchronize());
1482 }
1483 
1484 
reset_nan_and_inf_kernel(float * input,size_t size)1485 __global__ void reset_nan_and_inf_kernel(float *input, size_t size)
1486 {
1487     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1488     if (index < size) {
1489         float val = input[index];
1490         if (isnan(val) || isinf(val)) {
1491             input[index] = 0;
1492         }
1493     }
1494 }
1495 
reset_nan_and_inf(float * input,size_t size)1496 extern "C" void reset_nan_and_inf(float *input, size_t size)
1497 {
1498     const int block_size = BLOCK;
1499     const int num_blocks = get_number_of_blocks(size, block_size);
1500     reset_nan_and_inf_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> >(input, size);
1501     CHECK_CUDA(cudaPeekAtLastError());
1502     //CHECK_CUDA(cudaDeviceSynchronize());
1503 }
1504 
1505 
1506 
is_nan_or_inf_kernel(float * input,size_t size,int * pinned_return)1507 __global__ void is_nan_or_inf_kernel(float *input, size_t size, int *pinned_return)
1508 {
1509     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1510     if (index < size) {
1511         float val = input[index];
1512         if (isnan(val) || isinf(val))
1513             *pinned_return = 1;
1514     }
1515 }
1516 
is_nan_or_inf(float * input,size_t size)1517 extern "C" int is_nan_or_inf(float *input, size_t size)
1518 {
1519     int *pinned_return;
1520     CHECK_CUDA(cudaHostAlloc(&pinned_return, sizeof(int), cudaHostRegisterMapped));
1521     *pinned_return = 0;
1522 
1523     const int block_size = BLOCK;
1524     const int num_blocks = get_number_of_blocks(size, block_size);
1525     is_nan_or_inf_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> >(input, size, pinned_return);
1526     CHECK_CUDA(cudaDeviceSynchronize());
1527     int ret_val = *pinned_return;
1528 
1529     CHECK_CUDA(cudaFreeHost(pinned_return));
1530     return ret_val;
1531 }
1532 
add_3_arrays_activate_kernel(float * a1,float * a2,float * a3,size_t size,ACTIVATION a,float * dst)1533 __global__ void add_3_arrays_activate_kernel(float *a1, float *a2, float *a3, size_t size, ACTIVATION a, float *dst)
1534 {
1535     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1536     if (index < size) {
1537         float val = 0;
1538         val += a1[index];
1539         val += a2[index];
1540         if (a3) val += a3[index];
1541         if (a == LOGISTIC) val = 1.f / (1.f + expf(-val));
1542         else if(a == TANH) val = (2 / (1 + expf(-2 * val)) - 1);
1543         dst[index] = val;
1544     }
1545 }
1546 
add_3_arrays_activate(float * a1,float * a2,float * a3,size_t size,ACTIVATION a,float * dst)1547 extern "C" void add_3_arrays_activate(float *a1, float *a2, float *a3, size_t size, ACTIVATION a, float *dst)
1548 {
1549     const int block_size = BLOCK;
1550     const int num_blocks = get_number_of_blocks(size, block_size);
1551     if (a != LOGISTIC && a != TANH) {
1552         printf(" add_3_arrays_activate() doesn't support activation %d, it supports only LOGISTIC and TANH \n", a);
1553         exit(EXIT_FAILURE);
1554     }
1555     add_3_arrays_activate_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> >(a1, a2, a3, size, a, dst);
1556 }
1557 
1558 
sum_of_mults_kernel(float * a1,float * a2,float * b1,float * b2,size_t size,float * dst)1559 __global__ void sum_of_mults_kernel(float *a1, float *a2, float *b1, float *b2, size_t size, float *dst)
1560 {
1561     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1562     if (index < size) {
1563         dst[index] = a1[index] * a2[index] + b1[index] * b2[index];
1564     }
1565 }
1566 
sum_of_mults(float * a1,float * a2,float * b1,float * b2,size_t size,float * dst)1567 extern "C" void sum_of_mults(float *a1, float *a2, float *b1, float *b2,  size_t size, float *dst)
1568 {
1569     const int block_size = BLOCK;
1570     const int num_blocks = get_number_of_blocks(size, block_size);
1571     sum_of_mults_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> >(a1, a2, b1, b2, size, dst);
1572 }
1573 
1574 
activate_and_mult_kernel(float * a1,float * a2,size_t size,ACTIVATION a,float * dst)1575 __global__ void activate_and_mult_kernel(float *a1, float *a2, size_t size, ACTIVATION a, float *dst)
1576 {
1577     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1578     if (index < size) {
1579         float val = a1[index];
1580         if (a == TANH) val = (2 / (1 + expf(-2 * val)) - 1);
1581         dst[index] = val * a2[index];
1582     }
1583 }
1584 
activate_and_mult(float * a1,float * a2,size_t size,ACTIVATION a,float * dst)1585 extern "C" void activate_and_mult(float *a1, float *a2, size_t size, ACTIVATION a, float *dst)
1586 {
1587     const int block_size = BLOCK;
1588     const int num_blocks = get_number_of_blocks(size, block_size);
1589     if (a != TANH) {
1590         printf(" activat_and_mult() doesn't support activation %d, it supports only TANH \n", a);
1591         exit(EXIT_FAILURE);
1592     }
1593     activate_and_mult_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> >(a1, a2, size, a, dst);
1594 }
1595 
1596 
1597 
scale_channels_kernel(float * in_w_h_c,int size,int channel_size,int batch_size,int scale_wh,float * scales_c,float * out)1598 __global__ void scale_channels_kernel(float *in_w_h_c, int size, int channel_size, int batch_size, int scale_wh, float *scales_c, float *out)
1599 {
1600     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1601     if (index < size) {
1602         if (scale_wh) {
1603             int osd_index = index % channel_size + (index / batch_size)*channel_size;
1604 
1605             out[index] = in_w_h_c[index] * scales_c[osd_index];
1606         }
1607         else {
1608             out[index] = in_w_h_c[index] * scales_c[index / channel_size];
1609         }
1610     }
1611 }
1612 
scale_channels_gpu(float * in_w_h_c,int size,int channel_size,int batch_size,int scale_wh,float * scales_c,float * out)1613 extern "C" void scale_channels_gpu(float *in_w_h_c, int size, int channel_size, int batch_size, int scale_wh, float *scales_c, float *out)
1614 {
1615     const int block_size = BLOCK;
1616     const int num_blocks = get_number_of_blocks(size, block_size);
1617     scale_channels_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> >(in_w_h_c, size, channel_size, batch_size, scale_wh, scales_c, out);
1618     CHECK_CUDA(cudaPeekAtLastError());
1619 }
1620 
1621 
1622 
1623 
backward_scale_channels_kernel(float * in_w_h_c_delta,int size,int channel_size,int batch_size,int scale_wh,float * in_scales_c,float * out_from_delta,float * in_from_output,float * out_state_delta)1624 __global__ void backward_scale_channels_kernel(float *in_w_h_c_delta, int size, int channel_size, int batch_size, int scale_wh,
1625     float *in_scales_c, float *out_from_delta,
1626     float *in_from_output, float *out_state_delta)
1627 {
1628     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1629 
1630     if (index < size) {
1631 
1632         if (scale_wh)
1633         {
1634             int osd_index = index % channel_size + (index / batch_size)*channel_size;
1635 
1636             //out_state_delta[osd_index] += in_w_h_c_delta[index] * in_from_output[index]; // l.delta * from  (should be divided by channel_size?)
1637             atomicAdd(&out_state_delta[osd_index], in_w_h_c_delta[index] * in_from_output[index] / channel_size); // l.delta * from
1638 
1639             out_from_delta[index] += in_scales_c[osd_index] * in_w_h_c_delta[index]; // input * l.delta  // atomic isn't required here
1640 
1641         }
1642         else {
1643             int osd_index = index / channel_size;
1644             //out_state_delta[osd_index] += in_w_h_c_delta[index] * in_from_output[index]; // l.delta * from  (should be divided by channel_size?)
1645 
1646             int warp_id = index / 32;
1647             int index_warp_start = warp_id * 32;
1648             int osd_index_warp_start = index_warp_start / channel_size;
1649             int osd_index_warp_end = (index_warp_start + 31) / channel_size;
1650 
1651             if (osd_index_warp_start == osd_index_warp_end) // all thread in warp process the same channel
1652             {
1653                 float sum = warpAllReduceSum(in_w_h_c_delta[index] * in_from_output[index]); // l.delta * from
1654                 if (threadIdx.x % 32 == 0) {
1655                     atomicAdd(&out_state_delta[osd_index], sum);
1656                     //out_state_delta[osd_index] += sum;
1657                 }
1658             }
1659             else {
1660                 atomicAdd(&out_state_delta[osd_index], in_w_h_c_delta[index] * in_from_output[index]); // l.delta * from
1661             }
1662 
1663             out_from_delta[index] += in_scales_c[osd_index] * in_w_h_c_delta[index]; // input * l.delta  // atomic isn't required here
1664         }
1665     }
1666 }
1667 
backward_scale_channels_gpu(float * in_w_h_c_delta,int size,int channel_size,int batch_size,int scale_wh,float * in_scales_c,float * out_from_delta,float * in_from_output,float * out_state_delta)1668 extern "C" void backward_scale_channels_gpu(float *in_w_h_c_delta, int size, int channel_size, int batch_size, int scale_wh,
1669     float *in_scales_c, float *out_from_delta,
1670     float *in_from_output, float *out_state_delta)
1671 {
1672     const int block_size = BLOCK;
1673     const int num_blocks = get_number_of_blocks(size, block_size);
1674     backward_scale_channels_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (in_w_h_c_delta, size, channel_size, batch_size, scale_wh,
1675         in_scales_c, out_from_delta,
1676         in_from_output, out_state_delta);
1677 
1678     CHECK_CUDA(cudaPeekAtLastError());
1679 }
1680 
1681 
sam_kernel(float * in_w_h_c,int size,int channel_size,float * scales_c,float * out)1682 __global__ void sam_kernel(float *in_w_h_c, int size, int channel_size, float *scales_c, float *out)
1683 {
1684     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1685     if (index < size) {
1686         out[index] = in_w_h_c[index] * scales_c[index];
1687     }
1688 }
1689 
sam_gpu(float * in_w_h_c,int size,int channel_size,float * scales_c,float * out)1690 extern "C" void sam_gpu(float *in_w_h_c, int size, int channel_size, float *scales_c, float *out)
1691 {
1692     const int block_size = BLOCK;
1693     const int num_blocks = get_number_of_blocks(size, block_size);
1694     sam_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> >(in_w_h_c, size, channel_size, scales_c, out);
1695     CHECK_CUDA(cudaPeekAtLastError());
1696 }
1697 
1698 
backward_sam_kernel(float * in_w_h_c_delta,int size,int channel_size,float * in_scales_c,float * out_from_delta,float * in_from_output,float * out_state_delta)1699 __global__ void backward_sam_kernel(float *in_w_h_c_delta, int size, int channel_size,
1700     float *in_scales_c, float *out_from_delta,
1701     float *in_from_output, float *out_state_delta)
1702 {
1703     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1704     if (index < size) {
1705         out_state_delta[index] += in_w_h_c_delta[index] * in_from_output[index]; // l.delta * from  (should be divided by channel_size?)
1706         out_from_delta[index] += in_scales_c[index] * in_w_h_c_delta[index]; // input * l.delta
1707 
1708                                                                              //out_state_delta[index] += in_w_h_c_delta[index];
1709                                                                              //out_from_delta[index] = in_w_h_c_delta[index];
1710     }
1711 }
1712 
backward_sam_gpu(float * in_w_h_c_delta,int size,int channel_size,float * in_scales_c,float * out_from_delta,float * in_from_output,float * out_state_delta)1713 extern "C" void backward_sam_gpu(float *in_w_h_c_delta, int size, int channel_size,
1714     float *in_scales_c, float *out_from_delta,
1715     float *in_from_output, float *out_state_delta)
1716 {
1717     const int block_size = BLOCK;
1718     const int num_blocks = get_number_of_blocks(size, block_size);
1719     backward_sam_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (in_w_h_c_delta, size, channel_size,
1720         in_scales_c, out_from_delta,
1721         in_from_output, out_state_delta);
1722 
1723     CHECK_CUDA(cudaPeekAtLastError());
1724 }
1725 
1726 
smooth_rotate_weights_kernel(const float * src_weight_gpu,float * weight_deform_gpu,int nweights,int n,int kernel_size,int angle,int reverse)1727 __global__  void smooth_rotate_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int angle, int reverse)
1728 {
1729     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1730     const int kernel_area = kernel_size * kernel_size;
1731     const int i = index * kernel_area;
1732 
1733     const int stage_step = (nweights / kernel_area) / 4;  // 4 stages
1734     const int stage_id = index / stage_step;
1735 
1736     // nweights = (c / groups) * n * size * size;
1737     // kernel_area = size*size
1738 
1739     if (i < nweights)
1740     {
1741         // rotate left or right
1742         if (reverse) angle = -angle;
1743 
1744         const float cos_a = cosf(angle * 3.14159265 / 180);
1745         const float sin_a = sinf(angle * 3.14159265 / 180);
1746         const int x_c = kernel_size / 2;
1747         const int y_c = kernel_size / 2;
1748 
1749         float dropout_sum = 0;
1750 
1751         for (int y = 0; y < kernel_size; ++y) {
1752             for (int x = 0; x < kernel_size; ++x) {
1753                 // Xsource = x*cos(alpha) + y*sin(alpha)
1754                 // Ysource = -x*sin(alpha) + y*cos(alpha)
1755 
1756                 float x_s = x_c + (x - x_c)*cos_a + (y - y_c)*sin_a;
1757                 float y_s = y_c - (x - x_c)*sin_a + (y - y_c)*cos_a;
1758 
1759                 int x_0 = floor(x_s);   // round down
1760                 int x_1 = ceil(x_s);    // round up
1761                 if (x_0 == x_1) x_1 = x_0 + 1;
1762                 int y_0 = floor(y_s);
1763                 int y_1 = ceil(y_s);
1764                 if (y_0 == y_1) y_1 = y_0 + 1;
1765 
1766                 float c_x_0 = x_1 - x_s;
1767                 float c_x_1 = x_s - x_0;
1768                 float c_y_0 = y_1 - y_s;
1769                 float c_y_1 = y_s - y_0;
1770 
1771 
1772                 float val = 0;
1773                 if (x_0 >= 0 && x_0 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_0 + y_0*kernel_size + i] * c_x_0 * c_y_0;
1774                 else dropout_sum += c_x_0 * c_y_0;
1775 
1776                 if (x_1 >= 0 && x_1 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_1 + y_0*kernel_size + i] * c_x_1 * c_y_0;
1777                 else dropout_sum += c_x_1 * c_y_0;
1778 
1779                 if (x_0 >= 0 && x_0 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_0 + y_1*kernel_size + i] * c_x_0 * c_y_1;
1780                 else dropout_sum += c_x_0 * c_y_1;
1781 
1782                 if (x_1 >= 0 && x_1 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_1 + y_1*kernel_size + i] * c_x_1 * c_y_1;
1783                 else dropout_sum += c_x_1 * c_y_1;
1784 
1785                 weight_deform_gpu[x + y*kernel_size + i] = val;
1786             }
1787         }
1788 
1789         // compensate for dropped items
1790         const float coef = (kernel_size*kernel_size) / (kernel_size*kernel_size - dropout_sum);
1791         for (int y = 0; y < kernel_size; ++y) {
1792             for (int x = 0; x < kernel_size; ++x) {
1793                 weight_deform_gpu[x + y*kernel_size + i] *= coef;
1794             }
1795         }
1796     }
1797 }
1798 
1799 
smooth_rotate_weights_gpu(const float * src_weight_gpu,float * weight_deform_gpu,int nweights,int n,int size,int angle,int reverse)1800 extern "C" void smooth_rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse)
1801 {
1802     const int kernel_area = size*size;
1803     const int block_size = BLOCK;
1804     const int num_blocks = get_number_of_blocks(nweights / kernel_area, block_size);
1805     smooth_rotate_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, angle, reverse);
1806 
1807     CHECK_CUDA(cudaPeekAtLastError());
1808 }
1809 
1810 
1811 
stretch_weights_kernel(const float * src_weight_gpu,float * weight_deform_gpu,int nweights,int n,int kernel_size,float scale,int reverse)1812 __global__  void stretch_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, float scale, int reverse)
1813 {
1814     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1815     const int kernel_area = kernel_size * kernel_size;
1816     const int i = index * kernel_area;
1817 
1818     const int stage_step = (nweights / kernel_area) / 4;  // 4 stages
1819     const int stage_id = index / stage_step;
1820 
1821     // nweights = (c / groups) * n * size * size;
1822     // kernel_area = size*size
1823 
1824     if (i < nweights)
1825     {
1826 
1827         if (stage_id == 0) {
1828             // simple copy
1829             for (int x = 0; x < kernel_size; ++x) {
1830                 for (int y = 0; y < kernel_size; ++y) {
1831                     weight_deform_gpu[x + y*kernel_size + i] = src_weight_gpu[x + y*kernel_size + i];
1832                 }
1833             }
1834         }
1835         else if (stage_id > 0)
1836         {
1837             if (stage_id == 1) scale = 0.65;
1838             else if (stage_id == 2) scale = 0.8;
1839             else if (stage_id == 3) scale = 1.3;
1840 
1841             if (reverse) scale = 1 / scale;
1842 
1843             const int x_c = kernel_size / 2;
1844             const int y_c = kernel_size / 2;
1845 
1846             float dropout_sum = 0;
1847 
1848             for (int y = 0; y < kernel_size; ++y) {
1849                 for (int x = 0; x < kernel_size; ++x) {
1850                     // Xsource = x_c + (x_d - x_c) / scale
1851                     // Ysource = y_c + (y_d - y_c) / scale
1852 
1853                     float x_s = x_c + (x - x_c) / scale;
1854                     float y_s = y_c + (y - y_c) / scale;
1855 
1856                     int x_0 = floor(x_s);   // round down
1857                     int x_1 = ceil(x_s);    // round up
1858                     if (x_0 == x_1) x_1 = x_0 + 1;
1859                     int y_0 = floor(y_s);
1860                     int y_1 = ceil(y_s);
1861                     if (y_0 == y_1) y_1 = y_0 + 1;
1862 
1863                     float c_x_0 = x_1 - x_s;
1864                     float c_x_1 = x_s - x_0;
1865                     float c_y_0 = y_1 - y_s;
1866                     float c_y_1 = y_s - y_0;
1867 
1868                     float val = 0;
1869                     if (x_0 >= 0 && x_0 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_0 + y_0*kernel_size + i] * c_x_0 * c_y_0;
1870                     else dropout_sum += c_x_0 * c_y_0;
1871 
1872                     if (x_1 >= 0 && x_1 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_1 + y_0*kernel_size + i] * c_x_1 * c_y_0;
1873                     else dropout_sum += c_x_1 * c_y_0;
1874 
1875                     if (x_0 >= 0 && x_0 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_0 + y_1*kernel_size + i] * c_x_0 * c_y_1;
1876                     else dropout_sum += c_x_0 * c_y_1;
1877 
1878                     if (x_1 >= 0 && x_1 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_1 + y_1*kernel_size + i] * c_x_1 * c_y_1;
1879                     else dropout_sum += c_x_1 * c_y_1;
1880 
1881                     weight_deform_gpu[x + y*kernel_size + i] = val;
1882                 }
1883             }
1884 
1885             // compensate for dropped items
1886             //const float coef = (kernel_size*kernel_size) / (kernel_size*kernel_size - dropout_sum);
1887             for (int y = 0; y < kernel_size; ++y) {
1888                 for (int x = 0; x < kernel_size; ++x) {
1889                     //if (scale < 1) weight_deform_gpu[x + y*kernel_size + i] /= scale;// *= coef;
1890                     weight_deform_gpu[x + y*kernel_size + i] /= scale;// *= coef;
1891                 }
1892             }
1893         }
1894     }
1895 }
1896 
1897 
stretch_weights_gpu(const float * src_weight_gpu,float * weight_deform_gpu,int nweights,int n,int size,float scale,int reverse)1898 extern "C" void stretch_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, float scale, int reverse)
1899 {
1900     const int kernel_area = size*size;
1901     const int block_size = BLOCK;
1902     const int num_blocks = get_number_of_blocks(nweights / kernel_area, block_size);
1903     stretch_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, scale, reverse);
1904 
1905     CHECK_CUDA(cudaPeekAtLastError());
1906 }
1907 
1908 
1909 
sway_and_flip_weights_kernel(const float * src_weight_gpu,float * weight_deform_gpu,int nweights,int n,int kernel_size,int angle,int reverse)1910 __global__  void sway_and_flip_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int angle, int reverse)
1911 {
1912     const int index = blockIdx.x*blockDim.x + threadIdx.x;
1913     const int kernel_area = kernel_size * kernel_size;
1914     const int i = index * kernel_area;
1915 
1916     const int stage_step = (nweights / kernel_area) / 4;  // 4 stages
1917     const int stage_id = index / stage_step;
1918 
1919     // nweights = (c / groups) * n * size * size;
1920     // kernel_area = size*size
1921 
1922     if (i < nweights)
1923     {
1924 
1925         if (stage_id == 0) {
1926             // simple copy
1927             for (int x = 0; x < kernel_size; ++x) {
1928                 for (int y = 0; y < kernel_size; ++y) {
1929                     weight_deform_gpu[x + y*kernel_size + i] = src_weight_gpu[x + y*kernel_size + i];
1930                 }
1931             }
1932         }
1933         else if (stage_id == 1 || stage_id == 2)
1934         {
1935             // rotate left or right
1936             if (stage_id == 2) angle = -angle;
1937             if (reverse) angle = -angle;
1938 
1939             const float cos_a = cosf(angle * 3.14159265 / 180);
1940             const float sin_a = sinf(angle * 3.14159265 / 180);
1941             const int x_c = kernel_size / 2;
1942             const int y_c = kernel_size / 2;
1943 
1944             float dropout_sum = 0;
1945 
1946             for (int y = 0; y < kernel_size; ++y) {
1947                 for (int x = 0; x < kernel_size; ++x) {
1948                     // Xsource = x*cos(alpha) + y*sin(alpha)
1949                     // Ysource = -x*sin(alpha) + y*cos(alpha)
1950 
1951                     float x_s = x_c + (x - x_c)*cos_a + (y - y_c)*sin_a;
1952                     float y_s = y_c - (x - x_c)*sin_a + (y - y_c)*cos_a;
1953 
1954                     int x_0 = floor(x_s);   // round down
1955                     int x_1 = ceil(x_s);    // round up
1956                     if (x_0 == x_1) x_1 = x_0 + 1;
1957                     int y_0 = floor(y_s);
1958                     int y_1 = ceil(y_s);
1959                     if (y_0 == y_1) y_1 = y_0 + 1;
1960 
1961                     float c_x_0 = x_1 - x_s;
1962                     float c_x_1 = x_s - x_0;
1963                     float c_y_0 = y_1 - y_s;
1964                     float c_y_1 = y_s - y_0;
1965 
1966                     float val = 0;
1967                     if (x_0 >= 0 && x_0 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_0 + y_0*kernel_size + i] * c_x_0 * c_y_0;
1968                     else dropout_sum += c_x_0 * c_y_0;
1969 
1970                     if (x_1 >= 0 && x_1 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_1 + y_0*kernel_size + i] * c_x_1 * c_y_0;
1971                     else dropout_sum += c_x_1 * c_y_0;
1972 
1973                     if (x_0 >= 0 && x_0 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_0 + y_1*kernel_size + i] * c_x_0 * c_y_1;
1974                     else dropout_sum += c_x_0 * c_y_1;
1975 
1976                     if (x_1 >= 0 && x_1 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_1 + y_1*kernel_size + i] * c_x_1 * c_y_1;
1977                     else dropout_sum += c_x_1 * c_y_1;
1978 
1979                     weight_deform_gpu[x + y*kernel_size + i] = val;
1980                 }
1981             }
1982 
1983             // compensate for dropped items
1984             const float coef = (kernel_size*kernel_size) / (kernel_size*kernel_size - dropout_sum);
1985             for (int y = 0; y < kernel_size; ++y) {
1986                 for (int x = 0; x < kernel_size; ++x) {
1987                     weight_deform_gpu[x + y*kernel_size + i] *= coef;
1988                 }
1989             }
1990         }
1991         else if (stage_id == 3)
1992         {
1993             // flip
1994             for (int y = 0; y < kernel_size; ++y) {
1995                 for (int x = 0; x < kernel_size; ++x) {
1996                     weight_deform_gpu[(kernel_size - x - 1) + y*kernel_size + i] = src_weight_gpu[x + y*kernel_size + i];
1997                 }
1998             }
1999         }
2000     }
2001 }
2002 
2003 
sway_and_flip_weights_gpu(const float * src_weight_gpu,float * weight_deform_gpu,int nweights,int n,int size,int angle,int reverse)2004 extern "C" void sway_and_flip_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse)
2005 {
2006     const int kernel_area = size*size;
2007     const int block_size = BLOCK;
2008     const int num_blocks = get_number_of_blocks(nweights / kernel_area, block_size);
2009     sway_and_flip_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, angle, reverse);
2010 
2011     CHECK_CUDA(cudaPeekAtLastError());
2012 }
2013 
2014 
2015 
2016 
2017 
2018 
2019 
rotate_weights_kernel(const float * src_weight_gpu,float * weight_deform_gpu,int nweights,int n,int kernel_size,int reverse)2020 __global__  void rotate_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, int reverse)
2021 {
2022     const int index = blockIdx.x*blockDim.x + threadIdx.x;
2023     const int kernel_area = kernel_size * kernel_size;
2024     const int i = index * kernel_area;
2025 
2026     const int stage_step = (nweights / kernel_area) / 4;  // 4 stages
2027     const int stage_id = index / stage_step;
2028 
2029     // nweights = (c / groups) * n * size * size;
2030     // kernel_area = size*size
2031 
2032     if (i < nweights)
2033     {
2034         // if(reverse)
2035 
2036         if (stage_id == 0) {
2037             // simple copy
2038             for (int y = 0; y < kernel_size; ++y) {
2039                 for (int x = 0; x < kernel_size; ++x) {
2040                     const int src_i = x + y*kernel_size + i;
2041                     const int dst_i = x + y*kernel_size + i;
2042                     if (reverse) weight_deform_gpu[src_i] = src_weight_gpu[dst_i];
2043                     else weight_deform_gpu[dst_i] = src_weight_gpu[src_i];
2044                 }
2045             }
2046         }
2047         else if (stage_id == 1)
2048         {
2049             // 90 degree clockwise rotation - 1
2050             for (int y = 0; y < kernel_size; ++y) {
2051                 for (int x = 0; x < kernel_size; ++x) {
2052                     const int src_i = x + y*kernel_size + i;
2053                     const int dst_i = (kernel_size - 1 - y) + x*kernel_size + i;
2054                     if (reverse) weight_deform_gpu[src_i] = src_weight_gpu[dst_i];
2055                     else weight_deform_gpu[dst_i] = src_weight_gpu[src_i];
2056                 }
2057             }
2058         }
2059         else if (stage_id == 2)
2060         {
2061             // 180 degree clockwise rotation - 2
2062             for (int y = 0; y < kernel_size; ++y) {
2063                 for (int x = 0; x < kernel_size; ++x) {
2064                     const int src_i = x + y*kernel_size + i;
2065                     const int dst_i = (kernel_size - 1 - x) + (kernel_size - 1 - y)*kernel_size + i;
2066                     if (reverse) weight_deform_gpu[src_i] = src_weight_gpu[dst_i];
2067                     else weight_deform_gpu[dst_i] = src_weight_gpu[src_i];
2068                 }
2069             }
2070         }
2071         else if (stage_id == 3)
2072         {
2073             // 270 degree clockwise rotation - 3
2074             for (int y = 0; y < kernel_size; ++y) {
2075                 for (int x = 0; x < kernel_size; ++x) {
2076                     const int src_i = x + y*kernel_size + i;
2077                     const int dst_i = y + (kernel_size - 1 - x)*kernel_size + i;
2078                     if (reverse) weight_deform_gpu[src_i] = src_weight_gpu[dst_i];
2079                     else weight_deform_gpu[dst_i] = src_weight_gpu[src_i];
2080                 }
2081             }
2082         }
2083     }
2084 }
2085 
2086 
rotate_weights_gpu(const float * src_weight_gpu,float * weight_deform_gpu,int nweights,int n,int size,int reverse)2087 extern "C" void rotate_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int reverse)
2088 {
2089     const int kernel_area = size*size;
2090     const int block_size = BLOCK;
2091     const int num_blocks = get_number_of_blocks(nweights / kernel_area, block_size);
2092     rotate_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, reverse);
2093 
2094     CHECK_CUDA(cudaPeekAtLastError());
2095 }
2096 
2097 
2098 
stretch_sway_flip_weights_kernel(const float * src_weight_gpu,float * weight_deform_gpu,int nweights,int n,int kernel_size,float angle,int reverse)2099 __global__  void stretch_sway_flip_weights_kernel(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int kernel_size, float angle, int reverse)
2100 {
2101     const int index = blockIdx.x*blockDim.x + threadIdx.x;
2102     const int kernel_area = kernel_size * kernel_size;
2103     const int i = index * kernel_area;
2104 
2105     const int stage_step = (nweights / kernel_area) / 8;  // 8 stages
2106     const int stage_id = index / stage_step;
2107 
2108     // nweights = (c / groups) * n * size * size;
2109     // kernel_area = size*size
2110 
2111     if (i < nweights)
2112     {
2113 
2114         if (stage_id == 0) {
2115             // simple copy
2116             for (int x = 0; x < kernel_size; ++x) {
2117                 for (int y = 0; y < kernel_size; ++y) {
2118                     weight_deform_gpu[x + y*kernel_size + i] = src_weight_gpu[x + y*kernel_size + i];
2119                 }
2120             }
2121         }
2122         else if (stage_id == 1 || stage_id == 2 || stage_id == 3 || stage_id == 4)
2123         {
2124             float scale = 0.5;
2125             if (stage_id == 1) scale = 0.65;
2126             else if (stage_id == 2) scale = 0.8;
2127             else if (stage_id == 3) scale = 1.2;
2128             else if (stage_id == 4) scale = 1.4;
2129 
2130             if (reverse) scale = 1 / scale;
2131 
2132             const int x_c = kernel_size / 2;
2133             const int y_c = kernel_size / 2;
2134 
2135             float dropout_sum = 0;
2136 
2137             for (int y = 0; y < kernel_size; ++y) {
2138                 for (int x = 0; x < kernel_size; ++x) {
2139                     // Xsource = x_c + (x_d - x_c) / scale
2140                     // Ysource = y_c + (y_d - y_c) / scale
2141 
2142                     float x_s = x_c + (x - x_c) / scale;
2143                     float y_s = y_c + (y - y_c) / scale;
2144 
2145                     int x_0 = floor(x_s);   // round down
2146                     int x_1 = ceil(x_s);    // round up
2147                     if (x_0 == x_1) x_1 = x_0 + 1;
2148                     int y_0 = floor(y_s);
2149                     int y_1 = ceil(y_s);
2150                     if (y_0 == y_1) y_1 = y_0 + 1;
2151 
2152                     float c_x_0 = x_1 - x_s;
2153                     float c_x_1 = x_s - x_0;
2154                     float c_y_0 = y_1 - y_s;
2155                     float c_y_1 = y_s - y_0;
2156 
2157                     float val = 0;
2158                     if (x_0 >= 0 && x_0 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_0 + y_0*kernel_size + i] * c_x_0 * c_y_0;
2159                     else dropout_sum += c_x_0 * c_y_0;
2160 
2161                     if (x_1 >= 0 && x_1 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_1 + y_0*kernel_size + i] * c_x_1 * c_y_0;
2162                     else dropout_sum += c_x_1 * c_y_0;
2163 
2164                     if (x_0 >= 0 && x_0 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_0 + y_1*kernel_size + i] * c_x_0 * c_y_1;
2165                     else dropout_sum += c_x_0 * c_y_1;
2166 
2167                     if (x_1 >= 0 && x_1 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_1 + y_1*kernel_size + i] * c_x_1 * c_y_1;
2168                     else dropout_sum += c_x_1 * c_y_1;
2169 
2170                     weight_deform_gpu[x + y*kernel_size + i] = val;
2171                 }
2172             }
2173 
2174             // compensate for dropped items
2175             //const float coef = (kernel_size*kernel_size) / (kernel_size*kernel_size - dropout_sum);
2176             for (int y = 0; y < kernel_size; ++y) {
2177                 for (int x = 0; x < kernel_size; ++x) {
2178                     if(scale > 1)
2179                         weight_deform_gpu[x + y*kernel_size + i] /= scale;// *= coef;
2180                 }
2181             }
2182         }
2183         else if (stage_id == 5 || stage_id == 6)
2184         {
2185             // rotate left or right
2186             if (stage_id == 6) angle = -angle;
2187             if (reverse) angle = -angle;
2188 
2189             const float cos_a = cosf(angle * 3.14159265 / 180);
2190             const float sin_a = sinf(angle * 3.14159265 / 180);
2191             const int x_c = kernel_size / 2;
2192             const int y_c = kernel_size / 2;
2193 
2194             float dropout_sum = 0;
2195 
2196             for (int y = 0; y < kernel_size; ++y) {
2197                 for (int x = 0; x < kernel_size; ++x) {
2198                     // Xsource = x*cos(alpha) + y*sin(alpha)
2199                     // Ysource = -x*sin(alpha) + y*cos(alpha)
2200 
2201                     float x_s = x_c + (x - x_c)*cos_a + (y - y_c)*sin_a;
2202                     float y_s = y_c - (x - x_c)*sin_a + (y - y_c)*cos_a;
2203 
2204                     int x_0 = floor(x_s);   // round down
2205                     int x_1 = ceil(x_s);    // round up
2206                     if (x_0 == x_1) x_1 = x_0 + 1;
2207                     int y_0 = floor(y_s);
2208                     int y_1 = ceil(y_s);
2209                     if (y_0 == y_1) y_1 = y_0 + 1;
2210 
2211                     float c_x_0 = x_1 - x_s;
2212                     float c_x_1 = x_s - x_0;
2213                     float c_y_0 = y_1 - y_s;
2214                     float c_y_1 = y_s - y_0;
2215 
2216                     float val = 0;
2217                     if (x_0 >= 0 && x_0 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_0 + y_0*kernel_size + i] * c_x_0 * c_y_0;
2218                     else dropout_sum += c_x_0 * c_y_0;
2219 
2220                     if (x_1 >= 0 && x_1 < kernel_size && y_0 >= 0 && y_0 < kernel_size) val += src_weight_gpu[x_1 + y_0*kernel_size + i] * c_x_1 * c_y_0;
2221                     else dropout_sum += c_x_1 * c_y_0;
2222 
2223                     if (x_0 >= 0 && x_0 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_0 + y_1*kernel_size + i] * c_x_0 * c_y_1;
2224                     else dropout_sum += c_x_0 * c_y_1;
2225 
2226                     if (x_1 >= 0 && x_1 < kernel_size && y_1 >= 0 && y_1 < kernel_size) val += src_weight_gpu[x_1 + y_1*kernel_size + i] * c_x_1 * c_y_1;
2227                     else dropout_sum += c_x_1 * c_y_1;
2228 
2229                     weight_deform_gpu[x + y*kernel_size + i] = val;
2230                 }
2231             }
2232 
2233             // compensate for dropped items
2234             const float coef = (kernel_size*kernel_size) / (kernel_size*kernel_size - dropout_sum);
2235             for (int y = 0; y < kernel_size; ++y) {
2236                 for (int x = 0; x < kernel_size; ++x) {
2237                     weight_deform_gpu[x + y*kernel_size + i] *= coef;
2238                 }
2239             }
2240         }
2241         else if (stage_id == 7)
2242         {
2243             // flip
2244             for (int y = 0; y < kernel_size; ++y) {
2245                 for (int x = 0; x < kernel_size; ++x) {
2246                     weight_deform_gpu[(kernel_size - x - 1) + y*kernel_size + i] = src_weight_gpu[x + y*kernel_size + i];
2247                 }
2248             }
2249         }
2250     }
2251 }
2252 
2253 
stretch_sway_flip_weights_gpu(const float * src_weight_gpu,float * weight_deform_gpu,int nweights,int n,int size,int angle,int reverse)2254 extern "C" void stretch_sway_flip_weights_gpu(const float *src_weight_gpu, float *weight_deform_gpu, int nweights, int n, int size, int angle, int reverse)
2255 {
2256     const int kernel_area = size*size;
2257     const int block_size = BLOCK;
2258     const int num_blocks = get_number_of_blocks(nweights / kernel_area, block_size);
2259     stretch_sway_flip_weights_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_weight_gpu, weight_deform_gpu, nweights, n, size, angle, reverse);
2260 
2261     CHECK_CUDA(cudaPeekAtLastError());
2262 }
2263 
2264 
2265 
reduce_and_expand_array_kernel(const float * src_gpu,float * dst_gpu,int current_size,int groups)2266 __global__  void reduce_and_expand_array_kernel(const float *src_gpu, float *dst_gpu, int current_size, int groups)
2267 {
2268     const int index = blockIdx.x*blockDim.x + threadIdx.x;
2269 
2270     if (index < current_size) {
2271         float val = 0;
2272         for (int i = 0; i < groups; ++i) {
2273             val += src_gpu[index + i*current_size];
2274         }
2275         for (int i = 0; i < groups; ++i) {
2276             dst_gpu[index + i*current_size] = val / groups;
2277         }
2278     }
2279 }
2280 
reduce_and_expand_array_gpu(const float * src_gpu,float * dst_gpu,int size,int groups)2281 extern "C" void reduce_and_expand_array_gpu(const float *src_gpu, float *dst_gpu, int size, int groups)
2282 {
2283     const int current_size = size / groups;
2284     const int block_size = BLOCK;
2285     const int num_blocks = get_number_of_blocks(current_size, block_size);
2286     reduce_and_expand_array_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_gpu, dst_gpu, current_size, groups);
2287 
2288     CHECK_CUDA(cudaPeekAtLastError());
2289 }
2290 
2291 
2292 
expand_array_kernel(const float * src_gpu,float * dst_gpu,int current_size,int groups)2293 __global__  void expand_array_kernel(const float *src_gpu, float *dst_gpu, int current_size, int groups)
2294 {
2295     const int index = blockIdx.x*blockDim.x + threadIdx.x;
2296 
2297     if (index < current_size) {
2298         for (int i = 0; i < groups; ++i) {
2299             dst_gpu[index + i*current_size] = src_gpu[index];
2300         }
2301     }
2302 }
2303 
expand_array_gpu(const float * src_gpu,float * dst_gpu,int size,int groups)2304 extern "C" void expand_array_gpu(const float *src_gpu, float *dst_gpu, int size, int groups)
2305 {
2306     const int current_size = size / groups;
2307     const int block_size = BLOCK;
2308     const int num_blocks = get_number_of_blocks(current_size, block_size);
2309     expand_array_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (src_gpu, dst_gpu, current_size, groups);
2310 
2311     CHECK_CUDA(cudaPeekAtLastError());
2312 }