1#!/usr/bin/env python 2#=============================================================================== 3# Copyright 2021 Intel Corporation 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16#=============================================================================== 17 18from __future__ import print_function 19 20import os 21import re 22import sys 23 24 25def is_special_tag(t): 26 return 'undef' in t or 'any' in t or 'tag_last' in t 27 28 29def is_abc_tag(t): 30 if is_special_tag(t): 31 return False 32 t = t.replace('dnnl_', '') 33 for c in t: 34 c = c.lower() 35 if c.isdigit(): 36 continue 37 if c.isalpha(): 38 if 'a' <= c and c <= 'l': 39 continue 40 return False 41 return True 42 43 44class Tag: 45 def __init__(self, line): 46 m = re.match(r'\s*(\w+)\s*=\s*(\w+)', line) 47 if m: 48 self.lhs = m.group(1) 49 self.rhs = m.group(2) 50 else: 51 m = re.match(r'\s*(\w+)', line) 52 self.lhs = m.group(1) 53 self.rhs = None 54 55 self.is_special = is_special_tag(self.lhs) 56 self.is_abc = is_abc_tag(self.lhs) 57 if self.is_special: 58 self.rhs = None 59 elif not self.rhs: 60 assert self.is_abc, ('Expected abc-tag: %s' % line) 61 62 def lhs_base_tag(self): 63 for s in ['undef', 'any', 'last']: 64 if s in self.lhs: 65 return s 66 return self.lhs.replace('dnnl_', '') 67 68 def rhs_base_tag(self): 69 return self.rhs.replace('dnnl_', '') 70 71 def __str__(self): 72 return str((self.lhs, self.rhs)) 73 74 75def usage(): 76 print('''\ 77Usage: %s 78 79Updates dnnl.hpp header with missing format tags from dnnl_types.h''' % 80 sys.argv[0]) 81 sys.exit(1) 82 83 84for arg in sys.argv: 85 if '-help' in arg: 86 usage() 87 88script_root = os.path.dirname(os.path.realpath(__file__)) 89 90dnnl_types_h_path = '%s/../include/oneapi/dnnl/dnnl_types.h' % script_root 91dnnl_hpp_path = '%s/../include/oneapi/dnnl/dnnl.hpp' % script_root 92 93c_tags = [] 94cpp_tags = [] 95 96# Parse tags from dnnl_types.h 97with open(dnnl_types_h_path) as f: 98 s = f.read() 99 m = re.search(r'.*enum(.*?)dnnl_format_tag_t', s, re.S) 100 lines = [l for l in m.group(1).split('\n') if l.strip().startswith('dnnl')] 101 for l in lines: 102 c_tags.append(Tag(l)) 103 104# Parse tags from dnnl.hpp 105with open(dnnl_hpp_path) as f: 106 dnnl_hpp_contents = f.read() 107 m = re.search(r'(enum class format_tag.*?)};', dnnl_hpp_contents, re.S) 108 dnnl_hpp_format_tag = m.group(1) 109 lines = [ 110 l for l in dnnl_hpp_format_tag.split('\n') 111 if l.strip() and '=' in l.strip() 112 ] 113 for l in lines: 114 cpp_tags.append(Tag(l)) 115 116# Validate dnnl.hpp tags 117for cpp_tag in cpp_tags: 118 if cpp_tag.is_special: 119 continue 120 if cpp_tag.rhs: 121 if cpp_tag.lhs_base_tag() == cpp_tag.rhs_base_tag(): 122 continue 123 tags = [t for t in c_tags if t.lhs_base_tag() == cpp_tag.lhs_base_tag()] 124 if tags: 125 if cpp_tag.rhs_base_tag() == tags[0].rhs_base_tag(): 126 continue 127 print('Can\'t validate tag: %s' % cpp_tag) 128 129# Find missing aliases in dnnl.hpp 130missing_dnnl_hpp_tag_lines = [] 131for c_tag in c_tags: 132 if c_tag.is_special: 133 continue 134 cpp_found = [ 135 t for t in cpp_tags if t.lhs_base_tag() == c_tag.lhs_base_tag() 136 ] 137 if not cpp_found: 138 base = c_tag.lhs_base_tag() 139 line = ' %s = dnnl_%s,' % (base, base) 140 missing_dnnl_hpp_tag_lines.append(line) 141 142if not missing_dnnl_hpp_tag_lines: 143 exit(0) 144 145old_dnnl_hpp_format_tag = dnnl_hpp_format_tag 146dnnl_hpp_format_tag = dnnl_hpp_format_tag.rstrip() 147dnnl_hpp_format_tag += '\n' + '\n'.join(missing_dnnl_hpp_tag_lines) + '\n ' 148 149dnnl_hpp_contents = dnnl_hpp_contents.replace(old_dnnl_hpp_format_tag, 150 dnnl_hpp_format_tag) 151 152# Update dnnl.hpp 153with open(dnnl_hpp_path, 'w') as f: 154 f.write(dnnl_hpp_contents) 155