1"""Flexible enumeration of C types."""
2
3from Enumeration import *
4
5# TODO:
6
7#  - struct improvements (flexible arrays, packed &
8#    unpacked, alignment)
9#  - objective-c qualified id
10#  - anonymous / transparent unions
11#  - VLAs
12#  - block types
13#  - K&R functions
14#  - pass arguments of different types (test extension, transparent union)
15#  - varargs
16
17###
18# Actual type types
19
20class Type:
21    def isBitField(self):
22        return False
23
24    def isPaddingBitField(self):
25        return False
26
27    def getTypeName(self, printer):
28        name = 'T%d' % len(printer.types)
29        typedef = self.getTypedefDef(name, printer)
30        printer.addDeclaration(typedef)
31        return name
32
33class BuiltinType(Type):
34    def __init__(self, name, size, bitFieldSize=None):
35        self.name = name
36        self.size = size
37        self.bitFieldSize = bitFieldSize
38
39    def isBitField(self):
40        return self.bitFieldSize is not None
41
42    def isPaddingBitField(self):
43        return self.bitFieldSize is 0
44
45    def getBitFieldSize(self):
46        assert self.isBitField()
47        return self.bitFieldSize
48
49    def getTypeName(self, printer):
50        return self.name
51
52    def sizeof(self):
53        return self.size
54
55    def __str__(self):
56        return self.name
57
58class EnumType(Type):
59    def __init__(self, index, enumerators):
60        self.index = index
61        self.enumerators = enumerators
62
63    def getEnumerators(self):
64        result = ''
65        for i, init in enumerate(self.enumerators):
66            if i > 0:
67                result = result + ', '
68            result = result + 'enum%dval%d' % (self.index, i)
69            if init:
70                result = result + ' = %s' % (init)
71
72        return result
73
74    def __str__(self):
75        return 'enum { %s }' % (self.getEnumerators())
76
77    def getTypedefDef(self, name, printer):
78        return 'typedef enum %s { %s } %s;'%(name, self.getEnumerators(), name)
79
80class RecordType(Type):
81    def __init__(self, index, isUnion, fields):
82        self.index = index
83        self.isUnion = isUnion
84        self.fields = fields
85        self.name = None
86
87    def __str__(self):
88        def getField(t):
89            if t.isBitField():
90                return "%s : %d;" % (t, t.getBitFieldSize())
91            else:
92                return "%s;" % t
93
94        return '%s { %s }'%(('struct','union')[self.isUnion],
95                            ' '.join(map(getField, self.fields)))
96
97    def getTypedefDef(self, name, printer):
98        def getField((i, t)):
99            if t.isBitField():
100                if t.isPaddingBitField():
101                    return '%s : 0;'%(printer.getTypeName(t),)
102                else:
103                    return '%s field%d : %d;'%(printer.getTypeName(t),i,
104                                               t.getBitFieldSize())
105            else:
106                return '%s field%d;'%(printer.getTypeName(t),i)
107        fields = map(getField, enumerate(self.fields))
108        # Name the struct for more readable LLVM IR.
109        return 'typedef %s %s { %s } %s;'%(('struct','union')[self.isUnion],
110                                           name, ' '.join(fields), name)
111
112class ArrayType(Type):
113    def __init__(self, index, isVector, elementType, size):
114        if isVector:
115            # Note that for vectors, this is the size in bytes.
116            assert size > 0
117        else:
118            assert size is None or size >= 0
119        self.index = index
120        self.isVector = isVector
121        self.elementType = elementType
122        self.size = size
123        if isVector:
124            eltSize = self.elementType.sizeof()
125            assert not (self.size % eltSize)
126            self.numElements = self.size // eltSize
127        else:
128            self.numElements = self.size
129
130    def __str__(self):
131        if self.isVector:
132            return 'vector (%s)[%d]'%(self.elementType,self.size)
133        elif self.size is not None:
134            return '(%s)[%d]'%(self.elementType,self.size)
135        else:
136            return '(%s)[]'%(self.elementType,)
137
138    def getTypedefDef(self, name, printer):
139        elementName = printer.getTypeName(self.elementType)
140        if self.isVector:
141            return 'typedef %s %s __attribute__ ((vector_size (%d)));'%(elementName,
142                                                                        name,
143                                                                        self.size)
144        else:
145            if self.size is None:
146                sizeStr = ''
147            else:
148                sizeStr = str(self.size)
149            return 'typedef %s %s[%s];'%(elementName, name, sizeStr)
150
151class ComplexType(Type):
152    def __init__(self, index, elementType):
153        self.index = index
154        self.elementType = elementType
155
156    def __str__(self):
157        return '_Complex (%s)'%(self.elementType)
158
159    def getTypedefDef(self, name, printer):
160        return 'typedef _Complex %s %s;'%(printer.getTypeName(self.elementType), name)
161
162class FunctionType(Type):
163    def __init__(self, index, returnType, argTypes):
164        self.index = index
165        self.returnType = returnType
166        self.argTypes = argTypes
167
168    def __str__(self):
169        if self.returnType is None:
170            rt = 'void'
171        else:
172            rt = str(self.returnType)
173        if not self.argTypes:
174            at = 'void'
175        else:
176            at = ', '.join(map(str, self.argTypes))
177        return '%s (*)(%s)'%(rt, at)
178
179    def getTypedefDef(self, name, printer):
180        if self.returnType is None:
181            rt = 'void'
182        else:
183            rt = str(self.returnType)
184        if not self.argTypes:
185            at = 'void'
186        else:
187            at = ', '.join(map(str, self.argTypes))
188        return 'typedef %s (*%s)(%s);'%(rt, name, at)
189
190###
191# Type enumerators
192
193class TypeGenerator(object):
194    def __init__(self):
195        self.cache = {}
196
197    def setCardinality(self):
198        abstract
199
200    def get(self, N):
201        T = self.cache.get(N)
202        if T is None:
203            assert 0 <= N < self.cardinality
204            T = self.cache[N] = self.generateType(N)
205        return T
206
207    def generateType(self, N):
208        abstract
209
210class FixedTypeGenerator(TypeGenerator):
211    def __init__(self, types):
212        TypeGenerator.__init__(self)
213        self.types = types
214        self.setCardinality()
215
216    def setCardinality(self):
217        self.cardinality = len(self.types)
218
219    def generateType(self, N):
220        return self.types[N]
221
222# Factorial
223def fact(n):
224    result = 1
225    while n > 0:
226        result = result * n
227        n = n - 1
228    return result
229
230# Compute the number of combinations (n choose k)
231def num_combinations(n, k):
232    return fact(n) / (fact(k) * fact(n - k))
233
234# Enumerate the combinations choosing k elements from the list of values
235def combinations(values, k):
236    # From ActiveState Recipe 190465: Generator for permutations,
237    # combinations, selections of a sequence
238    if k==0: yield []
239    else:
240        for i in xrange(len(values)-k+1):
241            for cc in combinations(values[i+1:],k-1):
242                yield [values[i]]+cc
243
244class EnumTypeGenerator(TypeGenerator):
245    def __init__(self, values, minEnumerators, maxEnumerators):
246        TypeGenerator.__init__(self)
247        self.values = values
248        self.minEnumerators = minEnumerators
249        self.maxEnumerators = maxEnumerators
250        self.setCardinality()
251
252    def setCardinality(self):
253        self.cardinality = 0
254        for num in range(self.minEnumerators, self.maxEnumerators + 1):
255            self.cardinality += num_combinations(len(self.values), num)
256
257    def generateType(self, n):
258        # Figure out the number of enumerators in this type
259        numEnumerators = self.minEnumerators
260        valuesCovered = 0
261        while numEnumerators < self.maxEnumerators:
262            comb = num_combinations(len(self.values), numEnumerators)
263            if valuesCovered + comb > n:
264                break
265            numEnumerators = numEnumerators + 1
266            valuesCovered += comb
267
268        # Find the requested combination of enumerators and build a
269        # type from it.
270        i = 0
271        for enumerators in combinations(self.values, numEnumerators):
272            if i == n - valuesCovered:
273                return EnumType(n, enumerators)
274
275            i = i + 1
276
277        assert False
278
279class ComplexTypeGenerator(TypeGenerator):
280    def __init__(self, typeGen):
281        TypeGenerator.__init__(self)
282        self.typeGen = typeGen
283        self.setCardinality()
284
285    def setCardinality(self):
286        self.cardinality = self.typeGen.cardinality
287
288    def generateType(self, N):
289        return ComplexType(N, self.typeGen.get(N))
290
291class VectorTypeGenerator(TypeGenerator):
292    def __init__(self, typeGen, sizes):
293        TypeGenerator.__init__(self)
294        self.typeGen = typeGen
295        self.sizes = tuple(map(int,sizes))
296        self.setCardinality()
297
298    def setCardinality(self):
299        self.cardinality = len(self.sizes)*self.typeGen.cardinality
300
301    def generateType(self, N):
302        S,T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
303        return ArrayType(N, True, self.typeGen.get(T), self.sizes[S])
304
305class FixedArrayTypeGenerator(TypeGenerator):
306    def __init__(self, typeGen, sizes):
307        TypeGenerator.__init__(self)
308        self.typeGen = typeGen
309        self.sizes = tuple(size)
310        self.setCardinality()
311
312    def setCardinality(self):
313        self.cardinality = len(self.sizes)*self.typeGen.cardinality
314
315    def generateType(self, N):
316        S,T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
317        return ArrayType(N, false, self.typeGen.get(T), self.sizes[S])
318
319class ArrayTypeGenerator(TypeGenerator):
320    def __init__(self, typeGen, maxSize, useIncomplete=False, useZero=False):
321        TypeGenerator.__init__(self)
322        self.typeGen = typeGen
323        self.useIncomplete = useIncomplete
324        self.useZero = useZero
325        self.maxSize = int(maxSize)
326        self.W = useIncomplete + useZero + self.maxSize
327        self.setCardinality()
328
329    def setCardinality(self):
330        self.cardinality = self.W * self.typeGen.cardinality
331
332    def generateType(self, N):
333        S,T = getNthPairBounded(N, self.W, self.typeGen.cardinality)
334        if self.useIncomplete:
335            if S==0:
336                size = None
337                S = None
338            else:
339                S = S - 1
340        if S is not None:
341            if self.useZero:
342                size = S
343            else:
344                size = S + 1
345        return ArrayType(N, False, self.typeGen.get(T), size)
346
347class RecordTypeGenerator(TypeGenerator):
348    def __init__(self, typeGen, useUnion, maxSize):
349        TypeGenerator.__init__(self)
350        self.typeGen = typeGen
351        self.useUnion = bool(useUnion)
352        self.maxSize = int(maxSize)
353        self.setCardinality()
354
355    def setCardinality(self):
356        M = 1 + self.useUnion
357        if self.maxSize is aleph0:
358            S =  aleph0 * self.typeGen.cardinality
359        else:
360            S = 0
361            for i in range(self.maxSize+1):
362                S += M * (self.typeGen.cardinality ** i)
363        self.cardinality = S
364
365    def generateType(self, N):
366        isUnion,I = False,N
367        if self.useUnion:
368            isUnion,I = (I&1),I>>1
369        fields = map(self.typeGen.get,getNthTuple(I,self.maxSize,self.typeGen.cardinality))
370        return RecordType(N, isUnion, fields)
371
372class FunctionTypeGenerator(TypeGenerator):
373    def __init__(self, typeGen, useReturn, maxSize):
374        TypeGenerator.__init__(self)
375        self.typeGen = typeGen
376        self.useReturn = useReturn
377        self.maxSize = maxSize
378        self.setCardinality()
379
380    def setCardinality(self):
381        if self.maxSize is aleph0:
382            S = aleph0 * self.typeGen.cardinality()
383        elif self.useReturn:
384            S = 0
385            for i in range(1,self.maxSize+1+1):
386                S += self.typeGen.cardinality ** i
387        else:
388            S = 0
389            for i in range(self.maxSize+1):
390                S += self.typeGen.cardinality ** i
391        self.cardinality = S
392
393    def generateType(self, N):
394        if self.useReturn:
395            # Skip the empty tuple
396            argIndices = getNthTuple(N+1, self.maxSize+1, self.typeGen.cardinality)
397            retIndex,argIndices = argIndices[0],argIndices[1:]
398            retTy = self.typeGen.get(retIndex)
399        else:
400            retTy = None
401            argIndices = getNthTuple(N, self.maxSize, self.typeGen.cardinality)
402        args = map(self.typeGen.get, argIndices)
403        return FunctionType(N, retTy, args)
404
405class AnyTypeGenerator(TypeGenerator):
406    def __init__(self):
407        TypeGenerator.__init__(self)
408        self.generators = []
409        self.bounds = []
410        self.setCardinality()
411        self._cardinality = None
412
413    def getCardinality(self):
414        if self._cardinality is None:
415            return aleph0
416        else:
417            return self._cardinality
418    def setCardinality(self):
419        self.bounds = [g.cardinality for g in self.generators]
420        self._cardinality = sum(self.bounds)
421    cardinality = property(getCardinality, None)
422
423    def addGenerator(self, g):
424        self.generators.append(g)
425        for i in range(100):
426            prev = self._cardinality
427            self._cardinality = None
428            for g in self.generators:
429                g.setCardinality()
430            self.setCardinality()
431            if (self._cardinality is aleph0) or prev==self._cardinality:
432                break
433        else:
434            raise RuntimeError,"Infinite loop in setting cardinality"
435
436    def generateType(self, N):
437        index,M = getNthPairVariableBounds(N, self.bounds)
438        return self.generators[index].get(M)
439
440def test():
441    fbtg = FixedTypeGenerator([BuiltinType('char', 4),
442                               BuiltinType('char', 4, 0),
443                               BuiltinType('int',  4, 5)])
444
445    fields1 = AnyTypeGenerator()
446    fields1.addGenerator( fbtg )
447
448    fields0 = AnyTypeGenerator()
449    fields0.addGenerator( fbtg )
450#    fields0.addGenerator( RecordTypeGenerator(fields1, False, 4) )
451
452    btg = FixedTypeGenerator([BuiltinType('char', 4),
453                              BuiltinType('int',  4)])
454    etg = EnumTypeGenerator([None, '-1', '1', '1u'], 0, 3)
455
456    atg = AnyTypeGenerator()
457    atg.addGenerator( btg )
458    atg.addGenerator( RecordTypeGenerator(fields0, False, 4) )
459    atg.addGenerator( etg )
460    print 'Cardinality:',atg.cardinality
461    for i in range(100):
462        if i == atg.cardinality:
463            try:
464                atg.get(i)
465                raise RuntimeError,"Cardinality was wrong"
466            except AssertionError:
467                break
468        print '%4d: %s'%(i, atg.get(i))
469
470if __name__ == '__main__':
471    test()
472