1# -----------------------------------------------------------------------------
2# cpp.py
3#
4# Author:  David Beazley (http://www.dabeaz.com)
5# Copyright (C) 2017
6# All rights reserved
7#
8# This module implements an ANSI-C style lexical preprocessor for PLY.
9# -----------------------------------------------------------------------------
10import sys
11
12# Some Python 3 compatibility shims
13if sys.version_info.major < 3:
14    STRING_TYPES = (str, unicode)
15else:
16    STRING_TYPES = str
17    xrange = range
18
19# -----------------------------------------------------------------------------
20# Default preprocessor lexer definitions.   These tokens are enough to get
21# a basic preprocessor working.   Other modules may import these if they want
22# -----------------------------------------------------------------------------
23
24tokens = (
25   'CPP_ID','CPP_INTEGER', 'CPP_FLOAT', 'CPP_STRING', 'CPP_CHAR', 'CPP_WS', 'CPP_COMMENT1', 'CPP_COMMENT2', 'CPP_POUND','CPP_DPOUND'
26)
27
28literals = "+-*/%|&~^<>=!?()[]{}.,;:\\\'\""
29
30# Whitespace
31def t_CPP_WS(t):
32    r'\s+'
33    t.lexer.lineno += t.value.count("\n")
34    return t
35
36t_CPP_POUND = r'\#'
37t_CPP_DPOUND = r'\#\#'
38
39# Identifier
40t_CPP_ID = r'[A-Za-z_][\w_]*'
41
42# Integer literal
43def CPP_INTEGER(t):
44    r'(((((0x)|(0X))[0-9a-fA-F]+)|(\d+))([uU][lL]|[lL][uU]|[uU]|[lL])?)'
45    return t
46
47t_CPP_INTEGER = CPP_INTEGER
48
49# Floating literal
50t_CPP_FLOAT = r'((\d+)(\.\d+)(e(\+|-)?(\d+))? | (\d+)e(\+|-)?(\d+))([lL]|[fF])?'
51
52# String literal
53def t_CPP_STRING(t):
54    r'\"([^\\\n]|(\\(.|\n)))*?\"'
55    t.lexer.lineno += t.value.count("\n")
56    return t
57
58# Character constant 'c' or L'c'
59def t_CPP_CHAR(t):
60    r'(L)?\'([^\\\n]|(\\(.|\n)))*?\''
61    t.lexer.lineno += t.value.count("\n")
62    return t
63
64# Comment
65def t_CPP_COMMENT1(t):
66    r'(/\*(.|\n)*?\*/)'
67    ncr = t.value.count("\n")
68    t.lexer.lineno += ncr
69    # replace with one space or a number of '\n'
70    t.type = 'CPP_WS'; t.value = '\n' * ncr if ncr else ' '
71    return t
72
73# Line comment
74def t_CPP_COMMENT2(t):
75    r'(//.*?(\n|$))'
76    # replace with '/n'
77    t.type = 'CPP_WS'; t.value = '\n'
78    return t
79
80def t_error(t):
81    t.type = t.value[0]
82    t.value = t.value[0]
83    t.lexer.skip(1)
84    return t
85
86import re
87import copy
88import time
89import os.path
90
91# -----------------------------------------------------------------------------
92# trigraph()
93#
94# Given an input string, this function replaces all trigraph sequences.
95# The following mapping is used:
96#
97#     ??=    #
98#     ??/    \
99#     ??'    ^
100#     ??(    [
101#     ??)    ]
102#     ??!    |
103#     ??<    {
104#     ??>    }
105#     ??-    ~
106# -----------------------------------------------------------------------------
107
108_trigraph_pat = re.compile(r'''\?\?[=/\'\(\)\!<>\-]''')
109_trigraph_rep = {
110    '=':'#',
111    '/':'\\',
112    "'":'^',
113    '(':'[',
114    ')':']',
115    '!':'|',
116    '<':'{',
117    '>':'}',
118    '-':'~'
119}
120
121def trigraph(input):
122    return _trigraph_pat.sub(lambda g: _trigraph_rep[g.group()[-1]],input)
123
124# ------------------------------------------------------------------
125# Macro object
126#
127# This object holds information about preprocessor macros
128#
129#    .name      - Macro name (string)
130#    .value     - Macro value (a list of tokens)
131#    .arglist   - List of argument names
132#    .variadic  - Boolean indicating whether or not variadic macro
133#    .vararg    - Name of the variadic parameter
134#
135# When a macro is created, the macro replacement token sequence is
136# pre-scanned and used to create patch lists that are later used
137# during macro expansion
138# ------------------------------------------------------------------
139
140class Macro(object):
141    def __init__(self,name,value,arglist=None,variadic=False):
142        self.name = name
143        self.value = value
144        self.arglist = arglist
145        self.variadic = variadic
146        if variadic:
147            self.vararg = arglist[-1]
148        self.source = None
149
150# ------------------------------------------------------------------
151# Preprocessor object
152#
153# Object representing a preprocessor.  Contains macro definitions,
154# include directories, and other information
155# ------------------------------------------------------------------
156
157class Preprocessor(object):
158    def __init__(self,lexer=None):
159        if lexer is None:
160            lexer = lex.lexer
161        self.lexer = lexer
162        self.macros = { }
163        self.path = []
164        self.temp_path = []
165
166        # Probe the lexer for selected tokens
167        self.lexprobe()
168
169        tm = time.localtime()
170        self.define("__DATE__ \"%s\"" % time.strftime("%b %d %Y",tm))
171        self.define("__TIME__ \"%s\"" % time.strftime("%H:%M:%S",tm))
172        self.parser = None
173
174    # -----------------------------------------------------------------------------
175    # tokenize()
176    #
177    # Utility function. Given a string of text, tokenize into a list of tokens
178    # -----------------------------------------------------------------------------
179
180    def tokenize(self,text):
181        tokens = []
182        self.lexer.input(text)
183        while True:
184            tok = self.lexer.token()
185            if not tok: break
186            tokens.append(tok)
187        return tokens
188
189    # ---------------------------------------------------------------------
190    # error()
191    #
192    # Report a preprocessor error/warning of some kind
193    # ----------------------------------------------------------------------
194
195    def error(self,file,line,msg):
196        print("%s:%d %s" % (file,line,msg))
197
198    # ----------------------------------------------------------------------
199    # lexprobe()
200    #
201    # This method probes the preprocessor lexer object to discover
202    # the token types of symbols that are important to the preprocessor.
203    # If this works right, the preprocessor will simply "work"
204    # with any suitable lexer regardless of how tokens have been named.
205    # ----------------------------------------------------------------------
206
207    def lexprobe(self):
208
209        # Determine the token type for identifiers
210        self.lexer.input("identifier")
211        tok = self.lexer.token()
212        if not tok or tok.value != "identifier":
213            print("Couldn't determine identifier type")
214        else:
215            self.t_ID = tok.type
216
217        # Determine the token type for integers
218        self.lexer.input("12345")
219        tok = self.lexer.token()
220        if not tok or int(tok.value) != 12345:
221            print("Couldn't determine integer type")
222        else:
223            self.t_INTEGER = tok.type
224            self.t_INTEGER_TYPE = type(tok.value)
225
226        # Determine the token type for strings enclosed in double quotes
227        self.lexer.input("\"filename\"")
228        tok = self.lexer.token()
229        if not tok or tok.value != "\"filename\"":
230            print("Couldn't determine string type")
231        else:
232            self.t_STRING = tok.type
233
234        # Determine the token type for whitespace--if any
235        self.lexer.input("  ")
236        tok = self.lexer.token()
237        if not tok or tok.value != "  ":
238            self.t_SPACE = None
239        else:
240            self.t_SPACE = tok.type
241
242        # Determine the token type for newlines
243        self.lexer.input("\n")
244        tok = self.lexer.token()
245        if not tok or tok.value != "\n":
246            self.t_NEWLINE = None
247            print("Couldn't determine token for newlines")
248        else:
249            self.t_NEWLINE = tok.type
250
251        self.t_WS = (self.t_SPACE, self.t_NEWLINE)
252
253        # Check for other characters used by the preprocessor
254        chars = [ '<','>','#','##','\\','(',')',',','.']
255        for c in chars:
256            self.lexer.input(c)
257            tok = self.lexer.token()
258            if not tok or tok.value != c:
259                print("Unable to lex '%s' required for preprocessor" % c)
260
261    # ----------------------------------------------------------------------
262    # add_path()
263    #
264    # Adds a search path to the preprocessor.
265    # ----------------------------------------------------------------------
266
267    def add_path(self,path):
268        self.path.append(path)
269
270    # ----------------------------------------------------------------------
271    # group_lines()
272    #
273    # Given an input string, this function splits it into lines.  Trailing whitespace
274    # is removed.   Any line ending with \ is grouped with the next line.  This
275    # function forms the lowest level of the preprocessor---grouping into text into
276    # a line-by-line format.
277    # ----------------------------------------------------------------------
278
279    def group_lines(self,input):
280        lex = self.lexer.clone()
281        lines = [x.rstrip() for x in input.splitlines()]
282        for i in xrange(len(lines)):
283            j = i+1
284            while lines[i].endswith('\\') and (j < len(lines)):
285                lines[i] = lines[i][:-1]+lines[j]
286                lines[j] = ""
287                j += 1
288
289        input = "\n".join(lines)
290        lex.input(input)
291        lex.lineno = 1
292
293        current_line = []
294        while True:
295            tok = lex.token()
296            if not tok:
297                break
298            current_line.append(tok)
299            if tok.type in self.t_WS and '\n' in tok.value:
300                yield current_line
301                current_line = []
302
303        if current_line:
304            yield current_line
305
306    # ----------------------------------------------------------------------
307    # tokenstrip()
308    #
309    # Remove leading/trailing whitespace tokens from a token list
310    # ----------------------------------------------------------------------
311
312    def tokenstrip(self,tokens):
313        i = 0
314        while i < len(tokens) and tokens[i].type in self.t_WS:
315            i += 1
316        del tokens[:i]
317        i = len(tokens)-1
318        while i >= 0 and tokens[i].type in self.t_WS:
319            i -= 1
320        del tokens[i+1:]
321        return tokens
322
323
324    # ----------------------------------------------------------------------
325    # collect_args()
326    #
327    # Collects comma separated arguments from a list of tokens.   The arguments
328    # must be enclosed in parenthesis.  Returns a tuple (tokencount,args,positions)
329    # where tokencount is the number of tokens consumed, args is a list of arguments,
330    # and positions is a list of integers containing the starting index of each
331    # argument.  Each argument is represented by a list of tokens.
332    #
333    # When collecting arguments, leading and trailing whitespace is removed
334    # from each argument.
335    #
336    # This function properly handles nested parenthesis and commas---these do not
337    # define new arguments.
338    # ----------------------------------------------------------------------
339
340    def collect_args(self,tokenlist):
341        args = []
342        positions = []
343        current_arg = []
344        nesting = 1
345        tokenlen = len(tokenlist)
346
347        # Search for the opening '('.
348        i = 0
349        while (i < tokenlen) and (tokenlist[i].type in self.t_WS):
350            i += 1
351
352        if (i < tokenlen) and (tokenlist[i].value == '('):
353            positions.append(i+1)
354        else:
355            self.error(self.source,tokenlist[0].lineno,"Missing '(' in macro arguments")
356            return 0, [], []
357
358        i += 1
359
360        while i < tokenlen:
361            t = tokenlist[i]
362            if t.value == '(':
363                current_arg.append(t)
364                nesting += 1
365            elif t.value == ')':
366                nesting -= 1
367                if nesting == 0:
368                    if current_arg:
369                        args.append(self.tokenstrip(current_arg))
370                        positions.append(i)
371                    return i+1,args,positions
372                current_arg.append(t)
373            elif t.value == ',' and nesting == 1:
374                args.append(self.tokenstrip(current_arg))
375                positions.append(i+1)
376                current_arg = []
377            else:
378                current_arg.append(t)
379            i += 1
380
381        # Missing end argument
382        self.error(self.source,tokenlist[-1].lineno,"Missing ')' in macro arguments")
383        return 0, [],[]
384
385    # ----------------------------------------------------------------------
386    # macro_prescan()
387    #
388    # Examine the macro value (token sequence) and identify patch points
389    # This is used to speed up macro expansion later on---we'll know
390    # right away where to apply patches to the value to form the expansion
391    # ----------------------------------------------------------------------
392
393    def macro_prescan(self,macro):
394        macro.patch     = []             # Standard macro arguments
395        macro.str_patch = []             # String conversion expansion
396        macro.var_comma_patch = []       # Variadic macro comma patch
397        i = 0
398        while i < len(macro.value):
399            if macro.value[i].type == self.t_ID and macro.value[i].value in macro.arglist:
400                argnum = macro.arglist.index(macro.value[i].value)
401                # Conversion of argument to a string
402                if i > 0 and macro.value[i-1].value == '#':
403                    macro.value[i] = copy.copy(macro.value[i])
404                    macro.value[i].type = self.t_STRING
405                    del macro.value[i-1]
406                    macro.str_patch.append((argnum,i-1))
407                    continue
408                # Concatenation
409                elif (i > 0 and macro.value[i-1].value == '##'):
410                    macro.patch.append(('c',argnum,i-1))
411                    del macro.value[i-1]
412                    continue
413                elif ((i+1) < len(macro.value) and macro.value[i+1].value == '##'):
414                    macro.patch.append(('c',argnum,i))
415                    i += 1
416                    continue
417                # Standard expansion
418                else:
419                    macro.patch.append(('e',argnum,i))
420            elif macro.value[i].value == '##':
421                if macro.variadic and (i > 0) and (macro.value[i-1].value == ',') and \
422                        ((i+1) < len(macro.value)) and (macro.value[i+1].type == self.t_ID) and \
423                        (macro.value[i+1].value == macro.vararg):
424                    macro.var_comma_patch.append(i-1)
425            i += 1
426        macro.patch.sort(key=lambda x: x[2],reverse=True)
427
428    # ----------------------------------------------------------------------
429    # macro_expand_args()
430    #
431    # Given a Macro and list of arguments (each a token list), this method
432    # returns an expanded version of a macro.  The return value is a token sequence
433    # representing the replacement macro tokens
434    # ----------------------------------------------------------------------
435
436    def macro_expand_args(self,macro,args):
437        # Make a copy of the macro token sequence
438        rep = [copy.copy(_x) for _x in macro.value]
439
440        # Make string expansion patches.  These do not alter the length of the replacement sequence
441
442        str_expansion = {}
443        for argnum, i in macro.str_patch:
444            if argnum not in str_expansion:
445                str_expansion[argnum] = ('"%s"' % "".join([x.value for x in args[argnum]])).replace("\\","\\\\")
446            rep[i] = copy.copy(rep[i])
447            rep[i].value = str_expansion[argnum]
448
449        # Make the variadic macro comma patch.  If the variadic macro argument is empty, we get rid
450        comma_patch = False
451        if macro.variadic and not args[-1]:
452            for i in macro.var_comma_patch:
453                rep[i] = None
454                comma_patch = True
455
456        # Make all other patches.   The order of these matters.  It is assumed that the patch list
457        # has been sorted in reverse order of patch location since replacements will cause the
458        # size of the replacement sequence to expand from the patch point.
459
460        expanded = { }
461        for ptype, argnum, i in macro.patch:
462            # Concatenation.   Argument is left unexpanded
463            if ptype == 'c':
464                rep[i:i+1] = args[argnum]
465            # Normal expansion.  Argument is macro expanded first
466            elif ptype == 'e':
467                if argnum not in expanded:
468                    expanded[argnum] = self.expand_macros(args[argnum])
469                rep[i:i+1] = expanded[argnum]
470
471        # Get rid of removed comma if necessary
472        if comma_patch:
473            rep = [_i for _i in rep if _i]
474
475        return rep
476
477
478    # ----------------------------------------------------------------------
479    # expand_macros()
480    #
481    # Given a list of tokens, this function performs macro expansion.
482    # The expanded argument is a dictionary that contains macros already
483    # expanded.  This is used to prevent infinite recursion.
484    # ----------------------------------------------------------------------
485
486    def expand_macros(self,tokens,expanded=None):
487        if expanded is None:
488            expanded = {}
489        i = 0
490        while i < len(tokens):
491            t = tokens[i]
492            if t.type == self.t_ID:
493                if t.value in self.macros and t.value not in expanded:
494                    # Yes, we found a macro match
495                    expanded[t.value] = True
496
497                    m = self.macros[t.value]
498                    if not m.arglist:
499                        # A simple macro
500                        ex = self.expand_macros([copy.copy(_x) for _x in m.value],expanded)
501                        for e in ex:
502                            e.lineno = t.lineno
503                        tokens[i:i+1] = ex
504                        i += len(ex)
505                    else:
506                        # A macro with arguments
507                        j = i + 1
508                        while j < len(tokens) and tokens[j].type in self.t_WS:
509                            j += 1
510                        if tokens[j].value == '(':
511                            tokcount,args,positions = self.collect_args(tokens[j:])
512                            if not m.variadic and len(args) !=  len(m.arglist):
513                                self.error(self.source,t.lineno,"Macro %s requires %d arguments" % (t.value,len(m.arglist)))
514                                i = j + tokcount
515                            elif m.variadic and len(args) < len(m.arglist)-1:
516                                if len(m.arglist) > 2:
517                                    self.error(self.source,t.lineno,"Macro %s must have at least %d arguments" % (t.value, len(m.arglist)-1))
518                                else:
519                                    self.error(self.source,t.lineno,"Macro %s must have at least %d argument" % (t.value, len(m.arglist)-1))
520                                i = j + tokcount
521                            else:
522                                if m.variadic:
523                                    if len(args) == len(m.arglist)-1:
524                                        args.append([])
525                                    else:
526                                        args[len(m.arglist)-1] = tokens[j+positions[len(m.arglist)-1]:j+tokcount-1]
527                                        del args[len(m.arglist):]
528
529                                # Get macro replacement text
530                                rep = self.macro_expand_args(m,args)
531                                rep = self.expand_macros(rep,expanded)
532                                for r in rep:
533                                    r.lineno = t.lineno
534                                tokens[i:j+tokcount] = rep
535                                i += len(rep)
536                    del expanded[t.value]
537                    continue
538                elif t.value == '__LINE__':
539                    t.type = self.t_INTEGER
540                    t.value = self.t_INTEGER_TYPE(t.lineno)
541
542            i += 1
543        return tokens
544
545    # ----------------------------------------------------------------------
546    # evalexpr()
547    #
548    # Evaluate an expression token sequence for the purposes of evaluating
549    # integral expressions.
550    # ----------------------------------------------------------------------
551
552    def evalexpr(self,tokens):
553        # tokens = tokenize(line)
554        # Search for defined macros
555        i = 0
556        while i < len(tokens):
557            if tokens[i].type == self.t_ID and tokens[i].value == 'defined':
558                j = i + 1
559                needparen = False
560                result = "0L"
561                while j < len(tokens):
562                    if tokens[j].type in self.t_WS:
563                        j += 1
564                        continue
565                    elif tokens[j].type == self.t_ID:
566                        if tokens[j].value in self.macros:
567                            result = "1L"
568                        else:
569                            result = "0L"
570                        if not needparen: break
571                    elif tokens[j].value == '(':
572                        needparen = True
573                    elif tokens[j].value == ')':
574                        break
575                    else:
576                        self.error(self.source,tokens[i].lineno,"Malformed defined()")
577                    j += 1
578                tokens[i].type = self.t_INTEGER
579                tokens[i].value = self.t_INTEGER_TYPE(result)
580                del tokens[i+1:j+1]
581            i += 1
582        tokens = self.expand_macros(tokens)
583        for i,t in enumerate(tokens):
584            if t.type == self.t_ID:
585                tokens[i] = copy.copy(t)
586                tokens[i].type = self.t_INTEGER
587                tokens[i].value = self.t_INTEGER_TYPE("0L")
588            elif t.type == self.t_INTEGER:
589                tokens[i] = copy.copy(t)
590                # Strip off any trailing suffixes
591                tokens[i].value = str(tokens[i].value)
592                while tokens[i].value[-1] not in "0123456789abcdefABCDEF":
593                    tokens[i].value = tokens[i].value[:-1]
594
595        expr = "".join([str(x.value) for x in tokens])
596        expr = expr.replace("&&"," and ")
597        expr = expr.replace("||"," or ")
598        expr = expr.replace("!"," not ")
599        try:
600            result = eval(expr)
601        except Exception:
602            self.error(self.source,tokens[0].lineno,"Couldn't evaluate expression")
603            result = 0
604        return result
605
606    # ----------------------------------------------------------------------
607    # parsegen()
608    #
609    # Parse an input string/
610    # ----------------------------------------------------------------------
611    def parsegen(self,input,source=None):
612
613        # Replace trigraph sequences
614        t = trigraph(input)
615        lines = self.group_lines(t)
616
617        if not source:
618            source = ""
619
620        self.define("__FILE__ \"%s\"" % source)
621
622        self.source = source
623        chunk = []
624        enable = True
625        iftrigger = False
626        ifstack = []
627
628        for x in lines:
629            for i,tok in enumerate(x):
630                if tok.type not in self.t_WS: break
631            if tok.value == '#':
632                # Preprocessor directive
633
634                # insert necessary whitespace instead of eaten tokens
635                for tok in x:
636                    if tok.type in self.t_WS and '\n' in tok.value:
637                        chunk.append(tok)
638
639                dirtokens = self.tokenstrip(x[i+1:])
640                if dirtokens:
641                    name = dirtokens[0].value
642                    args = self.tokenstrip(dirtokens[1:])
643                else:
644                    name = ""
645                    args = []
646
647                if name == 'define':
648                    if enable:
649                        for tok in self.expand_macros(chunk):
650                            yield tok
651                        chunk = []
652                        self.define(args)
653                elif name == 'include':
654                    if enable:
655                        for tok in self.expand_macros(chunk):
656                            yield tok
657                        chunk = []
658                        oldfile = self.macros['__FILE__']
659                        for tok in self.include(args):
660                            yield tok
661                        self.macros['__FILE__'] = oldfile
662                        self.source = source
663                elif name == 'undef':
664                    if enable:
665                        for tok in self.expand_macros(chunk):
666                            yield tok
667                        chunk = []
668                        self.undef(args)
669                elif name == 'ifdef':
670                    ifstack.append((enable,iftrigger))
671                    if enable:
672                        if not args[0].value in self.macros:
673                            enable = False
674                            iftrigger = False
675                        else:
676                            iftrigger = True
677                elif name == 'ifndef':
678                    ifstack.append((enable,iftrigger))
679                    if enable:
680                        if args[0].value in self.macros:
681                            enable = False
682                            iftrigger = False
683                        else:
684                            iftrigger = True
685                elif name == 'if':
686                    ifstack.append((enable,iftrigger))
687                    if enable:
688                        result = self.evalexpr(args)
689                        if not result:
690                            enable = False
691                            iftrigger = False
692                        else:
693                            iftrigger = True
694                elif name == 'elif':
695                    if ifstack:
696                        if ifstack[-1][0]:     # We only pay attention if outer "if" allows this
697                            if enable:         # If already true, we flip enable False
698                                enable = False
699                            elif not iftrigger:   # If False, but not triggered yet, we'll check expression
700                                result = self.evalexpr(args)
701                                if result:
702                                    enable  = True
703                                    iftrigger = True
704                    else:
705                        self.error(self.source,dirtokens[0].lineno,"Misplaced #elif")
706
707                elif name == 'else':
708                    if ifstack:
709                        if ifstack[-1][0]:
710                            if enable:
711                                enable = False
712                            elif not iftrigger:
713                                enable = True
714                                iftrigger = True
715                    else:
716                        self.error(self.source,dirtokens[0].lineno,"Misplaced #else")
717
718                elif name == 'endif':
719                    if ifstack:
720                        enable,iftrigger = ifstack.pop()
721                    else:
722                        self.error(self.source,dirtokens[0].lineno,"Misplaced #endif")
723                else:
724                    # Unknown preprocessor directive
725                    pass
726
727            else:
728                # Normal text
729                if enable:
730                    chunk.extend(x)
731
732        for tok in self.expand_macros(chunk):
733            yield tok
734        chunk = []
735
736    # ----------------------------------------------------------------------
737    # include()
738    #
739    # Implementation of file-inclusion
740    # ----------------------------------------------------------------------
741
742    def include(self,tokens):
743        # Try to extract the filename and then process an include file
744        if not tokens:
745            return
746        if tokens:
747            if tokens[0].value != '<' and tokens[0].type != self.t_STRING:
748                tokens = self.expand_macros(tokens)
749
750            if tokens[0].value == '<':
751                # Include <...>
752                i = 1
753                while i < len(tokens):
754                    if tokens[i].value == '>':
755                        break
756                    i += 1
757                else:
758                    print("Malformed #include <...>")
759                    return
760                filename = "".join([x.value for x in tokens[1:i]])
761                path = self.path + [""] + self.temp_path
762            elif tokens[0].type == self.t_STRING:
763                filename = tokens[0].value[1:-1]
764                path = self.temp_path + [""] + self.path
765            else:
766                print("Malformed #include statement")
767                return
768        for p in path:
769            iname = os.path.join(p,filename)
770            try:
771                data = open(iname,"r").read()
772                dname = os.path.dirname(iname)
773                if dname:
774                    self.temp_path.insert(0,dname)
775                for tok in self.parsegen(data,filename):
776                    yield tok
777                if dname:
778                    del self.temp_path[0]
779                break
780            except IOError:
781                pass
782        else:
783            print("Couldn't find '%s'" % filename)
784
785    # ----------------------------------------------------------------------
786    # define()
787    #
788    # Define a new macro
789    # ----------------------------------------------------------------------
790
791    def define(self,tokens):
792        if isinstance(tokens,STRING_TYPES):
793            tokens = self.tokenize(tokens)
794
795        linetok = tokens
796        try:
797            name = linetok[0]
798            if len(linetok) > 1:
799                mtype = linetok[1]
800            else:
801                mtype = None
802            if not mtype:
803                m = Macro(name.value,[])
804                self.macros[name.value] = m
805            elif mtype.type in self.t_WS:
806                # A normal macro
807                m = Macro(name.value,self.tokenstrip(linetok[2:]))
808                self.macros[name.value] = m
809            elif mtype.value == '(':
810                # A macro with arguments
811                tokcount, args, positions = self.collect_args(linetok[1:])
812                variadic = False
813                for a in args:
814                    if variadic:
815                        print("No more arguments may follow a variadic argument")
816                        break
817                    astr = "".join([str(_i.value) for _i in a])
818                    if astr == "...":
819                        variadic = True
820                        a[0].type = self.t_ID
821                        a[0].value = '__VA_ARGS__'
822                        variadic = True
823                        del a[1:]
824                        continue
825                    elif astr[-3:] == "..." and a[0].type == self.t_ID:
826                        variadic = True
827                        del a[1:]
828                        # If, for some reason, "." is part of the identifier, strip off the name for the purposes
829                        # of macro expansion
830                        if a[0].value[-3:] == '...':
831                            a[0].value = a[0].value[:-3]
832                        continue
833                    if len(a) > 1 or a[0].type != self.t_ID:
834                        print("Invalid macro argument")
835                        break
836                else:
837                    mvalue = self.tokenstrip(linetok[1+tokcount:])
838                    i = 0
839                    while i < len(mvalue):
840                        if i+1 < len(mvalue):
841                            if mvalue[i].type in self.t_WS and mvalue[i+1].value == '##':
842                                del mvalue[i]
843                                continue
844                            elif mvalue[i].value == '##' and mvalue[i+1].type in self.t_WS:
845                                del mvalue[i+1]
846                        i += 1
847                    m = Macro(name.value,mvalue,[x[0].value for x in args],variadic)
848                    self.macro_prescan(m)
849                    self.macros[name.value] = m
850            else:
851                print("Bad macro definition")
852        except LookupError:
853            print("Bad macro definition")
854
855    # ----------------------------------------------------------------------
856    # undef()
857    #
858    # Undefine a macro
859    # ----------------------------------------------------------------------
860
861    def undef(self,tokens):
862        id = tokens[0].value
863        try:
864            del self.macros[id]
865        except LookupError:
866            pass
867
868    # ----------------------------------------------------------------------
869    # parse()
870    #
871    # Parse input text.
872    # ----------------------------------------------------------------------
873    def parse(self,input,source=None,ignore={}):
874        self.ignore = ignore
875        self.parser = self.parsegen(input,source)
876
877    # ----------------------------------------------------------------------
878    # token()
879    #
880    # Method to return individual tokens
881    # ----------------------------------------------------------------------
882    def token(self):
883        try:
884            while True:
885                tok = next(self.parser)
886                if tok.type not in self.ignore: return tok
887        except StopIteration:
888            self.parser = None
889            return None
890
891if __name__ == '__main__':
892    import ply.lex as lex
893    lexer = lex.lex()
894
895    # Run a preprocessor
896    import sys
897    f = open(sys.argv[1])
898    input = f.read()
899
900    p = Preprocessor(lexer)
901    p.parse(input,sys.argv[1])
902    while True:
903        tok = p.token()
904        if not tok: break
905        print(p.source, tok)
906