1# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
2
3# Copyright (C) 2001-2017 Nominum, Inc.
4#
5# Permission to use, copy, modify, and distribute this software and its
6# documentation for any purpose with or without fee is hereby granted,
7# provided that the above copyright notice and this permission notice
8# appear in all copies.
9#
10# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
11# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
13# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
16# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17
18"""Help for building DNS wire format messages"""
19
20from io import BytesIO
21import struct
22import random
23import time
24
25import dns.exception
26import dns.tsig
27from ._compat import long
28
29
30QUESTION = 0
31ANSWER = 1
32AUTHORITY = 2
33ADDITIONAL = 3
34
35
36class Renderer(object):
37    """Helper class for building DNS wire-format messages.
38
39    Most applications can use the higher-level L{dns.message.Message}
40    class and its to_wire() method to generate wire-format messages.
41    This class is for those applications which need finer control
42    over the generation of messages.
43
44    Typical use::
45
46        r = dns.renderer.Renderer(id=1, flags=0x80, max_size=512)
47        r.add_question(qname, qtype, qclass)
48        r.add_rrset(dns.renderer.ANSWER, rrset_1)
49        r.add_rrset(dns.renderer.ANSWER, rrset_2)
50        r.add_rrset(dns.renderer.AUTHORITY, ns_rrset)
51        r.add_edns(0, 0, 4096)
52        r.add_rrset(dns.renderer.ADDTIONAL, ad_rrset_1)
53        r.add_rrset(dns.renderer.ADDTIONAL, ad_rrset_2)
54        r.write_header()
55        r.add_tsig(keyname, secret, 300, 1, 0, '', request_mac)
56        wire = r.get_wire()
57
58    output, a BytesIO, where rendering is written
59
60    id: the message id
61
62    flags: the message flags
63
64    max_size: the maximum size of the message
65
66    origin: the origin to use when rendering relative names
67
68    compress: the compression table
69
70    section: an int, the section currently being rendered
71
72    counts: list of the number of RRs in each section
73
74    mac: the MAC of the rendered message (if TSIG was used)
75    """
76
77    def __init__(self, id=None, flags=0, max_size=65535, origin=None):
78        """Initialize a new renderer."""
79
80        self.output = BytesIO()
81        if id is None:
82            self.id = random.randint(0, 65535)
83        else:
84            self.id = id
85        self.flags = flags
86        self.max_size = max_size
87        self.origin = origin
88        self.compress = {}
89        self.section = QUESTION
90        self.counts = [0, 0, 0, 0]
91        self.output.write(b'\x00' * 12)
92        self.mac = ''
93
94    def _rollback(self, where):
95        """Truncate the output buffer at offset *where*, and remove any
96        compression table entries that pointed beyond the truncation
97        point.
98        """
99
100        self.output.seek(where)
101        self.output.truncate()
102        keys_to_delete = []
103        for k, v in self.compress.items():
104            if v >= where:
105                keys_to_delete.append(k)
106        for k in keys_to_delete:
107            del self.compress[k]
108
109    def _set_section(self, section):
110        """Set the renderer's current section.
111
112        Sections must be rendered order: QUESTION, ANSWER, AUTHORITY,
113        ADDITIONAL.  Sections may be empty.
114
115        Raises dns.exception.FormError if an attempt was made to set
116        a section value less than the current section.
117        """
118
119        if self.section != section:
120            if self.section > section:
121                raise dns.exception.FormError
122            self.section = section
123
124    def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN):
125        """Add a question to the message."""
126
127        self._set_section(QUESTION)
128        before = self.output.tell()
129        qname.to_wire(self.output, self.compress, self.origin)
130        self.output.write(struct.pack("!HH", rdtype, rdclass))
131        after = self.output.tell()
132        if after >= self.max_size:
133            self._rollback(before)
134            raise dns.exception.TooBig
135        self.counts[QUESTION] += 1
136
137    def add_rrset(self, section, rrset, **kw):
138        """Add the rrset to the specified section.
139
140        Any keyword arguments are passed on to the rdataset's to_wire()
141        routine.
142        """
143
144        self._set_section(section)
145        before = self.output.tell()
146        n = rrset.to_wire(self.output, self.compress, self.origin, **kw)
147        after = self.output.tell()
148        if after >= self.max_size:
149            self._rollback(before)
150            raise dns.exception.TooBig
151        self.counts[section] += n
152
153    def add_rdataset(self, section, name, rdataset, **kw):
154        """Add the rdataset to the specified section, using the specified
155        name as the owner name.
156
157        Any keyword arguments are passed on to the rdataset's to_wire()
158        routine.
159        """
160
161        self._set_section(section)
162        before = self.output.tell()
163        n = rdataset.to_wire(name, self.output, self.compress, self.origin,
164                             **kw)
165        after = self.output.tell()
166        if after >= self.max_size:
167            self._rollback(before)
168            raise dns.exception.TooBig
169        self.counts[section] += n
170
171    def add_edns(self, edns, ednsflags, payload, options=None):
172        """Add an EDNS OPT record to the message."""
173
174        # make sure the EDNS version in ednsflags agrees with edns
175        ednsflags &= long(0xFF00FFFF)
176        ednsflags |= (edns << 16)
177        self._set_section(ADDITIONAL)
178        before = self.output.tell()
179        self.output.write(struct.pack('!BHHIH', 0, dns.rdatatype.OPT, payload,
180                                      ednsflags, 0))
181        if options is not None:
182            lstart = self.output.tell()
183            for opt in options:
184                stuff = struct.pack("!HH", opt.otype, 0)
185                self.output.write(stuff)
186                start = self.output.tell()
187                opt.to_wire(self.output)
188                end = self.output.tell()
189                assert end - start < 65536
190                self.output.seek(start - 2)
191                stuff = struct.pack("!H", end - start)
192                self.output.write(stuff)
193                self.output.seek(0, 2)
194            lend = self.output.tell()
195            assert lend - lstart < 65536
196            self.output.seek(lstart - 2)
197            stuff = struct.pack("!H", lend - lstart)
198            self.output.write(stuff)
199            self.output.seek(0, 2)
200        after = self.output.tell()
201        if after >= self.max_size:
202            self._rollback(before)
203            raise dns.exception.TooBig
204        self.counts[ADDITIONAL] += 1
205
206    def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data,
207                 request_mac, algorithm=dns.tsig.default_algorithm):
208        """Add a TSIG signature to the message."""
209
210        s = self.output.getvalue()
211        (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s,
212                                                    keyname,
213                                                    secret,
214                                                    int(time.time()),
215                                                    fudge,
216                                                    id,
217                                                    tsig_error,
218                                                    other_data,
219                                                    request_mac,
220                                                    algorithm=algorithm)
221        self._write_tsig(tsig_rdata, keyname)
222
223    def add_multi_tsig(self, ctx, keyname, secret, fudge, id, tsig_error,
224                       other_data, request_mac,
225                       algorithm=dns.tsig.default_algorithm):
226        """Add a TSIG signature to the message. Unlike add_tsig(), this can be
227        used for a series of consecutive DNS envelopes, e.g. for a zone
228        transfer over TCP [RFC2845, 4.4].
229
230        For the first message in the sequence, give ctx=None. For each
231        subsequent message, give the ctx that was returned from the
232        add_multi_tsig() call for the previous message."""
233
234        s = self.output.getvalue()
235        (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s,
236                                                    keyname,
237                                                    secret,
238                                                    int(time.time()),
239                                                    fudge,
240                                                    id,
241                                                    tsig_error,
242                                                    other_data,
243                                                    request_mac,
244                                                    ctx=ctx,
245                                                    first=ctx is None,
246                                                    multi=True,
247                                                    algorithm=algorithm)
248        self._write_tsig(tsig_rdata, keyname)
249        return ctx
250
251    def _write_tsig(self, tsig_rdata, keyname):
252        self._set_section(ADDITIONAL)
253        before = self.output.tell()
254
255        keyname.to_wire(self.output, self.compress, self.origin)
256        self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG,
257                                      dns.rdataclass.ANY, 0, 0))
258        rdata_start = self.output.tell()
259        self.output.write(tsig_rdata)
260
261        after = self.output.tell()
262        assert after - rdata_start < 65536
263        if after >= self.max_size:
264            self._rollback(before)
265            raise dns.exception.TooBig
266
267        self.output.seek(rdata_start - 2)
268        self.output.write(struct.pack('!H', after - rdata_start))
269        self.counts[ADDITIONAL] += 1
270        self.output.seek(10)
271        self.output.write(struct.pack('!H', self.counts[ADDITIONAL]))
272        self.output.seek(0, 2)
273
274    def write_header(self):
275        """Write the DNS message header.
276
277        Writing the DNS message header is done after all sections
278        have been rendered, but before the optional TSIG signature
279        is added.
280        """
281
282        self.output.seek(0)
283        self.output.write(struct.pack('!HHHHHH', self.id, self.flags,
284                                      self.counts[0], self.counts[1],
285                                      self.counts[2], self.counts[3]))
286        self.output.seek(0, 2)
287
288    def get_wire(self):
289        """Return the wire format message."""
290
291        return self.output.getvalue()
292