1# coding=utf-8
2# Permission is hereby granted, free of charge, to any person obtaining a
3# copy of this software and associated documentation files (the "Software"),
4# to deal in the Software without restriction, including without limitation
5# the rights to use, copy, modify, merge, publish, distribute, sublicense,
6# and/or sell copies of the Software, and to permit persons to whom the
7# Software is furnished to do so, subject to the following conditions:
8#
9# The above copyright notice and this permission notice (including the next
10# paragraph) shall be included in all copies or substantial portions of the
11# Software.
12#
13# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
16# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19# SOFTWARE.
20
21import itertools
22import os
23import random
24import textwrap
25
26from modules import utils
27from genclbuiltins import MAX_VALUES
28
29TYPES = {
30    'char': 'uchar',
31    'uchar': 'uchar',
32    'short': 'ushort',
33    'ushort': 'ushort',
34    'half': 'ushort',
35    'int': 'uint',
36    'uint': 'uint',
37    'float': 'uint',
38    'long': 'ulong',
39    'ulong': 'ulong',
40    'double': 'ulong'
41}
42
43VEC_SIZES = ['2', '4', '8', '16']
44ELEMENTS = 8
45
46DIR_NAME = os.path.join("cl", "builtin", "misc")
47
48
49def gen_array(size, m):
50    return [random.randint(0, m) for i in range(size)]
51
52
53def permute(data1, data2, mask, ssize, dsize):
54    ret = []
55    for i, m in enumerate(mask):
56        src = data1 if (m % (2 * ssize)) < ssize else data2
57        ret.append(src[(m % ssize) + ((i // dsize) * ssize)])
58    return ret
59
60
61def ext_req(type_name):
62    if type_name[:6] == "double":
63        return "require_device_extensions: cl_khr_fp64"
64    if type_name[:4] == "half":
65        return "require_device_extensions: cl_khr_fp16"
66    return ""
67
68
69def print_config(f, type_name, utype_name):
70    f.write(textwrap.dedent(("""\
71    /*!
72    [config]
73    name: shuffle2 {type_name} {utype_name}
74    dimensions: 1
75    """ + ext_req(type_name))
76    .format(type_name=type_name, utype_name=utype_name)))
77
78
79def begin_test(type_name, utype_name):
80    fileName = os.path.join(DIR_NAME, 'builtin-shuffle2-{}-{}.cl'.format(type_name, utype_name))
81    print(fileName)
82    f = open(fileName, 'w')
83    print_config(f, type_name, utype_name)
84    return f
85
86
87def main():
88    random.seed(0)
89    utils.safe_makedirs(DIR_NAME)
90
91    for t, ut in TYPES.items():
92        f = begin_test(t, ut)
93        for ss, ds in itertools.product(VEC_SIZES, VEC_SIZES):
94            ssize = int(ss) * ELEMENTS
95            dsize = int(ds) * ELEMENTS
96            stype_name = t + ss
97            dtype_name = t + ds
98            utype_name = ut + ds
99            data1 = gen_array(ssize, MAX_VALUES['ushort'])
100            data2 = gen_array(ssize, MAX_VALUES['ushort'])
101            mask = gen_array(dsize, MAX_VALUES[ut])
102            perm = permute(data1, data2, mask, int(ss), int(ds))
103            f.write(textwrap.dedent("""
104            [test]
105            name: shuffle2 {stype_name} {utype_name}
106            global_size: {elements} 0 0
107            kernel_name: test_shuffle2_{stype_name}_{utype_name}
108            arg_out: 0 buffer {dtype_name}[{elements}] {perm}
109            arg_in:  1 buffer {stype_name}[{elements}] {data1}
110            arg_in:  2 buffer {stype_name}[{elements}] {data2}
111            arg_in:  3 buffer {utype_name}[{elements}] {mask}
112            """.format(stype_name=stype_name, utype_name=utype_name,
113                       dtype_name=dtype_name, elements=ELEMENTS,
114                       perm=' '.join([str(x) for x in perm]),
115                       data1=' '.join([str(x) for x in data1]),
116                       data2=' '.join([str(x) for x in data2]),
117                       mask=' '.join([str(x) for x in mask]))))
118
119        f.write(textwrap.dedent("""!*/"""))
120
121        if t == "double":
122            f.write(textwrap.dedent("""
123            #pragma OPENCL EXTENSION cl_khr_fp64: enable
124            """))
125        if t == "half":
126            f.write(textwrap.dedent("""
127            #pragma OPENCL EXTENSION cl_khr_fp16: enable
128            """))
129
130        for ss, ds in itertools.product(VEC_SIZES, VEC_SIZES):
131            type_name = t + ss
132            utype_name = ut + ds
133            f.write(textwrap.dedent("""
134            kernel void test_shuffle2_{type_name}{ssize}_{utype_name}{dsize}(global {type_name}* out, global {type_name}* in1, global {type_name}* in2, global {utype_name}* mask) {{
135                vstore{dsize}(shuffle2(vload{ssize}(get_global_id(0), in1), vload{ssize}(get_global_id(0), in2), vload{dsize}(get_global_id(0), mask)), get_global_id(0), out);
136            }}
137            """.format(type_name=t, utype_name=ut, ssize=ss, dsize=ds)))
138
139        f.close()
140
141
142if __name__ == '__main__':
143    main()
144