1# -*- coding: utf-8 -*-
2#
3# include_checker.py
4#
5# This file is part of NEST.
6#
7# Copyright (C) 2004 The NEST Initiative
8#
9# NEST is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 2 of the License, or
12# (at your option) any later version.
13#
14# NEST is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with NEST.  If not, see <http://www.gnu.org/licenses/>.
21
22import os
23import re
24import sys
25
26"""
27This script suggest C/CPP include orders that conform to the NEST coding style
28guidelines. Call the script like (from NEST sources):
29
30For one file:
31    python3 extras/include_checker.py -nest $PWD -f nest/main.cpp
32
33For one directory:
34    python3 extras/include_checker.py -nest $PWD -d nest
35
36If everything is OK, or only few includes are in the wrong order, it will print
37something like:
38
39    Includes for main.cpp are OK! Includes in wrong order: 0
40
41If something is wrong, it will print the suggestion:
42
43    Includes for neststartup.h are WRONG! Includes in wrong order: 5
44
45    ##############################
46    Suggested includes for neststartup.h:
47    ##############################
48
49
50    // C includes:
51    #include <neurosim/pyneurosim.h>
52
53    // C++ includes:
54    #include <string>
55
56    // Generated includes:
57    #include "config.h"
58
59    // Includes from sli:
60    #include "datum.h"
61"""
62
63# We would like to have files that are not actually provided by
64# the NEST Initiative, e.g. implementing the Google Sparsetable,
65# to be exactly like they come from the upstream source.
66excludes_files = []
67
68
69class IncludeInfo():
70    filename = ""
71    name = ""
72    spiky = False
73    origin = "a_unknown"
74
75    def __init__(self, filename, name, spiky, all_headers):
76        self.filename = filename
77        self.name = name
78        self.spiky = spiky
79        self.set_origin(all_headers)
80
81    def is_header_include(self):
82        return (self.name.split('.')[0] == self.filename.split('.')[0] or
83                self.name.split('.')[0] == self.filename.split('_impl.')[0])
84
85    def is_cpp_include(self):
86        return (not self.name.endswith('.h') and
87                not self.name.endswith('.hpp') and self.spiky)
88
89    def is_c_include(self):
90        return self.name.endswith('.h') and self.spiky
91
92    def is_project_include(self):
93        return (not self.spiky and
94                (self.name.endswith('.h') or self.name.endswith('.hpp')))
95
96    def set_origin(self, includes):
97        for k, v in includes.iteritems():
98            if self.name in v:
99                self.origin = k
100                break
101
102    def cmp_value(self):
103        v = 8 if self.is_header_include() else 0
104        v += 4 if self.is_c_include() else 0
105        v += 2 if self.is_cpp_include() else 0
106        v += 1 if self.is_project_include() else 0
107        return v
108
109    def __cmp__(self, other):
110        s = self.cmp_value()
111        o = other.cmp_value()
112        val = o - s
113        if val == 0:
114            val = cmp(self.origin, other.origin)
115            if val == 0:
116                return cmp(self.name, other.name)
117            else:
118                return val
119        else:
120            return val
121
122    def to_string(self):
123        l_guard = '<' if self.spiky else '"'
124        r_guard = '>' if self.spiky else '"'
125        return '#include ' + l_guard + self.name + r_guard
126
127
128def all_includes(path):
129    result = {}
130    dirs = [d for d in next(os.walk(path))[1] if d[0] != '.']
131    for d in dirs:
132        for root, dirs, files in os.walk(os.path.join(path, d)):
133            tmp = [f for f in files if f.endswith(".h") or f.endswith(".hpp")]
134            if len(tmp) > 0:
135                result[d] = tmp
136
137    return result
138
139
140def create_include_info(line, filename, all_headers):
141    match = re.search('^#include ([<"])(.*)([>"])', line)
142    name = match.group(2)
143    spiky = match.group(1) == '<'
144    return IncludeInfo(filename, name, spiky, all_headers)
145
146
147def get_includes_from(file, all_headers):
148    includes = []
149    with open(file, 'r') as f:
150        for line in f:
151            if line.startswith('#include'):
152                includes += [create_include_info(line,
153                                                 os.path.basename(file),
154                                                 all_headers)]
155    return includes
156
157
158def is_include_order_ok(includes):
159    s_incs = sorted(includes)
160    return len(includes) - len([i for i, s in zip(includes, s_incs)
161                                if i.name == s.name])
162
163
164def print_includes(includes):
165    s_incs = sorted(includes)
166
167    is_c = False
168    is_cpp = False
169    origin = ""
170
171    for i in s_incs:
172        if not i.is_header_include():
173            if not is_c and i.is_c_include():
174                is_c = True
175                is_cpp = False
176                origin = ""
177                print("\n// C includes:")
178
179            if not is_cpp and i.is_cpp_include():
180                is_c = False
181                is_cpp = True
182                origin = ""
183                print("\n// C++ includes:")
184
185            if i.is_project_include() and origin != i.origin:
186                is_c = False
187                is_cpp = False
188                origin = i.origin
189                if i.origin == "a_unknown":
190                    print("\n// Generated includes:")
191                else:
192                    print("\n// Includes from " + i.origin + ":")
193
194        print(i.to_string())
195
196
197def process_source(path, f, all_headers, print_suggestion):
198    if f in excludes_files:
199        print("Not checking file " + f + " as it is in the exclude list. " +
200              "Please do not change the order of includes.")
201        return 0
202    includes = get_includes_from(os.path.join(path, f), all_headers)
203    order_ok = is_include_order_ok(includes)
204    if order_ok <= 2:
205        print("Includes for " + f + " are OK! Includes in wrong order: " +
206              str(order_ok))
207    if order_ok > 2:
208        print("Includes for " + f + " are WRONG! Includes in wrong order: " +
209              str(order_ok))
210        if print_suggestion:
211            print("\n##############################")
212            print("Suggested includes for " + f + ":")
213            print("##############################\n")
214            print_includes(includes)
215            print("\n##############################")
216
217    return order_ok
218
219
220def process_all_sources(path, all_headers, print_suggestion):
221    count = 0
222    for root, dirs, files in os.walk(path):
223        for f in files:
224            if re.search(r"\.h$|\.hpp$|\.c$|\.cc|\.cpp$", f):
225                # valid source file
226                count += process_source(root, f, all_headers, print_suggestion)
227        for d in dirs:
228            count += process_all_sources(os.path.join(root, d), all_headers,
229                                         print_suggestion)
230    return count
231
232
233def usage(exitcode):
234    print("Use like:")
235    print("  " + sys.argv[0] + " -nest <nest-base-dir>" +
236                               " (-f <filename> | -d <base-directory>)")
237    sys.exit(exitcode)
238
239
240if __name__ == '__main__':
241    print_suggestion = True
242    if len(sys.argv) != 5:
243        usage(1)
244
245    if sys.argv[1] == '-nest' and os.path.isdir(sys.argv[2]):
246        all_headers = all_includes(sys.argv[2])
247    else:
248        usage(2)
249
250    if sys.argv[3] == '-f' and os.path.isfile(sys.argv[4]):
251        path = os.path.dirname(sys.argv[4])
252        file = os.path.basename(sys.argv[4])
253        process_source(path, file, all_headers, print_suggestion)
254
255    elif sys.argv[3] == '-d' and os.path.isdir(sys.argv[4]):
256        dir = sys.argv[4]
257        process_all_sources(dir, all_headers, print_suggestion)
258
259    else:
260        usage(3)
261