1#!/usr/bin/env python
2#===============================================================================
3# Copyright 2021 Intel Corporation
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#===============================================================================
17
18from __future__ import print_function
19
20import os
21import re
22import sys
23
24
25def is_special_tag(t):
26    return 'undef' in t or 'any' in t or 'tag_last' in t
27
28
29def is_abc_tag(t):
30    if is_special_tag(t):
31        return False
32    t = t.replace('dnnl_', '')
33    for c in t:
34        c = c.lower()
35        if c.isdigit():
36            continue
37        if c.isalpha():
38            if 'a' <= c and c <= 'l':
39                continue
40        return False
41    return True
42
43
44class Tag:
45    def __init__(self, line):
46        m = re.match(r'\s*(\w+)\s*=\s*(\w+)', line)
47        if m:
48            self.lhs = m.group(1)
49            self.rhs = m.group(2)
50        else:
51            m = re.match(r'\s*(\w+)', line)
52            self.lhs = m.group(1)
53            self.rhs = None
54
55        self.is_special = is_special_tag(self.lhs)
56        self.is_abc = is_abc_tag(self.lhs)
57        if self.is_special:
58            self.rhs = None
59        elif not self.rhs:
60            assert self.is_abc, ('Expected abc-tag: %s' % line)
61
62    def lhs_base_tag(self):
63        for s in ['undef', 'any', 'last']:
64            if s in self.lhs:
65                return s
66        return self.lhs.replace('dnnl_', '')
67
68    def rhs_base_tag(self):
69        return self.rhs.replace('dnnl_', '')
70
71    def __str__(self):
72        return str((self.lhs, self.rhs))
73
74
75def usage():
76    print('''\
77Usage: %s
78
79Updates dnnl.hpp header with missing format tags from dnnl_types.h''' %
80          sys.argv[0])
81    sys.exit(1)
82
83
84for arg in sys.argv:
85    if '-help' in arg:
86        usage()
87
88script_root = os.path.dirname(os.path.realpath(__file__))
89
90dnnl_types_h_path = '%s/../include/oneapi/dnnl/dnnl_types.h' % script_root
91dnnl_hpp_path = '%s/../include/oneapi/dnnl/dnnl.hpp' % script_root
92
93c_tags = []
94cpp_tags = []
95
96# Parse tags from dnnl_types.h
97with open(dnnl_types_h_path) as f:
98    s = f.read()
99    m = re.search(r'.*enum(.*?)dnnl_format_tag_t', s, re.S)
100    lines = [l for l in m.group(1).split('\n') if l.strip().startswith('dnnl')]
101    for l in lines:
102        c_tags.append(Tag(l))
103
104# Parse tags from dnnl.hpp
105with open(dnnl_hpp_path) as f:
106    dnnl_hpp_contents = f.read()
107    m = re.search(r'(enum class format_tag.*?)};', dnnl_hpp_contents, re.S)
108    dnnl_hpp_format_tag = m.group(1)
109    lines = [
110        l for l in dnnl_hpp_format_tag.split('\n')
111        if l.strip() and '=' in l.strip()
112    ]
113    for l in lines:
114        cpp_tags.append(Tag(l))
115
116# Validate dnnl.hpp tags
117for cpp_tag in cpp_tags:
118    if cpp_tag.is_special:
119        continue
120    if cpp_tag.rhs:
121        if cpp_tag.lhs_base_tag() == cpp_tag.rhs_base_tag():
122            continue
123        tags = [t for t in c_tags if t.lhs_base_tag() == cpp_tag.lhs_base_tag()]
124        if tags:
125            if cpp_tag.rhs_base_tag() == tags[0].rhs_base_tag():
126                continue
127        print('Can\'t validate tag: %s' % cpp_tag)
128
129# Find missing aliases in dnnl.hpp
130missing_dnnl_hpp_tag_lines = []
131for c_tag in c_tags:
132    if c_tag.is_special:
133        continue
134    cpp_found = [
135        t for t in cpp_tags if t.lhs_base_tag() == c_tag.lhs_base_tag()
136    ]
137    if not cpp_found:
138        base = c_tag.lhs_base_tag()
139        line = '        %s = dnnl_%s,' % (base, base)
140        missing_dnnl_hpp_tag_lines.append(line)
141
142if not missing_dnnl_hpp_tag_lines:
143    exit(0)
144
145old_dnnl_hpp_format_tag = dnnl_hpp_format_tag
146dnnl_hpp_format_tag = dnnl_hpp_format_tag.rstrip()
147dnnl_hpp_format_tag += '\n' + '\n'.join(missing_dnnl_hpp_tag_lines) + '\n    '
148
149dnnl_hpp_contents = dnnl_hpp_contents.replace(old_dnnl_hpp_format_tag,
150                                              dnnl_hpp_format_tag)
151
152# Update dnnl.hpp
153with open(dnnl_hpp_path, 'w') as f:
154    f.write(dnnl_hpp_contents)
155