1import os 2import re 3import ast 4 5 6def make_c_files(): 7 modules = ["reduce", "move", "nonreduce", "nonreduce_axis"] 8 dirpath = os.path.dirname(__file__) 9 for module in modules: 10 template_file = os.path.join(dirpath, module + "_template.c") 11 target_file = os.path.join(dirpath, module + ".c") 12 13 if ( 14 os.path.exists(target_file) 15 and os.stat(template_file).st_mtime < os.stat(target_file).st_mtime 16 ): 17 continue 18 19 with open(template_file, "r") as f: 20 src_str = f.read() 21 src_str = template(src_str) 22 if len(src_str) and src_str[-1] != "\n": 23 src_str += "\n" 24 with open(target_file, "w") as f: 25 f.write(src_str) 26 27 28def template(src_str): 29 src_list = src_str.splitlines() 30 src_list = repeat_templating(src_list) 31 src_list = dtype_templating(src_list) 32 src_list = string_templating(src_list) 33 src_str = "\n".join(src_list) 34 src_str = re.sub(r"\n\s*\n\s*\n", r"\n\n", src_str) 35 return src_str 36 37 38# repeat -------------------------------------------------------------------- 39 40REPEAT_BEGIN = r"^/\*\s*repeat\s*=\s*" 41REPEAT_END = r"^/\*\s*repeat end" 42COMMENT_END = r".*\*\/.*" 43 44 45def repeat_templating(lines): 46 index = 0 47 while True: 48 idx0, idx1 = next_block(lines, index, REPEAT_BEGIN, REPEAT_END) 49 if idx0 is None: 50 break 51 func_list = lines[idx0:idx1] 52 func_list = expand_functions_repeat(func_list) 53 # the +1 below is to skip the /* repeat end */ line 54 lines = lines[:idx0] + func_list + lines[idx1 + 1 :] 55 index = idx0 56 return lines 57 58 59def expand_functions_repeat(lines): 60 idx = first_occurence(COMMENT_END, lines) 61 repeat_dict = repeat_info(lines[: idx + 1]) 62 lines = lines[idx + 1 :] 63 func_str = "\n".join(lines) 64 func_list = expand_repeat(func_str, repeat_dict) 65 return func_list 66 67 68def repeat_info(lines): 69 line = "".join(lines) 70 repeat = re.findall(r"\{.*\}", line) 71 repeat_dict = ast.literal_eval(repeat[0]) 72 return repeat_dict 73 74 75def expand_repeat(func_str, repeat_dict): 76 nrepeats = [len(repeat_dict[key]) for key in repeat_dict] 77 if len(set(nrepeats)) != 1: 78 raise ValueError("All repeat lists must be the same length") 79 nrepeat = nrepeats[0] 80 func_list = [] 81 for i in range(nrepeat): 82 f = func_str[:] 83 for key in repeat_dict: 84 f = f.replace(key, repeat_dict[key][i]) 85 func_list.append("\n" + f) 86 func_list = ("".join(func_list)).splitlines() 87 return func_list 88 89 90# dtype --------------------------------------------------------------------- 91 92DTYPE_BEGIN = r"^/\*\s*dtype\s*=\s*" 93DTYPE_END = r"^/\*\s*dtype end" 94 95 96def dtype_templating(lines): 97 index = 0 98 while True: 99 idx0, idx1 = next_block(lines, index, DTYPE_BEGIN, DTYPE_END) 100 if idx0 is None: 101 break 102 func_list = lines[idx0:idx1] 103 func_list = expand_functions_dtype(func_list) 104 # the +1 below is to skip the /* dtype end */ line 105 lines = lines[:idx0] + func_list + lines[idx1 + 1 :] 106 index = idx0 107 return lines 108 109 110def expand_functions_dtype(lines): 111 idx = first_occurence(COMMENT_END, lines) 112 dtypes = dtype_info(lines[: idx + 1]) 113 lines = lines[idx + 1 :] 114 func_str = "\n".join(lines) 115 func_list = expand_dtypes(func_str, dtypes) 116 return func_list 117 118 119def dtype_info(lines): 120 line = "".join(lines) 121 dtypes = re.findall(r"\[.*\]", line) 122 if len(dtypes) != 1: 123 raise ValueError("expecting exactly one dtype specification") 124 dtypes = ast.literal_eval(dtypes[0]) 125 return dtypes 126 127 128def expand_dtypes(func_str, dtypes): 129 if "DTYPE" not in func_str: 130 raise ValueError("cannot find dtype marker") 131 func_list = [] 132 for dtype in dtypes: 133 f = func_str[:] 134 for i, dt in enumerate(dtype): 135 f = f.replace("DTYPE%d" % i, dt) 136 if i > 0: 137 f = f + "\n" 138 func_list.append("\n\n" + f) 139 return func_list 140 141 142# multiline strings --------------------------------------------------------- 143 144STRING_BEGIN = r".*MULTILINE STRING BEGIN.*" 145STRING_END = r".*MULTILINE STRING END.*" 146 147 148def string_templating(lines): 149 index = 0 150 while True: 151 idx0, idx1 = next_block(lines, index, STRING_BEGIN, STRING_END) 152 if idx0 is None: 153 break 154 str_list = lines[idx0 + 1 : idx1] 155 str_list = quote_string(str_list) 156 lines = lines[:idx0] + str_list + lines[idx1 + 1 :] 157 index = idx0 158 return lines 159 160 161def quote_string(lines): 162 for i in range(len(lines)): 163 lines[i] = '"' + lines[i] + r"\n" + '"' 164 lines[-1] = lines[-1] + ";" 165 return lines 166 167 168# utility ------------------------------------------------------------------- 169 170 171def first_occurence(pattern, lines): 172 for i in range(len(lines)): 173 if re.match(pattern, lines[i]): 174 return i 175 raise ValueError("`pattern` not found") 176 177 178def next_block(lines, index, begine_pattern, end_pattern): 179 idx = None 180 for i in range(index, len(lines)): 181 line = lines[i] 182 if re.match(begine_pattern, line): 183 idx = i 184 elif re.match(end_pattern, line): 185 if idx is None: 186 raise ValueError("found end of function before beginning") 187 return idx, i 188 return None, None 189