1import os
2import re
3import ast
4
5
6def make_c_files():
7    modules = ["reduce", "move", "nonreduce", "nonreduce_axis"]
8    dirpath = os.path.dirname(__file__)
9    for module in modules:
10        template_file = os.path.join(dirpath, module + "_template.c")
11        target_file = os.path.join(dirpath, module + ".c")
12
13        if (
14            os.path.exists(target_file)
15            and os.stat(template_file).st_mtime < os.stat(target_file).st_mtime
16        ):
17            continue
18
19        with open(template_file, "r") as f:
20            src_str = f.read()
21        src_str = template(src_str)
22        if len(src_str) and src_str[-1] != "\n":
23            src_str += "\n"
24        with open(target_file, "w") as f:
25            f.write(src_str)
26
27
28def template(src_str):
29    src_list = src_str.splitlines()
30    src_list = repeat_templating(src_list)
31    src_list = dtype_templating(src_list)
32    src_list = string_templating(src_list)
33    src_str = "\n".join(src_list)
34    src_str = re.sub(r"\n\s*\n\s*\n", r"\n\n", src_str)
35    return src_str
36
37
38# repeat --------------------------------------------------------------------
39
40REPEAT_BEGIN = r"^/\*\s*repeat\s*=\s*"
41REPEAT_END = r"^/\*\s*repeat end"
42COMMENT_END = r".*\*\/.*"
43
44
45def repeat_templating(lines):
46    index = 0
47    while True:
48        idx0, idx1 = next_block(lines, index, REPEAT_BEGIN, REPEAT_END)
49        if idx0 is None:
50            break
51        func_list = lines[idx0:idx1]
52        func_list = expand_functions_repeat(func_list)
53        # the +1 below is to skip the /* repeat end */ line
54        lines = lines[:idx0] + func_list + lines[idx1 + 1 :]
55        index = idx0
56    return lines
57
58
59def expand_functions_repeat(lines):
60    idx = first_occurence(COMMENT_END, lines)
61    repeat_dict = repeat_info(lines[: idx + 1])
62    lines = lines[idx + 1 :]
63    func_str = "\n".join(lines)
64    func_list = expand_repeat(func_str, repeat_dict)
65    return func_list
66
67
68def repeat_info(lines):
69    line = "".join(lines)
70    repeat = re.findall(r"\{.*\}", line)
71    repeat_dict = ast.literal_eval(repeat[0])
72    return repeat_dict
73
74
75def expand_repeat(func_str, repeat_dict):
76    nrepeats = [len(repeat_dict[key]) for key in repeat_dict]
77    if len(set(nrepeats)) != 1:
78        raise ValueError("All repeat lists must be the same length")
79    nrepeat = nrepeats[0]
80    func_list = []
81    for i in range(nrepeat):
82        f = func_str[:]
83        for key in repeat_dict:
84            f = f.replace(key, repeat_dict[key][i])
85        func_list.append("\n" + f)
86    func_list = ("".join(func_list)).splitlines()
87    return func_list
88
89
90# dtype ---------------------------------------------------------------------
91
92DTYPE_BEGIN = r"^/\*\s*dtype\s*=\s*"
93DTYPE_END = r"^/\*\s*dtype end"
94
95
96def dtype_templating(lines):
97    index = 0
98    while True:
99        idx0, idx1 = next_block(lines, index, DTYPE_BEGIN, DTYPE_END)
100        if idx0 is None:
101            break
102        func_list = lines[idx0:idx1]
103        func_list = expand_functions_dtype(func_list)
104        # the +1 below is to skip the /* dtype end */ line
105        lines = lines[:idx0] + func_list + lines[idx1 + 1 :]
106        index = idx0
107    return lines
108
109
110def expand_functions_dtype(lines):
111    idx = first_occurence(COMMENT_END, lines)
112    dtypes = dtype_info(lines[: idx + 1])
113    lines = lines[idx + 1 :]
114    func_str = "\n".join(lines)
115    func_list = expand_dtypes(func_str, dtypes)
116    return func_list
117
118
119def dtype_info(lines):
120    line = "".join(lines)
121    dtypes = re.findall(r"\[.*\]", line)
122    if len(dtypes) != 1:
123        raise ValueError("expecting exactly one dtype specification")
124    dtypes = ast.literal_eval(dtypes[0])
125    return dtypes
126
127
128def expand_dtypes(func_str, dtypes):
129    if "DTYPE" not in func_str:
130        raise ValueError("cannot find dtype marker")
131    func_list = []
132    for dtype in dtypes:
133        f = func_str[:]
134        for i, dt in enumerate(dtype):
135            f = f.replace("DTYPE%d" % i, dt)
136            if i > 0:
137                f = f + "\n"
138        func_list.append("\n\n" + f)
139    return func_list
140
141
142# multiline strings ---------------------------------------------------------
143
144STRING_BEGIN = r".*MULTILINE STRING BEGIN.*"
145STRING_END = r".*MULTILINE STRING END.*"
146
147
148def string_templating(lines):
149    index = 0
150    while True:
151        idx0, idx1 = next_block(lines, index, STRING_BEGIN, STRING_END)
152        if idx0 is None:
153            break
154        str_list = lines[idx0 + 1 : idx1]
155        str_list = quote_string(str_list)
156        lines = lines[:idx0] + str_list + lines[idx1 + 1 :]
157        index = idx0
158    return lines
159
160
161def quote_string(lines):
162    for i in range(len(lines)):
163        lines[i] = '"' + lines[i] + r"\n" + '"'
164    lines[-1] = lines[-1] + ";"
165    return lines
166
167
168# utility -------------------------------------------------------------------
169
170
171def first_occurence(pattern, lines):
172    for i in range(len(lines)):
173        if re.match(pattern, lines[i]):
174            return i
175    raise ValueError("`pattern` not found")
176
177
178def next_block(lines, index, begine_pattern, end_pattern):
179    idx = None
180    for i in range(index, len(lines)):
181        line = lines[i]
182        if re.match(begine_pattern, line):
183            idx = i
184        elif re.match(end_pattern, line):
185            if idx is None:
186                raise ValueError("found end of function before beginning")
187            return idx, i
188    return None, None
189