1"""Flexible enumeration of C types."""
2from __future__ import division, print_function
3
4from Enumeration import *
5
6# TODO:
7
8#  - struct improvements (flexible arrays, packed &
9#    unpacked, alignment)
10#  - objective-c qualified id
11#  - anonymous / transparent unions
12#  - VLAs
13#  - block types
14#  - K&R functions
15#  - pass arguments of different types (test extension, transparent union)
16#  - varargs
17
18###
19# Actual type types
20
21class Type(object):
22    def isBitField(self):
23        return False
24
25    def isPaddingBitField(self):
26        return False
27
28    def getTypeName(self, printer):
29        name = 'T%d' % len(printer.types)
30        typedef = self.getTypedefDef(name, printer)
31        printer.addDeclaration(typedef)
32        return name
33
34class BuiltinType(Type):
35    def __init__(self, name, size, bitFieldSize=None):
36        self.name = name
37        self.size = size
38        self.bitFieldSize = bitFieldSize
39
40    def isBitField(self):
41        return self.bitFieldSize is not None
42
43    def isPaddingBitField(self):
44        return self.bitFieldSize is 0
45
46    def getBitFieldSize(self):
47        assert self.isBitField()
48        return self.bitFieldSize
49
50    def getTypeName(self, printer):
51        return self.name
52
53    def sizeof(self):
54        return self.size
55
56    def __str__(self):
57        return self.name
58
59class EnumType(Type):
60    unique_id = 0
61
62    def __init__(self, index, enumerators):
63        self.index = index
64        self.enumerators = enumerators
65        self.unique_id = self.__class__.unique_id
66        self.__class__.unique_id += 1
67
68    def getEnumerators(self):
69        result = ''
70        for i, init in enumerate(self.enumerators):
71            if i > 0:
72                result = result + ', '
73            result = result + 'enum%dval%d_%d' % (self.index, i, self.unique_id)
74            if init:
75                result = result + ' = %s' % (init)
76
77        return result
78
79    def __str__(self):
80        return 'enum { %s }' % (self.getEnumerators())
81
82    def getTypedefDef(self, name, printer):
83        return 'typedef enum %s { %s } %s;'%(name, self.getEnumerators(), name)
84
85class RecordType(Type):
86    def __init__(self, index, isUnion, fields):
87        self.index = index
88        self.isUnion = isUnion
89        self.fields = fields
90        self.name = None
91
92    def __str__(self):
93        def getField(t):
94            if t.isBitField():
95                return "%s : %d;" % (t, t.getBitFieldSize())
96            else:
97                return "%s;" % t
98
99        return '%s { %s }'%(('struct','union')[self.isUnion],
100                            ' '.join(map(getField, self.fields)))
101
102    def getTypedefDef(self, name, printer):
103        def getField(it):
104            i, t = it
105            if t.isBitField():
106                if t.isPaddingBitField():
107                    return '%s : 0;'%(printer.getTypeName(t),)
108                else:
109                    return '%s field%d : %d;'%(printer.getTypeName(t),i,
110                                               t.getBitFieldSize())
111            else:
112                return '%s field%d;'%(printer.getTypeName(t),i)
113        fields = [getField(f) for f in enumerate(self.fields)]
114        # Name the struct for more readable LLVM IR.
115        return 'typedef %s %s { %s } %s;'%(('struct','union')[self.isUnion],
116                                           name, ' '.join(fields), name)
117
118class ArrayType(Type):
119    def __init__(self, index, isVector, elementType, size):
120        if isVector:
121            # Note that for vectors, this is the size in bytes.
122            assert size > 0
123        else:
124            assert size is None or size >= 0
125        self.index = index
126        self.isVector = isVector
127        self.elementType = elementType
128        self.size = size
129        if isVector:
130            eltSize = self.elementType.sizeof()
131            assert not (self.size % eltSize)
132            self.numElements = self.size // eltSize
133        else:
134            self.numElements = self.size
135
136    def __str__(self):
137        if self.isVector:
138            return 'vector (%s)[%d]'%(self.elementType,self.size)
139        elif self.size is not None:
140            return '(%s)[%d]'%(self.elementType,self.size)
141        else:
142            return '(%s)[]'%(self.elementType,)
143
144    def getTypedefDef(self, name, printer):
145        elementName = printer.getTypeName(self.elementType)
146        if self.isVector:
147            return 'typedef %s %s __attribute__ ((vector_size (%d)));'%(elementName,
148                                                                        name,
149                                                                        self.size)
150        else:
151            if self.size is None:
152                sizeStr = ''
153            else:
154                sizeStr = str(self.size)
155            return 'typedef %s %s[%s];'%(elementName, name, sizeStr)
156
157class ComplexType(Type):
158    def __init__(self, index, elementType):
159        self.index = index
160        self.elementType = elementType
161
162    def __str__(self):
163        return '_Complex (%s)'%(self.elementType)
164
165    def getTypedefDef(self, name, printer):
166        return 'typedef _Complex %s %s;'%(printer.getTypeName(self.elementType), name)
167
168class FunctionType(Type):
169    def __init__(self, index, returnType, argTypes):
170        self.index = index
171        self.returnType = returnType
172        self.argTypes = argTypes
173
174    def __str__(self):
175        if self.returnType is None:
176            rt = 'void'
177        else:
178            rt = str(self.returnType)
179        if not self.argTypes:
180            at = 'void'
181        else:
182            at = ', '.join(map(str, self.argTypes))
183        return '%s (*)(%s)'%(rt, at)
184
185    def getTypedefDef(self, name, printer):
186        if self.returnType is None:
187            rt = 'void'
188        else:
189            rt = str(self.returnType)
190        if not self.argTypes:
191            at = 'void'
192        else:
193            at = ', '.join(map(str, self.argTypes))
194        return 'typedef %s (*%s)(%s);'%(rt, name, at)
195
196###
197# Type enumerators
198
199class TypeGenerator(object):
200    def __init__(self):
201        self.cache = {}
202
203    def setCardinality(self):
204        abstract
205
206    def get(self, N):
207        T = self.cache.get(N)
208        if T is None:
209            assert 0 <= N < self.cardinality
210            T = self.cache[N] = self.generateType(N)
211        return T
212
213    def generateType(self, N):
214        abstract
215
216class FixedTypeGenerator(TypeGenerator):
217    def __init__(self, types):
218        TypeGenerator.__init__(self)
219        self.types = types
220        self.setCardinality()
221
222    def setCardinality(self):
223        self.cardinality = len(self.types)
224
225    def generateType(self, N):
226        return self.types[N]
227
228# Factorial
229def fact(n):
230    result = 1
231    while n > 0:
232        result = result * n
233        n = n - 1
234    return result
235
236# Compute the number of combinations (n choose k)
237def num_combinations(n, k):
238    return fact(n) // (fact(k) * fact(n - k))
239
240# Enumerate the combinations choosing k elements from the list of values
241def combinations(values, k):
242    # From ActiveState Recipe 190465: Generator for permutations,
243    # combinations, selections of a sequence
244    if k==0: yield []
245    else:
246        for i in range(len(values)-k+1):
247            for cc in combinations(values[i+1:],k-1):
248                yield [values[i]]+cc
249
250class EnumTypeGenerator(TypeGenerator):
251    def __init__(self, values, minEnumerators, maxEnumerators):
252        TypeGenerator.__init__(self)
253        self.values = values
254        self.minEnumerators = minEnumerators
255        self.maxEnumerators = maxEnumerators
256        self.setCardinality()
257
258    def setCardinality(self):
259        self.cardinality = 0
260        for num in range(self.minEnumerators, self.maxEnumerators + 1):
261            self.cardinality += num_combinations(len(self.values), num)
262
263    def generateType(self, n):
264        # Figure out the number of enumerators in this type
265        numEnumerators = self.minEnumerators
266        valuesCovered = 0
267        while numEnumerators < self.maxEnumerators:
268            comb = num_combinations(len(self.values), numEnumerators)
269            if valuesCovered + comb > n:
270                break
271            numEnumerators = numEnumerators + 1
272            valuesCovered += comb
273
274        # Find the requested combination of enumerators and build a
275        # type from it.
276        i = 0
277        for enumerators in combinations(self.values, numEnumerators):
278            if i == n - valuesCovered:
279                return EnumType(n, enumerators)
280
281            i = i + 1
282
283        assert False
284
285class ComplexTypeGenerator(TypeGenerator):
286    def __init__(self, typeGen):
287        TypeGenerator.__init__(self)
288        self.typeGen = typeGen
289        self.setCardinality()
290
291    def setCardinality(self):
292        self.cardinality = self.typeGen.cardinality
293
294    def generateType(self, N):
295        return ComplexType(N, self.typeGen.get(N))
296
297class VectorTypeGenerator(TypeGenerator):
298    def __init__(self, typeGen, sizes):
299        TypeGenerator.__init__(self)
300        self.typeGen = typeGen
301        self.sizes = tuple(map(int,sizes))
302        self.setCardinality()
303
304    def setCardinality(self):
305        self.cardinality = len(self.sizes)*self.typeGen.cardinality
306
307    def generateType(self, N):
308        S,T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
309        return ArrayType(N, True, self.typeGen.get(T), self.sizes[S])
310
311class FixedArrayTypeGenerator(TypeGenerator):
312    def __init__(self, typeGen, sizes):
313        TypeGenerator.__init__(self)
314        self.typeGen = typeGen
315        self.sizes = tuple(size)
316        self.setCardinality()
317
318    def setCardinality(self):
319        self.cardinality = len(self.sizes)*self.typeGen.cardinality
320
321    def generateType(self, N):
322        S,T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
323        return ArrayType(N, false, self.typeGen.get(T), self.sizes[S])
324
325class ArrayTypeGenerator(TypeGenerator):
326    def __init__(self, typeGen, maxSize, useIncomplete=False, useZero=False):
327        TypeGenerator.__init__(self)
328        self.typeGen = typeGen
329        self.useIncomplete = useIncomplete
330        self.useZero = useZero
331        self.maxSize = int(maxSize)
332        self.W = useIncomplete + useZero + self.maxSize
333        self.setCardinality()
334
335    def setCardinality(self):
336        self.cardinality = self.W * self.typeGen.cardinality
337
338    def generateType(self, N):
339        S,T = getNthPairBounded(N, self.W, self.typeGen.cardinality)
340        if self.useIncomplete:
341            if S==0:
342                size = None
343                S = None
344            else:
345                S = S - 1
346        if S is not None:
347            if self.useZero:
348                size = S
349            else:
350                size = S + 1
351        return ArrayType(N, False, self.typeGen.get(T), size)
352
353class RecordTypeGenerator(TypeGenerator):
354    def __init__(self, typeGen, useUnion, maxSize):
355        TypeGenerator.__init__(self)
356        self.typeGen = typeGen
357        self.useUnion = bool(useUnion)
358        self.maxSize = int(maxSize)
359        self.setCardinality()
360
361    def setCardinality(self):
362        M = 1 + self.useUnion
363        if self.maxSize is aleph0:
364            S =  aleph0 * self.typeGen.cardinality
365        else:
366            S = 0
367            for i in range(self.maxSize+1):
368                S += M * (self.typeGen.cardinality ** i)
369        self.cardinality = S
370
371    def generateType(self, N):
372        isUnion,I = False,N
373        if self.useUnion:
374            isUnion,I = (I&1),I>>1
375        fields = [self.typeGen.get(f) for f in getNthTuple(I,self.maxSize,self.typeGen.cardinality)]
376        return RecordType(N, isUnion, fields)
377
378class FunctionTypeGenerator(TypeGenerator):
379    def __init__(self, typeGen, useReturn, maxSize):
380        TypeGenerator.__init__(self)
381        self.typeGen = typeGen
382        self.useReturn = useReturn
383        self.maxSize = maxSize
384        self.setCardinality()
385
386    def setCardinality(self):
387        if self.maxSize is aleph0:
388            S = aleph0 * self.typeGen.cardinality()
389        elif self.useReturn:
390            S = 0
391            for i in range(1,self.maxSize+1+1):
392                S += self.typeGen.cardinality ** i
393        else:
394            S = 0
395            for i in range(self.maxSize+1):
396                S += self.typeGen.cardinality ** i
397        self.cardinality = S
398
399    def generateType(self, N):
400        if self.useReturn:
401            # Skip the empty tuple
402            argIndices = getNthTuple(N+1, self.maxSize+1, self.typeGen.cardinality)
403            retIndex,argIndices = argIndices[0],argIndices[1:]
404            retTy = self.typeGen.get(retIndex)
405        else:
406            retTy = None
407            argIndices = getNthTuple(N, self.maxSize, self.typeGen.cardinality)
408        args = [self.typeGen.get(i) for i in argIndices]
409        return FunctionType(N, retTy, args)
410
411class AnyTypeGenerator(TypeGenerator):
412    def __init__(self):
413        TypeGenerator.__init__(self)
414        self.generators = []
415        self.bounds = []
416        self.setCardinality()
417        self._cardinality = None
418
419    def getCardinality(self):
420        if self._cardinality is None:
421            return aleph0
422        else:
423            return self._cardinality
424    def setCardinality(self):
425        self.bounds = [g.cardinality for g in self.generators]
426        self._cardinality = sum(self.bounds)
427    cardinality = property(getCardinality, None)
428
429    def addGenerator(self, g):
430        self.generators.append(g)
431        for i in range(100):
432            prev = self._cardinality
433            self._cardinality = None
434            for g in self.generators:
435                g.setCardinality()
436            self.setCardinality()
437            if (self._cardinality is aleph0) or prev==self._cardinality:
438                break
439        else:
440            raise RuntimeError("Infinite loop in setting cardinality")
441
442    def generateType(self, N):
443        index,M = getNthPairVariableBounds(N, self.bounds)
444        return self.generators[index].get(M)
445
446def test():
447    fbtg = FixedTypeGenerator([BuiltinType('char', 4),
448                               BuiltinType('char', 4, 0),
449                               BuiltinType('int',  4, 5)])
450
451    fields1 = AnyTypeGenerator()
452    fields1.addGenerator( fbtg )
453
454    fields0 = AnyTypeGenerator()
455    fields0.addGenerator( fbtg )
456#    fields0.addGenerator( RecordTypeGenerator(fields1, False, 4) )
457
458    btg = FixedTypeGenerator([BuiltinType('char', 4),
459                              BuiltinType('int',  4)])
460    etg = EnumTypeGenerator([None, '-1', '1', '1u'], 0, 3)
461
462    atg = AnyTypeGenerator()
463    atg.addGenerator( btg )
464    atg.addGenerator( RecordTypeGenerator(fields0, False, 4) )
465    atg.addGenerator( etg )
466    print('Cardinality:',atg.cardinality)
467    for i in range(100):
468        if i == atg.cardinality:
469            try:
470                atg.get(i)
471                raise RuntimeError("Cardinality was wrong")
472            except AssertionError:
473                break
474        print('%4d: %s'%(i, atg.get(i)))
475
476if __name__ == '__main__':
477    test()
478