1 #include "Raster.cuh"
2 #include "TensorflowOp_generated.h"
3 namespace MNN {
4 namespace CUDA {
5
6 template <typename T>
pack_c4(const T * input,T * output,int inside,int axis,int outside,int axisC4)7 __global__ void pack_c4(const T *input, T *output, int inside, int axis, int outside, int axisC4) {
8 int total = inside * axis * outside;
9 for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < total; i += blockDim.x * gridDim.x) {
10 int x = i % inside;
11 int tmp = i / inside;
12 int y = tmp % axis;
13 int z = tmp / axis;
14 int y4 = y / 4;
15 int yR = y % 4;
16 int dstOffset = 4 * (z * axisC4 * inside + y4 * inside + x) + yR;
17 output[dstOffset] = input[i];
18 }
19 }
20
PackC4(uint8_t * output,const uint8_t * input,int inside,int axis,int outside,int bytes,CUDARuntime * runtime)21 void PackC4(uint8_t* output, const uint8_t* input, int inside, int axis, int outside, int bytes, CUDARuntime* runtime) {
22 auto packAxis = (axis + 3) / 4;
23 if (axis % 4 != 0) {
24 runtime->memset(output, 0, inside * packAxis * 4 * outside * bytes);
25 }
26 int block_num = runtime->blocks_num(inside * axis * outside);
27 int threads_num = runtime->threads_num();
28 switch (bytes) {
29 case 4:
30 pack_c4<<<block_num, threads_num>>>((const float*)input, (float*)output, inside, axis, outside, packAxis);
31 break;
32 case 2:
33 pack_c4<<<block_num, threads_num>>>((const int16_t*)input, (int16_t*)output, inside, axis, outside, packAxis);
34 break;
35 case 1:
36 pack_c4<<<block_num, threads_num>>>((const int8_t*)input, (int8_t*)output, inside, axis, outside, packAxis);
37 break;
38 default:
39 break;
40 }
41 }
42
43 template <typename T>
unpack_c4(const T * input,T * output,int inside,int axis,int outside,int axisC4)44 __global__ void unpack_c4(const T *input, T *output, int inside, int axis, int outside, int axisC4) {
45 int total = inside * axis * outside;
46 for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < total; i += blockDim.x * gridDim.x) {
47 int x = i % inside;
48 int tmp = i / inside;
49 int y = tmp % axis;
50 int z = tmp / axis;
51 int y4 = y / 4;
52 int yR = y % 4;
53 int srcOffset = 4 * (z * axisC4 * inside + y4 * inside + x) + yR;
54 output[i] = input[srcOffset];
55 }
56 }
UnpackC4(uint8_t * output,const uint8_t * input,int inside,int axis,int outside,int bytes,CUDARuntime * runtime)57 void UnpackC4(uint8_t* output, const uint8_t* input, int inside, int axis, int outside, int bytes, CUDARuntime* runtime) {
58 auto packAxis = (axis + 3) / 4;
59 int block_num = runtime->blocks_num(inside * axis * outside);
60 int threads_num = runtime->threads_num();
61 switch (bytes) {
62 case 4:
63 unpack_c4<<<block_num, threads_num>>>((const float*)input, (float*)output, inside, axis, outside, packAxis);
64 break;
65 case 2:
66 unpack_c4<<<block_num, threads_num>>>((const int16_t*)input, (int16_t*)output, inside, axis, outside, packAxis);
67 break;
68 case 1:
69 unpack_c4<<<block_num, threads_num>>>((const int8_t*)input, (int8_t*)output, inside, axis, outside, packAxis);
70 break;
71 default:
72 break;
73 }
74 }
75
76
77 // Blit don't care offset
78 template <typename T>
blitRegion(const T * inputO,T * outputO,int loopCount,const int32_t * dstIndice,const int32_t * srcIndice,int dstUseIndice,int srcUseIndice,int dstStep,int srcStep,int srcLimit,int sizeZ,int sizeY,int sizeX,int strideZ,int strideY,int strideX,int dstStrideZ,int dstStrideY,int dstStrideX)79 __global__ void blitRegion(const T *inputO, T *outputO,
80 int loopCount,
81 const int32_t* dstIndice, const int32_t* srcIndice,
82 int dstUseIndice, int srcUseIndice,
83 int dstStep, int srcStep,int srcLimit,
84 int sizeZ, int sizeY, int sizeX,
85 int strideZ, int strideY, int strideX,
86 int dstStrideZ, int dstStrideY, int dstStrideX
87 ) {
88 int total = loopCount;
89 for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < total; i += blockDim.x * gridDim.x) {
90 int srcOffsetO = i * srcStep;
91 if (srcUseIndice >= 0) {
92 srcOffsetO = srcIndice[i] * srcStep;
93 }
94 int dstOffsetO = i * dstStep;
95 if (dstUseIndice >= 0) {
96 dstOffsetO = dstIndice[i] * dstStep;
97 }
98 if (srcOffsetO >= 0 && srcOffsetO < srcLimit) {
99 const T* input = inputO + srcOffsetO;
100 T* output = outputO + dstOffsetO;
101 for (int z=0; z<sizeZ; ++z) {
102 for (int y=0; y<sizeY; ++y) {
103 for (int x=0; x<sizeX; ++x) {
104 int srcOffset = z * strideZ + y * strideY + x * strideX;
105 int dstOffset = z * dstStrideZ + y * dstStrideY + x * dstStrideX;
106 output[dstOffset] = input[srcOffset];
107 }
108 }
109 }
110 } else {
111 T* output = outputO + dstOffsetO;
112 for (int z=0; z<sizeZ; ++z) {
113 for (int y=0; y<sizeY; ++y) {
114 for (int x=0; x<sizeX; ++x) {
115 int dstOffset = z * dstStrideZ + y * dstStrideY + x * dstStrideX;
116 output[dstOffset] = (T)0;
117 }
118 }
119 }
120 }
121 }
122 }
BlitWithIndice(uint8_t * output,const uint8_t * input,const int32_t * dstIndices,const int32_t * srcIndices,int dstUseIndice,int srcUseIndice,int loopCount,int dstStep,int srcStep,int srcLimit,const Tensor::InsideDescribe::Region & reg,int bytes,CUDARuntime * runtime)123 void BlitWithIndice(uint8_t* output, const uint8_t* input, const int32_t* dstIndices, const int32_t* srcIndices, int dstUseIndice, int srcUseIndice, int loopCount, int dstStep, int srcStep, int srcLimit, const Tensor::InsideDescribe::Region& reg, int bytes, CUDARuntime* runtime) {
124 int count = loopCount;
125 int block_num = runtime->blocks_num(count);
126 int threads_num = runtime->threads_num();
127 switch (bytes) {
128 case 4:
129 blitRegion<<<block_num, threads_num>>>((const float*)input, (float*)output,
130 loopCount,
131 dstIndices, srcIndices,
132 dstUseIndice, srcUseIndice,
133 dstStep, srcStep, srcLimit,
134 reg.size[0], reg.size[1], reg.size[2],
135 reg.src.stride[0], reg.src.stride[1], reg.src.stride[2],
136 reg.dst.stride[0], reg.dst.stride[1], reg.dst.stride[2]);
137 break;
138 case 2:
139 blitRegion<<<block_num, threads_num>>>((const int16_t*)input, (int16_t*)output,
140 loopCount,
141 dstIndices, srcIndices,
142 dstUseIndice, srcUseIndice,
143 dstStep, srcStep, srcLimit,
144 reg.size[0], reg.size[1], reg.size[2],
145 reg.src.stride[0], reg.src.stride[1], reg.src.stride[2],
146 reg.dst.stride[0], reg.dst.stride[1], reg.dst.stride[2]);
147 break;
148 case 1:
149 blitRegion<<<block_num, threads_num>>>((const int8_t*)input, (int8_t*)output,
150 loopCount,
151 dstIndices, srcIndices,
152 dstUseIndice, srcUseIndice,
153 dstStep, srcStep, srcLimit,
154 reg.size[0], reg.size[1], reg.size[2],
155 reg.src.stride[0], reg.src.stride[1], reg.src.stride[2],
156 reg.dst.stride[0], reg.dst.stride[1], reg.dst.stride[2]);
157 break;
158 default:
159 break;
160 }
161 }
162
163 #define UNARY_FUNC(Name, Func)\
164 template<typename T>\
165 __global__ void Name(const T *input, T *output,\
166 int sizeZ, int sizeY, int sizeX,\
167 int strideZ, int strideY, int strideX,\
168 int dstStrideZ, int dstStrideY, int dstStrideX\
169 ) { \
170 int count = sizeZ * sizeY * sizeX;\
171 for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\
172 int total = sizeZ * sizeY * sizeX;\
173 int ix = i % sizeX;\
174 int tmp = i / sizeX;\
175 int iy = tmp % sizeY;\
176 int iz = tmp / sizeY;\
177 int srcOffset = iz * strideZ + iy * strideY + ix * strideX;\
178 int dstOffset = iz * dstStrideZ + iy * dstStrideY + ix * dstStrideX;\
179 T x = input[srcOffset];\
180 output[dstOffset] = Func;\
181 }\
182 }\
183
184 UNARY_FUNC(blit, x);
185 UNARY_FUNC(ABS, abs(x));
186 UNARY_FUNC(EXP, exp(x));
187 UNARY_FUNC(NEG, -x);
188 UNARY_FUNC(RECIPROCAL, (T)(1.0)/x);
189 UNARY_FUNC(FLOOR, floor(x));
190 UNARY_FUNC(CEIL, ceil(x));
191 UNARY_FUNC(SQUARE, x*x);
192 UNARY_FUNC(SQRT, (T)(sqrt((float)x)));
193 UNARY_FUNC(RSQRT, (T)(rsqrt((float)x)));
194 UNARY_FUNC(LOG, (T)(log((float)x)));
195 UNARY_FUNC(SIN, (T)(sin((float)x)));
196 UNARY_FUNC(COS, (T)(cos((float)x)));
197 UNARY_FUNC(TAN, (T)(tan((float)x)));
198 UNARY_FUNC(ASIN, (T)(asin((float)x)));
199 UNARY_FUNC(ACOS, (T)(acos((float)x)));
200 UNARY_FUNC(ATAN, (T)(atan((float)x)));
201 UNARY_FUNC(LOG1P, log(1+x));
202 UNARY_FUNC(TANH, tanh(x));
203 UNARY_FUNC(SIGMOID, 1./(1.+exp(-x)));
204 UNARY_FUNC(EXPM1, exp(x)-1);
205 UNARY_FUNC(ATANH, atanh(x));
206 UNARY_FUNC(ACOSH, acosh(x));
207 UNARY_FUNC(COSH, cosh(x));
208 UNARY_FUNC(SIGN, x > 0 ? 1 : (x<0 ? -1 : 0));
209 UNARY_FUNC(ROUND, round(x));
210 UNARY_FUNC(SINH, sinh(x));
211 UNARY_FUNC(ASINH, asinh(x));
212 UNARY_FUNC(HARDSWISH, 1.0/6.0 * x * min(max(x+3.0, 0.0), 6.0));
213
RasterBlit(uint8_t * output,const uint8_t * input,const int32_t * size,const int32_t * srcStride,const int32_t * dstStride,int bytes,CUDARuntime * runtime)214 void RasterBlit(uint8_t* output, const uint8_t* input, const int32_t* size, const int32_t* srcStride, const int32_t* dstStride, int bytes, CUDARuntime* runtime) {
215 int count = size[0] * size[1] * size[2];
216 int block_num = runtime->blocks_num(count);
217 int threads_num = runtime->threads_num();
218 switch (bytes) {
219 case 4:
220 blit<<<block_num, threads_num>>>((const float*)input, (float*)output,
221 size[0], size[1], size[2],
222 srcStride[0], srcStride[1], srcStride[2],
223 dstStride[0], dstStride[1], dstStride[2]);
224 break;
225 case 2:
226 blit<<<block_num, threads_num>>>((const int16_t*)input, (int16_t*)output,
227 size[0], size[1], size[2],
228 srcStride[0], srcStride[1], srcStride[2],
229 dstStride[0], dstStride[1], dstStride[2]);
230 break;
231 case 1:
232 blit<<<block_num, threads_num>>>((const int8_t*)input, (int8_t*)output,
233 size[0], size[1], size[2],
234 srcStride[0], srcStride[1], srcStride[2],
235 dstStride[0], dstStride[1], dstStride[2]);
236 break;
237 default:
238 break;
239 }
240 }
241
UnaryBlit(uint8_t * output,const uint8_t * input,const int32_t * size,const int32_t * srcStride,const int32_t * dstStride,int bytes,CUDARuntime * runtime,int opType)242 void UnaryBlit(uint8_t* output, const uint8_t* input, const int32_t* size, const int32_t* srcStride, const int32_t* dstStride, int bytes, CUDARuntime* runtime, int opType) {
243 int count = size[0] * size[1] * size[2];
244 int block_num = runtime->blocks_num(count);
245 int threads_num = runtime->threads_num();
246 // TODO: Support FP16
247 MNN_ASSERT(bytes==4);
248 #define COMPUTE(TYPE)\
249 if (opType == MNN::UnaryOpOperation_##TYPE ) {\
250 TYPE<<<block_num, threads_num>>>((const float*)input, (float*)output,\
251 size[0], size[1], size[2],\
252 srcStride[0], srcStride[1], srcStride[2],\
253 dstStride[0], dstStride[1], dstStride[2]);\
254 return;\
255 }\
256
257 COMPUTE(ABS);
258 COMPUTE(NEG);
259 COMPUTE(FLOOR);
260 COMPUTE(CEIL);
261 COMPUTE(SQUARE);
262 COMPUTE(SQRT);
263 COMPUTE(RSQRT);
264 COMPUTE(EXP);
265 COMPUTE(LOG);
266 COMPUTE(SIN);
267 COMPUTE(COS);
268 COMPUTE(TAN);
269 COMPUTE(ASIN);
270 COMPUTE(ACOS);
271 COMPUTE(ATAN);
272 COMPUTE(RECIPROCAL);
273 COMPUTE(LOG1P);
274 COMPUTE(TANH);
275 COMPUTE(SIGMOID);
276 COMPUTE(EXPM1);
277 COMPUTE(ACOSH);
278 COMPUTE(ATANH);
279 COMPUTE(SIGN);
280 COMPUTE(COSH);
281 COMPUTE(ROUND);
282 COMPUTE(SINH);
283 COMPUTE(ASINH);
284 COMPUTE(HARDSWISH);
285
286 #undef COMPUTE
287 }
288
289 #define BINARY_FUNC(Name, Func)\
290 template<typename TIn, typename TOut>\
291 __global__ void Binary##Name(\
292 const TIn *input0, const TIn* input1, TOut *output,\
293 int sizeZ, int sizeY, int sizeX,\
294 int strideZ, int strideY, int strideX,\
295 int strideZ1, int strideY1, int strideX1,\
296 int dstStrideZ, int dstStrideY, int dstStrideX\
297 ) { \
298 int count = sizeZ * sizeY * sizeX;\
299 for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\
300 int total = sizeZ * sizeY * sizeX;\
301 int ix = i % sizeX;\
302 int tmp = i / sizeX;\
303 int iy = tmp % sizeY;\
304 int iz = tmp / sizeY;\
305 int srcOffset = iz * strideZ + iy * strideY + ix * strideX;\
306 int srcOffset1 = iz * strideZ1 + iy * strideY1 + ix * strideX1;\
307 int dstOffset = iz * dstStrideZ + iy * dstStrideY + ix * dstStrideX;\
308 TIn x = input0[srcOffset];\
309 TIn y = input1[srcOffset1];\
310 output[dstOffset] = (TOut)Func;\
311 }\
312 }\
313
314 #define sign(y) ((y) > 0 ? 1 : ((y) < 0 ? -1 : 0))
315
316 BINARY_FUNC(ADD, x+y);
317 BINARY_FUNC(SUB, x-y);
318 BINARY_FUNC(MUL, x*y);
319 BINARY_FUNC(DIV, x/y);
320 BINARY_FUNC(REALDIV, (float)sign(y) * x / max(abs(y), 0.0000001));
321 BINARY_FUNC(MINIMUM, min(x, y));
322 BINARY_FUNC(MAXIMUM, max(x, y));
323 BINARY_FUNC(GREATER, x > y ? 1 : 0);
324 BINARY_FUNC(LESS, x < y ? 1 : 0);
325 BINARY_FUNC(LESS_EQUAL, x <= y ? 1 : 0);
326 BINARY_FUNC(GREATER_EQUAL, x >= y ? 1 : 0);
327 BINARY_FUNC(EQUAL, x == y ? 1 : 0);
328 BINARY_FUNC(NOTEQUAL, x != y ? 1 : 0);
329 BINARY_FUNC(FLOORDIV, floor(x / y));
330 BINARY_FUNC(FLOORMOD, x - floor(x / y) * y);
331 BINARY_FUNC(SquaredDifference, (x-y)*(x-y));
332 BINARY_FUNC(POW, pow(x, y));
333 BINARY_FUNC(ATAN2, atan2(x, y));
334 BINARY_FUNC(MOD, x - x / y);
335 BINARY_FUNC(LOGICALOR, (x || y) ? 1 : 0);
336
BinaryBlitTemplateFloat(uint8_t * output,const uint8_t * input,const uint8_t * input1,const int32_t * size,const int32_t * srcStride,const int32_t * srcStride1,const int32_t * dstStride,int bytes,CUDARuntime * runtime,int opType)337 void BinaryBlitTemplateFloat(uint8_t* output, const uint8_t* input, const uint8_t* input1, const int32_t* size, const int32_t* srcStride, const int32_t* srcStride1, const int32_t* dstStride, int bytes, CUDARuntime* runtime, int opType) {
338 int count = size[0] * size[1] * size[2];
339 int block_num = runtime->blocks_num(count);
340 int threads_num = runtime->threads_num();
341 // TODO: Support FP16
342 MNN_ASSERT(bytes==4);
343 #define COMPUTE_FLOAT(TYPE, TOut)\
344 if (opType == MNN::BinaryOpOperation_##TYPE ) {\
345 Binary##TYPE<<<block_num, threads_num>>>((const float*)input, (const float*)(input1), (TOut*)output,\
346 size[0], size[1], size[2],\
347 srcStride[0], srcStride[1], srcStride[2],\
348 srcStride1[0], srcStride1[1], srcStride1[2],\
349 dstStride[0], dstStride[1], dstStride[2]);\
350 return;\
351 }\
352
353 COMPUTE_FLOAT(ADD, float);
354 COMPUTE_FLOAT(SUB, float);
355 COMPUTE_FLOAT(MUL, float);
356 COMPUTE_FLOAT(DIV, float);
357 COMPUTE_FLOAT(REALDIV, float);
358 COMPUTE_FLOAT(MINIMUM, float);
359 COMPUTE_FLOAT(MAXIMUM, float);
360 COMPUTE_FLOAT(GREATER, int);
361 COMPUTE_FLOAT(LESS, int);
362 COMPUTE_FLOAT(LESS_EQUAL, int);
363 COMPUTE_FLOAT(GREATER_EQUAL, int);
364 COMPUTE_FLOAT(EQUAL, int);
365 COMPUTE_FLOAT(NOTEQUAL, int);
366 COMPUTE_FLOAT(FLOORDIV, float);
367 COMPUTE_FLOAT(FLOORMOD, float);
368 COMPUTE_FLOAT(POW, float);
369 COMPUTE_FLOAT(SquaredDifference, float);
370 COMPUTE_FLOAT(ATAN2, float);
371 COMPUTE_FLOAT(MOD, float);
372 }
373
BinaryBlitTemplateInt32(uint8_t * output,const uint8_t * input,const uint8_t * input1,const int32_t * size,const int32_t * srcStride,const int32_t * srcStride1,const int32_t * dstStride,int bytes,CUDARuntime * runtime,int opType)374 void BinaryBlitTemplateInt32(uint8_t* output, const uint8_t* input, const uint8_t* input1, const int32_t* size, const int32_t* srcStride, const int32_t* srcStride1, const int32_t* dstStride, int bytes, CUDARuntime* runtime, int opType) {
375 int count = size[0] * size[1] * size[2];
376 int block_num = runtime->blocks_num(count);
377 int threads_num = runtime->threads_num();
378 #define COMPUTE_INT(TYPE, TOut)\
379 if (opType == MNN::BinaryOpOperation_##TYPE ) {\
380 Binary##TYPE<<<block_num, threads_num>>>((const int*)input, (const int*)(input1), (TOut*)output,\
381 size[0], size[1], size[2],\
382 srcStride[0], srcStride[1], srcStride[2],\
383 srcStride1[0], srcStride1[1], srcStride1[2],\
384 dstStride[0], dstStride[1], dstStride[2]);\
385 return;\
386 }\
387
388 COMPUTE_INT(ADD, int);
389 COMPUTE_INT(SUB, int);
390 COMPUTE_INT(MUL, int);
391 COMPUTE_INT(DIV, int);
392 COMPUTE_INT(MINIMUM, int);
393 COMPUTE_INT(MAXIMUM, int);
394 COMPUTE_INT(GREATER, int);
395 COMPUTE_INT(LESS, int);
396 COMPUTE_INT(LESS_EQUAL, int);
397 COMPUTE_INT(GREATER_EQUAL, int);
398 COMPUTE_INT(EQUAL, int);
399 COMPUTE_INT(NOTEQUAL, int);
400 COMPUTE_INT(SquaredDifference, int);
401 COMPUTE_INT(MOD, int);
402 COMPUTE_INT(LOGICALOR, int);
403 }
404
405
BinaryBlit(uint8_t * output,const uint8_t * input,const uint8_t * input1,const int32_t * size,const int32_t * srcStride,const int32_t * srcStride1,const int32_t * dstStride,halide_type_t type,CUDARuntime * runtime,int opType)406 void BinaryBlit(uint8_t* output, const uint8_t* input, const uint8_t* input1, const int32_t* size, const int32_t* srcStride, const int32_t* srcStride1, const int32_t* dstStride, halide_type_t type, CUDARuntime* runtime, int opType) {
407 if (type.code == halide_type_float) {
408 BinaryBlitTemplateFloat(output, input, input1, size, srcStride, srcStride1, dstStride, type.bytes(), runtime, opType);
409 } else if (type.code == halide_type_int) {
410 BinaryBlitTemplateInt32(output, input, input1, size, srcStride, srcStride1, dstStride, type.bytes(), runtime, opType);
411 }
412 }
413
414
415 }// namespace CUDA
416 }// namespace MNN
417