1#ifdef MNN_SUPPORT_FP16
2#pragma OPENCL EXTENSION cl_khr_fp16 : enable
3#endif
4
5#define GLOBAL_SIZE_2_DIMS \
6    __private const int global_size_dim0, __private const int global_size_dim1,
7
8#define DEAL_NON_UNIFORM_DIM2(input1, input2)                       \
9    if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \
10        return;                                                     \
11    }
12
13__kernel void scale_buf(GLOBAL_SIZE_2_DIMS
14                        __global const FLOAT* input,
15                        __global const FLOAT* scale,
16#ifdef BIAS
17                        __global const FLOAT* bias,
18#endif
19                        __global FLOAT* output,
20                        __private const int4 shape) {//N, H, W, C4
21
22    const int out_w_c_idx = get_global_id(0);
23    const int out_h_b_idx = get_global_id(1);
24
25    DEAL_NON_UNIFORM_DIM2(out_w_c_idx, out_h_b_idx);
26
27    const int out_b_idx = out_h_b_idx / shape.y;
28    const int out_h_idx = out_h_b_idx % shape.y;
29    const int out_c_idx = out_w_c_idx / shape.z;
30    const int out_w_idx = out_w_c_idx % shape.z;
31
32    const int offset = (((out_b_idx * shape.w + out_c_idx) * shape.y + out_h_idx) * shape.z + out_w_idx) * 4;
33    FLOAT4 in_value    = vload4(0, input+offset);
34    FLOAT4 scale_value = vload4(out_c_idx, scale);
35#ifdef BIAS
36    FLOAT4 bias_value = vload4(out_c_idx, bias);
37    FLOAT4 out_value  = in_value * scale_value + bias_value;
38#else
39    FLOAT4 out_value  = in_value * scale_value;
40#endif
41    vstore4(out_value, 0, output+offset);
42}
43