1"""
2codegen.py generates pika/spec.py
3
4The required spec json file can be found at
5https://github.com/rabbitmq/rabbitmq-codegen
6.
7
8After cloning it run the following to generate a spec.py file:
9python2 ./codegen.py ../../rabbitmq-codegen
10"""
11from __future__ import nested_scopes
12
13import os
14import re
15import sys
16
17if sys.version_info.major != 2:
18    sys.exit('Python 2 is required at this time')
19
20RABBITMQ_CODEGEN_PATH = sys.argv[1]
21PIKA_SPEC = '../pika/spec.py'
22print('codegen-path: %s' % RABBITMQ_CODEGEN_PATH)
23sys.path.append(RABBITMQ_CODEGEN_PATH)
24
25import amqp_codegen
26
27DRIVER_METHODS = {
28    "Exchange.Bind": ["Exchange.BindOk"],
29    "Exchange.Unbind": ["Exchange.UnbindOk"],
30    "Exchange.Declare": ["Exchange.DeclareOk"],
31    "Exchange.Delete": ["Exchange.DeleteOk"],
32    "Queue.Declare": ["Queue.DeclareOk"],
33    "Queue.Bind": ["Queue.BindOk"],
34    "Queue.Purge": ["Queue.PurgeOk"],
35    "Queue.Delete": ["Queue.DeleteOk"],
36    "Queue.Unbind": ["Queue.UnbindOk"],
37    "Basic.Qos": ["Basic.QosOk"],
38    "Basic.Get": ["Basic.GetOk", "Basic.GetEmpty"],
39    "Basic.Ack": [],
40    "Basic.Reject": [],
41    "Basic.Recover": ["Basic.RecoverOk"],
42    "Basic.RecoverAsync": [],
43    "Tx.Select": ["Tx.SelectOk"],
44    "Tx.Commit": ["Tx.CommitOk"],
45    "Tx.Rollback": ["Tx.RollbackOk"]
46}
47
48
49def fieldvalue(v):
50    if isinstance(v, unicode):
51        return repr(v.encode('ascii'))
52    elif isinstance(v, dict):
53        return repr(None)
54    elif isinstance(v, list):
55        return repr(None)
56    else:
57        return repr(v)
58
59
60def normalize_separators(s):
61    s = s.replace('-', '_')
62    s = s.replace(' ', '_')
63    return s
64
65
66def pyize(s):
67    s = normalize_separators(s)
68    if s in ('global', 'class'):
69        s += '_'
70    if s == 'global_':
71        s = 'global_qos'
72    return s
73
74
75def camel(s):
76    return normalize_separators(s).title().replace('_', '')
77
78
79amqp_codegen.AmqpMethod.structName = lambda m: camel(m.klass.name) + '.' + camel(m.name)
80amqp_codegen.AmqpClass.structName = lambda c: camel(c.name) + "Properties"
81
82
83def constantName(s):
84    return '_'.join(re.split('[- ]', s.upper()))
85
86
87def flagName(c, f):
88    if c:
89        return c.structName() + '.' + constantName('flag_' + f.name)
90    else:
91        return constantName('flag_' + f.name)
92
93
94def generate(specPath):
95    spec = amqp_codegen.AmqpSpec(specPath)
96
97    def genSingleDecode(prefix, cLvalue, unresolved_domain):
98        type = spec.resolveDomain(unresolved_domain)
99        if type == 'shortstr':
100            print(prefix +
101                  "%s, offset = data.decode_short_string(encoded, offset)" %
102                  cLvalue)
103        elif type == 'longstr':
104            print(prefix +
105                  "length = struct.unpack_from('>I', encoded, offset)[0]")
106            print(prefix + "offset += 4")
107            print(prefix + "%s = encoded[offset:offset + length]" % cLvalue)
108            print(prefix + "try:")
109            print(prefix + "    %s = str(%s)" % (cLvalue, cLvalue))
110            print(prefix + "except UnicodeEncodeError:")
111            print(prefix + "    pass")
112            print(prefix + "offset += length")
113        elif type == 'octet':
114            print(prefix +
115                  "%s = struct.unpack_from('B', encoded, offset)[0]" % cLvalue)
116            print(prefix + "offset += 1")
117        elif type == 'short':
118            print(prefix +
119                  "%s = struct.unpack_from('>H', encoded, offset)[0]" % cLvalue)
120            print(prefix + "offset += 2")
121        elif type == 'long':
122            print(prefix +
123                  "%s = struct.unpack_from('>I', encoded, offset)[0]" % cLvalue)
124            print(prefix + "offset += 4")
125        elif type == 'longlong':
126            print(prefix +
127                  "%s = struct.unpack_from('>Q', encoded, offset)[0]" % cLvalue)
128            print(prefix + "offset += 8")
129        elif type == 'timestamp':
130            print(prefix +
131                  "%s = struct.unpack_from('>Q', encoded, offset)[0]" % cLvalue)
132            print(prefix + "offset += 8")
133        elif type == 'bit':
134            raise Exception("Can't decode bit in genSingleDecode")
135        elif type == 'table':
136            print(
137                Exception(prefix +
138                          "(%s, offset) = data.decode_table(encoded, offset)" %
139                          cLvalue))
140        else:
141            raise Exception("Illegal domain in genSingleDecode", type)
142
143    def genSingleEncode(prefix, cValue, unresolved_domain):
144        type = spec.resolveDomain(unresolved_domain)
145        if type == 'shortstr':
146            print(
147                prefix +
148                "assert isinstance(%s, str_or_bytes),\\\n%s       'A non-string value was supplied for %s'"
149                % (cValue, prefix, cValue))
150            print(prefix + "data.encode_short_string(pieces, %s)" % cValue)
151        elif type == 'longstr':
152            print(
153                prefix +
154                "assert isinstance(%s, str_or_bytes),\\\n%s       'A non-string value was supplied for %s'"
155                % (cValue, prefix, cValue))
156            print(
157                prefix +
158                "value = %s.encode('utf-8') if isinstance(%s, unicode_type) else %s"
159                % (cValue, cValue, cValue))
160            print(prefix + "pieces.append(struct.pack('>I', len(value)))")
161            print(prefix + "pieces.append(value)")
162        elif type == 'octet':
163            print(prefix + "pieces.append(struct.pack('B', %s))" % cValue)
164        elif type == 'short':
165            print(prefix + "pieces.append(struct.pack('>H', %s))" % cValue)
166        elif type == 'long':
167            print(prefix + "pieces.append(struct.pack('>I', %s))" % cValue)
168        elif type == 'longlong':
169            print(prefix + "pieces.append(struct.pack('>Q', %s))" % cValue)
170        elif type == 'timestamp':
171            print(prefix + "pieces.append(struct.pack('>Q', %s))" % cValue)
172        elif type == 'bit':
173            raise Exception("Can't encode bit in genSingleEncode")
174        elif type == 'table':
175            print(Exception(prefix + "data.encode_table(pieces, %s)" % cValue))
176        else:
177            raise Exception("Illegal domain in genSingleEncode", type)
178
179    def genDecodeMethodFields(m):
180        print("        def decode(self, encoded, offset=0):")
181        bitindex = None
182        for f in m.arguments:
183            if spec.resolveDomain(f.domain) == 'bit':
184                if bitindex is None:
185                    bitindex = 0
186                if bitindex >= 8:
187                    bitindex = 0
188                if not bitindex:
189                    print(
190                        "            bit_buffer = struct.unpack_from('B', encoded, offset)[0]"
191                    )
192                    print("            offset += 1")
193                print("            self.%s = (bit_buffer & (1 << %d)) != 0" %
194                      (pyize(f.name), bitindex))
195                bitindex += 1
196            else:
197                bitindex = None
198                genSingleDecode("            ", "self.%s" % (pyize(f.name),),
199                                f.domain)
200        print("            return self")
201        print('')
202
203    def genDecodeProperties(c):
204        print("    def decode(self, encoded, offset=0):")
205        print("        flags = 0")
206        print("        flagword_index = 0")
207        print("        while True:")
208        print(
209            "            partial_flags = struct.unpack_from('>H', encoded, offset)[0]"
210        )
211        print("            offset += 2")
212        print(
213            "            flags = flags | (partial_flags << (flagword_index * 16))"
214        )
215        print("            if not (partial_flags & 1):")
216        print("                break")
217        print("            flagword_index += 1")
218        for f in c.fields:
219            if spec.resolveDomain(f.domain) == 'bit':
220                print("        self.%s = (flags & %s) != 0" % (pyize(f.name),
221                                                               flagName(c, f)))
222            else:
223                print("        if flags & %s:" % (flagName(c, f),))
224                genSingleDecode("            ", "self.%s" % (pyize(f.name),),
225                                f.domain)
226                print("        else:")
227                print("            self.%s = None" % (pyize(f.name),))
228        print("        return self")
229        print('')
230
231    def genEncodeMethodFields(m):
232        print("        def encode(self):")
233        print("            pieces = list()")
234        bitindex = None
235
236        def finishBits():
237            if bitindex is not None:
238                print("            pieces.append(struct.pack('B', bit_buffer))")
239
240        for f in m.arguments:
241            if spec.resolveDomain(f.domain) == 'bit':
242                if bitindex is None:
243                    bitindex = 0
244                    print("            bit_buffer = 0")
245                if bitindex >= 8:
246                    finishBits()
247                    print("            bit_buffer = 0")
248                    bitindex = 0
249                print("            if self.%s:" % pyize(f.name))
250                print("                bit_buffer |= 1 << %d" % bitindex)
251                bitindex += 1
252            else:
253                finishBits()
254                bitindex = None
255                genSingleEncode("            ", "self.%s" % (pyize(f.name),),
256                                f.domain)
257        finishBits()
258        print("            return pieces")
259        print('')
260
261    def genEncodeProperties(c):
262        print("    def encode(self):")
263        print("        pieces = list()")
264        print("        flags = 0")
265        for f in c.fields:
266            if spec.resolveDomain(f.domain) == 'bit':
267                print("        if self.%s: flags = flags | %s" % (pyize(
268                    f.name), flagName(c, f)))
269            else:
270                print("        if self.%s is not None:" % (pyize(f.name),))
271                print("            flags = flags | %s" % (flagName(c, f),))
272                genSingleEncode("            ", "self.%s" % (pyize(f.name),),
273                                f.domain)
274        print("        flag_pieces = list()")
275        print("        while True:")
276        print("            remainder = flags >> 16")
277        print("            partial_flags = flags & 0xFFFE")
278        print("            if remainder != 0:")
279        print("                partial_flags |= 1")
280        print(
281            "            flag_pieces.append(struct.pack('>H', partial_flags))")
282        print("            flags = remainder")
283        print("            if not flags:")
284        print("                break")
285        print("        return flag_pieces + pieces")
286        print('')
287
288    def fieldDeclList(fields):
289        return ''.join([
290            ", %s=%s" % (pyize(f.name), fieldvalue(f.defaultvalue))
291            for f in fields
292        ])
293
294    def fieldInitList(prefix, fields):
295        if fields:
296            return ''.join(["%sself.%s = %s\n" % (prefix, pyize(f.name), pyize(f.name)) \
297                            for f in fields])
298        else:
299            return '%spass\n' % (prefix,)
300
301    print("""\"\"\"
302AMQP Specification
303==================
304This module implements the constants and classes that comprise AMQP protocol
305level constructs. It should rarely be directly referenced outside of Pika's
306own internal use.
307
308.. note:: Auto-generated code by codegen.py, do not edit directly. Pull
309requests to this file without accompanying ``utils/codegen.py`` changes will be
310rejected.
311
312\"\"\"
313
314import struct
315from pika import amqp_object
316from pika import data
317from pika.compat import str_or_bytes, unicode_type
318
319# Python 3 support for str object
320str = bytes
321""")
322
323    print("PROTOCOL_VERSION = (%d, %d, %d)" % (spec.major, spec.minor,
324                                               spec.revision))
325    print("PORT = %d" % spec.port)
326    print('')
327
328    # Append some constants that arent in the spec json file
329    spec.constants.append(('FRAME_MAX_SIZE', 131072, ''))
330    spec.constants.append(('FRAME_HEADER_SIZE', 7, ''))
331    spec.constants.append(('FRAME_END_SIZE', 1, ''))
332    spec.constants.append(('TRANSIENT_DELIVERY_MODE', 1, ''))
333    spec.constants.append(('PERSISTENT_DELIVERY_MODE', 2, ''))
334
335    constants = {}
336    for c, v, cls in spec.constants:
337        constants[constantName(c)] = v
338
339    for key in sorted(constants.keys()):
340        print("%s = %s" % (key, constants[key]))
341    print('')
342
343    for c in spec.allClasses():
344        print('')
345        print('class %s(amqp_object.Class):' % (camel(c.name),))
346        print('')
347        print("    INDEX = 0x%.04X  # %d" % (c.index, c.index))
348        print("    NAME = %s" % (fieldvalue(camel(c.name)),))
349        print('')
350
351        for m in c.allMethods():
352            print('    class %s(amqp_object.Method):' % (camel(m.name),))
353            print('')
354            methodid = m.klass.index << 16 | m.index
355            print("        INDEX = 0x%.08X  # %d, %d; %d" %
356                  (methodid, m.klass.index, m.index, methodid))
357            print("        NAME = %s" % (fieldvalue(m.structName(),)))
358            print('')
359            print(
360                "        def __init__(self%s):" % (fieldDeclList(m.arguments),))
361            print(fieldInitList('            ', m.arguments))
362            print("        @property")
363            print("        def synchronous(self):")
364            print("            return %s" % m.isSynchronous)
365            print('')
366            genDecodeMethodFields(m)
367            genEncodeMethodFields(m)
368
369    for c in spec.allClasses():
370        if c.fields:
371            print('')
372            print('class %s(amqp_object.Properties):' % (c.structName(),))
373            print('')
374            print("    CLASS = %s" % (camel(c.name),))
375            print("    INDEX = 0x%.04X  # %d" % (c.index, c.index))
376            print("    NAME = %s" % (fieldvalue(c.structName(),)))
377            print('')
378
379            index = 0
380            if c.fields:
381                for f in c.fields:
382                    if index % 16 == 15:
383                        index += 1
384                    shortnum = index / 16
385                    partialindex = 15 - (index % 16)
386                    bitindex = shortnum * 16 + partialindex
387                    print('    %s = (1 << %d)' % (flagName(None, f), bitindex))
388                    index += 1
389                print('')
390
391            print("    def __init__(self%s):" % (fieldDeclList(c.fields),))
392            print(fieldInitList('        ', c.fields))
393            genDecodeProperties(c)
394            genEncodeProperties(c)
395
396    print("methods = {")
397    print(',\n'.join([
398        "    0x%08X: %s" % (m.klass.index << 16 | m.index, m.structName())
399        for m in spec.allMethods()
400    ]))
401    print("}")
402    print('')
403
404    print("props = {")
405    print(',\n'.join([
406        "    0x%04X: %s" % (c.index, c.structName())
407        for c in spec.allClasses()
408        if c.fields
409    ]))
410    print("}")
411    print('')
412    print('')
413
414    print("def has_content(methodNumber):")
415    print('    return methodNumber in (')
416    for m in spec.allMethods():
417        if m.hasContent:
418            print('        %s.INDEX,' % m.structName())
419    print('    )')
420
421
422if __name__ == "__main__":
423    with open(PIKA_SPEC, 'w') as handle:
424        sys.stdout = handle
425        generate(['%s/amqp-rabbitmq-0.9.1.json' % RABBITMQ_CODEGEN_PATH])
426