1 #include "Raster.cuh"
2 #include "TensorflowOp_generated.h"
3 namespace MNN {
4 namespace CUDA {
5 
6 template <typename T>
pack_c4(const T * input,T * output,int inside,int axis,int outside,int axisC4)7 __global__ void pack_c4(const T *input, T *output, int inside, int axis, int outside, int axisC4) {
8     int total = inside * axis * outside;
9     for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < total; i += blockDim.x * gridDim.x) {
10         int x = i % inside;
11         int tmp = i / inside;
12         int y = tmp % axis;
13         int z = tmp / axis;
14         int y4 = y / 4;
15         int yR = y % 4;
16         int dstOffset = 4 * (z * axisC4 * inside + y4 * inside + x) + yR;
17         output[dstOffset] = input[i];
18     }
19 }
20 
PackC4(uint8_t * output,const uint8_t * input,int inside,int axis,int outside,int bytes,CUDARuntime * runtime)21 void PackC4(uint8_t* output, const uint8_t* input, int inside, int axis, int outside, int bytes, CUDARuntime* runtime) {
22     auto packAxis = (axis + 3) / 4;
23     if (axis % 4 != 0) {
24         runtime->memset(output, 0, inside * packAxis * 4 * outside * bytes);
25     }
26     int block_num = runtime->blocks_num(inside * axis * outside);
27     int threads_num = runtime->threads_num();
28     switch (bytes) {
29         case 4:
30             pack_c4<<<block_num, threads_num>>>((const float*)input, (float*)output, inside, axis, outside, packAxis);
31             break;
32         case 2:
33             pack_c4<<<block_num, threads_num>>>((const int16_t*)input, (int16_t*)output, inside, axis, outside, packAxis);
34             break;
35         case 1:
36             pack_c4<<<block_num, threads_num>>>((const int8_t*)input, (int8_t*)output, inside, axis, outside, packAxis);
37             break;
38         default:
39             break;
40     }
41 }
42 
43 template <typename T>
unpack_c4(const T * input,T * output,int inside,int axis,int outside,int axisC4)44 __global__ void unpack_c4(const T *input, T *output, int inside, int axis, int outside, int axisC4) {
45     int total = inside * axis * outside;
46     for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < total; i += blockDim.x * gridDim.x) {
47         int x = i % inside;
48         int tmp = i / inside;
49         int y = tmp % axis;
50         int z = tmp / axis;
51         int y4 = y / 4;
52         int yR = y % 4;
53         int srcOffset = 4 * (z * axisC4 * inside + y4 * inside + x) + yR;
54         output[i] = input[srcOffset];
55     }
56 }
UnpackC4(uint8_t * output,const uint8_t * input,int inside,int axis,int outside,int bytes,CUDARuntime * runtime)57 void UnpackC4(uint8_t* output, const uint8_t* input, int inside, int axis, int outside, int bytes, CUDARuntime* runtime) {
58     auto packAxis = (axis + 3) / 4;
59     int block_num = runtime->blocks_num(inside * axis * outside);
60     int threads_num = runtime->threads_num();
61     switch (bytes) {
62         case 4:
63             unpack_c4<<<block_num, threads_num>>>((const float*)input, (float*)output, inside, axis, outside, packAxis);
64             break;
65         case 2:
66             unpack_c4<<<block_num, threads_num>>>((const int16_t*)input, (int16_t*)output, inside, axis, outside, packAxis);
67             break;
68         case 1:
69             unpack_c4<<<block_num, threads_num>>>((const int8_t*)input, (int8_t*)output, inside, axis, outside, packAxis);
70             break;
71         default:
72             break;
73     }
74 }
75 
76 
77 // Blit don't care offset
78 template <typename T>
blitRegion(const T * inputO,T * outputO,int loopCount,const int32_t * dstIndice,const int32_t * srcIndice,int dstUseIndice,int srcUseIndice,int dstStep,int srcStep,int srcLimit,int sizeZ,int sizeY,int sizeX,int strideZ,int strideY,int strideX,int dstStrideZ,int dstStrideY,int dstStrideX)79 __global__ void blitRegion(const T *inputO, T *outputO,
80         int loopCount,
81         const int32_t* dstIndice, const int32_t* srcIndice,
82         int dstUseIndice, int srcUseIndice,
83         int dstStep, int srcStep,int srcLimit,
84         int sizeZ, int sizeY, int sizeX,
85         int strideZ, int strideY, int strideX,
86         int dstStrideZ, int dstStrideY, int dstStrideX
87         ) {
88     int total = loopCount;
89     for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < total; i += blockDim.x * gridDim.x) {
90         int srcOffsetO = i * srcStep;
91         if (srcUseIndice >= 0) {
92             srcOffsetO = srcIndice[i] * srcStep;
93         }
94         int dstOffsetO = i * dstStep;
95         if (dstUseIndice >= 0) {
96             dstOffsetO = dstIndice[i] * dstStep;
97         }
98         if (srcOffsetO >= 0 && srcOffsetO < srcLimit) {
99             const T* input = inputO + srcOffsetO;
100             T* output = outputO + dstOffsetO;
101             for (int z=0; z<sizeZ; ++z) {
102                 for (int y=0; y<sizeY; ++y) {
103                     for (int x=0; x<sizeX; ++x) {
104                         int srcOffset = z * strideZ + y * strideY + x * strideX;
105                         int dstOffset = z * dstStrideZ + y * dstStrideY + x * dstStrideX;
106                         output[dstOffset] = input[srcOffset];
107                     }
108                 }
109             }
110         } else {
111             T* output = outputO + dstOffsetO;
112             for (int z=0; z<sizeZ; ++z) {
113                 for (int y=0; y<sizeY; ++y) {
114                     for (int x=0; x<sizeX; ++x) {
115                         int dstOffset = z * dstStrideZ + y * dstStrideY + x * dstStrideX;
116                         output[dstOffset] = (T)0;
117                     }
118                 }
119             }
120         }
121     }
122 }
BlitWithIndice(uint8_t * output,const uint8_t * input,const int32_t * dstIndices,const int32_t * srcIndices,int dstUseIndice,int srcUseIndice,int loopCount,int dstStep,int srcStep,int srcLimit,const Tensor::InsideDescribe::Region & reg,int bytes,CUDARuntime * runtime)123 void BlitWithIndice(uint8_t* output, const uint8_t* input, const int32_t* dstIndices, const int32_t* srcIndices, int dstUseIndice, int srcUseIndice, int loopCount, int dstStep, int srcStep, int srcLimit, const Tensor::InsideDescribe::Region& reg, int bytes, CUDARuntime* runtime) {
124     int count = loopCount;
125     int block_num = runtime->blocks_num(count);
126     int threads_num = runtime->threads_num();
127     switch (bytes) {
128         case 4:
129             blitRegion<<<block_num, threads_num>>>((const float*)input, (float*)output,
130                 loopCount,
131                 dstIndices, srcIndices,
132                 dstUseIndice, srcUseIndice,
133                 dstStep, srcStep, srcLimit,
134                 reg.size[0], reg.size[1], reg.size[2],
135                 reg.src.stride[0], reg.src.stride[1], reg.src.stride[2],
136                 reg.dst.stride[0], reg.dst.stride[1], reg.dst.stride[2]);
137             break;
138         case 2:
139             blitRegion<<<block_num, threads_num>>>((const int16_t*)input, (int16_t*)output,
140                 loopCount,
141                 dstIndices, srcIndices,
142                 dstUseIndice, srcUseIndice,
143                 dstStep, srcStep, srcLimit,
144                 reg.size[0], reg.size[1], reg.size[2],
145                 reg.src.stride[0], reg.src.stride[1], reg.src.stride[2],
146                 reg.dst.stride[0], reg.dst.stride[1], reg.dst.stride[2]);
147             break;
148         case 1:
149             blitRegion<<<block_num, threads_num>>>((const int8_t*)input, (int8_t*)output,
150                 loopCount,
151                 dstIndices, srcIndices,
152                 dstUseIndice, srcUseIndice,
153                 dstStep, srcStep, srcLimit,
154                 reg.size[0], reg.size[1], reg.size[2],
155                 reg.src.stride[0], reg.src.stride[1], reg.src.stride[2],
156                 reg.dst.stride[0], reg.dst.stride[1], reg.dst.stride[2]);
157             break;
158         default:
159             break;
160     }
161 }
162 
163 #define UNARY_FUNC(Name, Func)\
164 template<typename T>\
165 __global__ void Name(const T *input, T *output,\
166         int sizeZ, int sizeY, int sizeX,\
167         int strideZ, int strideY, int strideX,\
168         int dstStrideZ, int dstStrideY, int dstStrideX\
169         ) { \
170   int count = sizeZ * sizeY * sizeX;\
171   for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\
172     int total = sizeZ * sizeY * sizeX;\
173     int ix = i % sizeX;\
174     int tmp = i / sizeX;\
175     int iy = tmp % sizeY;\
176     int iz = tmp / sizeY;\
177     int srcOffset = iz * strideZ + iy * strideY + ix * strideX;\
178     int dstOffset = iz * dstStrideZ + iy * dstStrideY + ix * dstStrideX;\
179     T x = input[srcOffset];\
180     output[dstOffset] = Func;\
181   }\
182 }\
183 
184 UNARY_FUNC(blit, x);
185 UNARY_FUNC(ABS, abs(x));
186 UNARY_FUNC(EXP, exp(x));
187 UNARY_FUNC(NEG, -x);
188 UNARY_FUNC(RECIPROCAL, (T)(1.0)/x);
189 UNARY_FUNC(FLOOR, floor(x));
190 UNARY_FUNC(CEIL, ceil(x));
191 UNARY_FUNC(SQUARE, x*x);
192 UNARY_FUNC(SQRT, (T)(sqrt((float)x)));
193 UNARY_FUNC(RSQRT, (T)(rsqrt((float)x)));
194 UNARY_FUNC(LOG, (T)(log((float)x)));
195 UNARY_FUNC(SIN, (T)(sin((float)x)));
196 UNARY_FUNC(COS, (T)(cos((float)x)));
197 UNARY_FUNC(TAN, (T)(tan((float)x)));
198 UNARY_FUNC(ASIN, (T)(asin((float)x)));
199 UNARY_FUNC(ACOS, (T)(acos((float)x)));
200 UNARY_FUNC(ATAN, (T)(atan((float)x)));
201 UNARY_FUNC(LOG1P, log(1+x));
202 UNARY_FUNC(TANH, tanh(x));
203 UNARY_FUNC(SIGMOID, 1./(1.+exp(-x)));
204 UNARY_FUNC(EXPM1, exp(x)-1);
205 UNARY_FUNC(ATANH, atanh(x));
206 UNARY_FUNC(ACOSH, acosh(x));
207 UNARY_FUNC(COSH, cosh(x));
208 UNARY_FUNC(SIGN, x > 0 ? 1 : (x<0 ? -1 : 0));
209 UNARY_FUNC(ROUND, round(x));
210 UNARY_FUNC(SINH, sinh(x));
211 UNARY_FUNC(ASINH, asinh(x));
212 UNARY_FUNC(HARDSWISH, 1.0/6.0 * x * min(max(x+3.0, 0.0), 6.0));
213 
RasterBlit(uint8_t * output,const uint8_t * input,const int32_t * size,const int32_t * srcStride,const int32_t * dstStride,int bytes,CUDARuntime * runtime)214 void RasterBlit(uint8_t* output, const uint8_t* input, const int32_t* size, const int32_t* srcStride, const int32_t* dstStride, int bytes, CUDARuntime* runtime) {
215     int count = size[0] * size[1] * size[2];
216     int block_num = runtime->blocks_num(count);
217     int threads_num = runtime->threads_num();
218     switch (bytes) {
219         case 4:
220             blit<<<block_num, threads_num>>>((const float*)input, (float*)output,
221                 size[0], size[1], size[2],
222                 srcStride[0], srcStride[1], srcStride[2],
223                 dstStride[0], dstStride[1], dstStride[2]);
224             break;
225         case 2:
226             blit<<<block_num, threads_num>>>((const int16_t*)input, (int16_t*)output,
227                 size[0], size[1], size[2],
228                 srcStride[0], srcStride[1], srcStride[2],
229                 dstStride[0], dstStride[1], dstStride[2]);
230             break;
231         case 1:
232             blit<<<block_num, threads_num>>>((const int8_t*)input, (int8_t*)output,
233                 size[0], size[1], size[2],
234                 srcStride[0], srcStride[1], srcStride[2],
235                 dstStride[0], dstStride[1], dstStride[2]);
236             break;
237         default:
238             break;
239     }
240 }
241 
UnaryBlit(uint8_t * output,const uint8_t * input,const int32_t * size,const int32_t * srcStride,const int32_t * dstStride,int bytes,CUDARuntime * runtime,int opType)242 void UnaryBlit(uint8_t* output, const uint8_t* input, const int32_t* size, const int32_t* srcStride, const int32_t* dstStride, int bytes, CUDARuntime* runtime, int opType) {
243     int count = size[0] * size[1] * size[2];
244     int block_num = runtime->blocks_num(count);
245     int threads_num = runtime->threads_num();
246     // TODO: Support FP16
247     MNN_ASSERT(bytes==4);
248     #define COMPUTE(TYPE)\
249     if (opType == MNN::UnaryOpOperation_##TYPE ) {\
250             TYPE<<<block_num, threads_num>>>((const float*)input, (float*)output,\
251                 size[0], size[1], size[2],\
252                 srcStride[0], srcStride[1], srcStride[2],\
253                 dstStride[0], dstStride[1], dstStride[2]);\
254         return;\
255     }\
256 
257     COMPUTE(ABS);
258     COMPUTE(NEG);
259     COMPUTE(FLOOR);
260     COMPUTE(CEIL);
261     COMPUTE(SQUARE);
262     COMPUTE(SQRT);
263     COMPUTE(RSQRT);
264     COMPUTE(EXP);
265     COMPUTE(LOG);
266     COMPUTE(SIN);
267     COMPUTE(COS);
268     COMPUTE(TAN);
269     COMPUTE(ASIN);
270     COMPUTE(ACOS);
271     COMPUTE(ATAN);
272     COMPUTE(RECIPROCAL);
273     COMPUTE(LOG1P);
274     COMPUTE(TANH);
275     COMPUTE(SIGMOID);
276     COMPUTE(EXPM1);
277     COMPUTE(ACOSH);
278     COMPUTE(ATANH);
279     COMPUTE(SIGN);
280     COMPUTE(COSH);
281     COMPUTE(ROUND);
282     COMPUTE(SINH);
283     COMPUTE(ASINH);
284     COMPUTE(HARDSWISH);
285 
286     #undef COMPUTE
287 }
288 
289 #define BINARY_FUNC(Name, Func)\
290 template<typename TIn, typename TOut>\
291 __global__ void Binary##Name(\
292         const TIn *input0, const TIn* input1, TOut *output,\
293         int sizeZ, int sizeY, int sizeX,\
294         int strideZ, int strideY, int strideX,\
295         int strideZ1, int strideY1, int strideX1,\
296         int dstStrideZ, int dstStrideY, int dstStrideX\
297         ) { \
298   int count = sizeZ * sizeY * sizeX;\
299   for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\
300     int total = sizeZ * sizeY * sizeX;\
301     int ix = i % sizeX;\
302     int tmp = i / sizeX;\
303     int iy = tmp % sizeY;\
304     int iz = tmp / sizeY;\
305     int srcOffset = iz * strideZ + iy * strideY + ix * strideX;\
306     int srcOffset1 = iz * strideZ1 + iy * strideY1 + ix * strideX1;\
307     int dstOffset = iz * dstStrideZ + iy * dstStrideY + ix * dstStrideX;\
308     TIn x = input0[srcOffset];\
309     TIn y = input1[srcOffset1];\
310     output[dstOffset] = (TOut)Func;\
311   }\
312 }\
313 
314 #define sign(y) ((y) > 0 ? 1 : ((y) < 0 ? -1 : 0))
315 
316 BINARY_FUNC(ADD, x+y);
317 BINARY_FUNC(SUB, x-y);
318 BINARY_FUNC(MUL, x*y);
319 BINARY_FUNC(DIV, x/y);
320 BINARY_FUNC(REALDIV, (float)sign(y) * x / max(abs(y), 0.0000001));
321 BINARY_FUNC(MINIMUM, min(x, y));
322 BINARY_FUNC(MAXIMUM, max(x, y));
323 BINARY_FUNC(GREATER, x > y ? 1 : 0);
324 BINARY_FUNC(LESS, x < y ? 1 : 0);
325 BINARY_FUNC(LESS_EQUAL, x <= y ? 1 : 0);
326 BINARY_FUNC(GREATER_EQUAL, x >= y ? 1 : 0);
327 BINARY_FUNC(EQUAL, x == y ? 1 : 0);
328 BINARY_FUNC(NOTEQUAL, x != y ? 1 : 0);
329 BINARY_FUNC(FLOORDIV, floor(x / y));
330 BINARY_FUNC(FLOORMOD, x - floor(x / y) * y);
331 BINARY_FUNC(SquaredDifference, (x-y)*(x-y));
332 BINARY_FUNC(POW, pow(x, y));
333 BINARY_FUNC(ATAN2, atan2(x, y));
334 BINARY_FUNC(MOD, x - x / y);
335 BINARY_FUNC(LOGICALOR, (x || y) ? 1 : 0);
336 
BinaryBlitTemplateFloat(uint8_t * output,const uint8_t * input,const uint8_t * input1,const int32_t * size,const int32_t * srcStride,const int32_t * srcStride1,const int32_t * dstStride,int bytes,CUDARuntime * runtime,int opType)337 void BinaryBlitTemplateFloat(uint8_t* output, const uint8_t* input, const uint8_t* input1, const int32_t* size, const int32_t* srcStride, const int32_t* srcStride1, const int32_t* dstStride, int bytes, CUDARuntime* runtime, int opType) {
338     int count = size[0] * size[1] * size[2];
339     int block_num = runtime->blocks_num(count);
340     int threads_num = runtime->threads_num();
341     // TODO: Support FP16
342     MNN_ASSERT(bytes==4);
343     #define COMPUTE_FLOAT(TYPE, TOut)\
344     if (opType == MNN::BinaryOpOperation_##TYPE ) {\
345             Binary##TYPE<<<block_num, threads_num>>>((const float*)input, (const float*)(input1), (TOut*)output,\
346                 size[0], size[1], size[2],\
347                 srcStride[0], srcStride[1], srcStride[2],\
348                 srcStride1[0], srcStride1[1], srcStride1[2],\
349                 dstStride[0], dstStride[1], dstStride[2]);\
350         return;\
351     }\
352 
353     COMPUTE_FLOAT(ADD, float);
354     COMPUTE_FLOAT(SUB, float);
355     COMPUTE_FLOAT(MUL, float);
356     COMPUTE_FLOAT(DIV, float);
357     COMPUTE_FLOAT(REALDIV, float);
358     COMPUTE_FLOAT(MINIMUM, float);
359     COMPUTE_FLOAT(MAXIMUM, float);
360     COMPUTE_FLOAT(GREATER, int);
361     COMPUTE_FLOAT(LESS, int);
362     COMPUTE_FLOAT(LESS_EQUAL, int);
363     COMPUTE_FLOAT(GREATER_EQUAL, int);
364     COMPUTE_FLOAT(EQUAL, int);
365     COMPUTE_FLOAT(NOTEQUAL, int);
366     COMPUTE_FLOAT(FLOORDIV, float);
367     COMPUTE_FLOAT(FLOORMOD, float);
368     COMPUTE_FLOAT(POW, float);
369     COMPUTE_FLOAT(SquaredDifference, float);
370     COMPUTE_FLOAT(ATAN2, float);
371     COMPUTE_FLOAT(MOD, float);
372 }
373 
BinaryBlitTemplateInt32(uint8_t * output,const uint8_t * input,const uint8_t * input1,const int32_t * size,const int32_t * srcStride,const int32_t * srcStride1,const int32_t * dstStride,int bytes,CUDARuntime * runtime,int opType)374 void BinaryBlitTemplateInt32(uint8_t* output, const uint8_t* input, const uint8_t* input1, const int32_t* size, const int32_t* srcStride, const int32_t* srcStride1, const int32_t* dstStride, int bytes, CUDARuntime* runtime, int opType) {
375     int count = size[0] * size[1] * size[2];
376     int block_num = runtime->blocks_num(count);
377     int threads_num = runtime->threads_num();
378     #define COMPUTE_INT(TYPE, TOut)\
379     if (opType == MNN::BinaryOpOperation_##TYPE ) {\
380             Binary##TYPE<<<block_num, threads_num>>>((const int*)input, (const int*)(input1), (TOut*)output,\
381                 size[0], size[1], size[2],\
382                 srcStride[0], srcStride[1], srcStride[2],\
383                 srcStride1[0], srcStride1[1], srcStride1[2],\
384                 dstStride[0], dstStride[1], dstStride[2]);\
385         return;\
386     }\
387 
388     COMPUTE_INT(ADD, int);
389     COMPUTE_INT(SUB, int);
390     COMPUTE_INT(MUL, int);
391     COMPUTE_INT(DIV, int);
392     COMPUTE_INT(MINIMUM, int);
393     COMPUTE_INT(MAXIMUM, int);
394     COMPUTE_INT(GREATER, int);
395     COMPUTE_INT(LESS, int);
396     COMPUTE_INT(LESS_EQUAL, int);
397     COMPUTE_INT(GREATER_EQUAL, int);
398     COMPUTE_INT(EQUAL, int);
399     COMPUTE_INT(NOTEQUAL, int);
400     COMPUTE_INT(SquaredDifference, int);
401     COMPUTE_INT(MOD, int);
402     COMPUTE_INT(LOGICALOR, int);
403 }
404 
405 
BinaryBlit(uint8_t * output,const uint8_t * input,const uint8_t * input1,const int32_t * size,const int32_t * srcStride,const int32_t * srcStride1,const int32_t * dstStride,halide_type_t type,CUDARuntime * runtime,int opType)406 void BinaryBlit(uint8_t* output, const uint8_t* input, const uint8_t* input1, const int32_t* size, const int32_t* srcStride, const int32_t* srcStride1, const int32_t* dstStride, halide_type_t type, CUDARuntime* runtime, int opType) {
407     if (type.code == halide_type_float) {
408         BinaryBlitTemplateFloat(output, input, input1, size, srcStride, srcStride1, dstStride, type.bytes(), runtime, opType);
409     } else if (type.code == halide_type_int) {
410         BinaryBlitTemplateInt32(output, input, input1, size, srcStride, srcStride1, dstStride, type.bytes(), runtime, opType);
411     }
412 }
413 
414 
415 }// namespace CUDA
416 }// namespace MNN
417