1#ifdef MNN_SUPPORT_FP16
2#pragma OPENCL EXTENSION cl_khr_fp16 : enable
3#endif
4__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
5
6__kernel void binary(__private int global_dim0, __private int global_dim1,
7                         __read_only image2d_t input0, __read_only image2d_t input1,
8                         __write_only image2d_t output,
9                         __private const int4 shape,//[N,H,W,C4]
10                         __private const int2 isFull) {
11    int2 pos = (int2)(get_global_id(0), get_global_id(1));//WC4, NH
12
13    FLOAT4 in0, in1;
14    if (pos.x < global_dim0 && pos.y < global_dim1) {
15
16        if(isFull.x == 0) {
17            in0 = RI_F(input0, SAMPLER, (int2)(0, 0));
18            in0 = (FLOAT4)(in0.x, in0.x, in0.x, in0.x);
19        } else {
20            in0 = RI_F(input0, SAMPLER, pos);
21        }
22        if(isFull.y == 0) {
23            in1 = RI_F(input1, SAMPLER, (int2)(0, 0));
24            in1 = (FLOAT4)(in1.x, in1.x, in1.x, in1.x);
25        } else {
26            in1 = RI_F(input1, SAMPLER, pos);
27        }
28
29        FLOAT4 out = CONVERT_FLOAT4(OPERATOR);
30        WI_F(output, pos, out);
31    }
32}
33
34__kernel void binary_prelu(__read_only image2d_t input0, __read_only image2d_t input1, __write_only image2d_t output,
35                            int4 shape, int2 whInput1, int4 input1NHWCStep) {
36    int2 pos = (int2)(get_global_id(0), get_global_id(1));
37    int4 nhwc = (int4)(pos.y/shape.y, pos.y%shape.y, pos.x%shape.z, pos.x/shape.z);
38    if (nhwc.x < shape.x && nhwc.w < shape.w) {
39            int4 nhwc1 = nhwc * input1NHWCStep;
40            int2 pos1 = (int2)(nhwc1.w*whInput1.x+nhwc1.z, nhwc1.x*whInput1.y+nhwc1.y);
41            FLOAT4 in0 = RI_F(input0, SAMPLER, pos);
42            FLOAT4 in1 = RI_F(input1, SAMPLER, pos1);
43            FLOAT4 out = CONVERT_FLOAT4(OPERATOR);
44            WI_F(output, pos, out);
45        }
46}
47
48__kernel void imageCopy(__read_only image2d_t input, __write_only image2d_t output) {
49    const int2 pos = (int2)(get_global_id(0), get_global_id(1));
50    const int2 dim = get_image_dim(input);
51    if (pos.x >= dim.x && pos.y >= dim.y) {
52        return;
53    }
54    WI_F(output, pos, RI_F(input, SAMPLER, pos));
55}
56