1# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
2
3# Copyright (C) 2001-2017 Nominum, Inc.
4#
5# Permission to use, copy, modify, and distribute this software and its
6# documentation for any purpose with or without fee is hereby granted,
7# provided that the above copyright notice and this permission notice
8# appear in all copies.
9#
10# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
11# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
13# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
16# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17
18"""Help for building DNS wire format messages"""
19
20import contextlib
21import io
22import struct
23import random
24import time
25
26import dns.exception
27import dns.tsig
28
29
30QUESTION = 0
31ANSWER = 1
32AUTHORITY = 2
33ADDITIONAL = 3
34
35
36class Renderer:
37    """Helper class for building DNS wire-format messages.
38
39    Most applications can use the higher-level L{dns.message.Message}
40    class and its to_wire() method to generate wire-format messages.
41    This class is for those applications which need finer control
42    over the generation of messages.
43
44    Typical use::
45
46        r = dns.renderer.Renderer(id=1, flags=0x80, max_size=512)
47        r.add_question(qname, qtype, qclass)
48        r.add_rrset(dns.renderer.ANSWER, rrset_1)
49        r.add_rrset(dns.renderer.ANSWER, rrset_2)
50        r.add_rrset(dns.renderer.AUTHORITY, ns_rrset)
51        r.add_edns(0, 0, 4096)
52        r.add_rrset(dns.renderer.ADDTIONAL, ad_rrset_1)
53        r.add_rrset(dns.renderer.ADDTIONAL, ad_rrset_2)
54        r.write_header()
55        r.add_tsig(keyname, secret, 300, 1, 0, '', request_mac)
56        wire = r.get_wire()
57
58    output, an io.BytesIO, where rendering is written
59
60    id: the message id
61
62    flags: the message flags
63
64    max_size: the maximum size of the message
65
66    origin: the origin to use when rendering relative names
67
68    compress: the compression table
69
70    section: an int, the section currently being rendered
71
72    counts: list of the number of RRs in each section
73
74    mac: the MAC of the rendered message (if TSIG was used)
75    """
76
77    def __init__(self, id=None, flags=0, max_size=65535, origin=None):
78        """Initialize a new renderer."""
79
80        self.output = io.BytesIO()
81        if id is None:
82            self.id = random.randint(0, 65535)
83        else:
84            self.id = id
85        self.flags = flags
86        self.max_size = max_size
87        self.origin = origin
88        self.compress = {}
89        self.section = QUESTION
90        self.counts = [0, 0, 0, 0]
91        self.output.write(b'\x00' * 12)
92        self.mac = ''
93
94    def _rollback(self, where):
95        """Truncate the output buffer at offset *where*, and remove any
96        compression table entries that pointed beyond the truncation
97        point.
98        """
99
100        self.output.seek(where)
101        self.output.truncate()
102        keys_to_delete = []
103        for k, v in self.compress.items():
104            if v >= where:
105                keys_to_delete.append(k)
106        for k in keys_to_delete:
107            del self.compress[k]
108
109    def _set_section(self, section):
110        """Set the renderer's current section.
111
112        Sections must be rendered order: QUESTION, ANSWER, AUTHORITY,
113        ADDITIONAL.  Sections may be empty.
114
115        Raises dns.exception.FormError if an attempt was made to set
116        a section value less than the current section.
117        """
118
119        if self.section != section:
120            if self.section > section:
121                raise dns.exception.FormError
122            self.section = section
123
124    @contextlib.contextmanager
125    def _track_size(self):
126        start = self.output.tell()
127        yield start
128        if self.output.tell() > self.max_size:
129            self._rollback(start)
130            raise dns.exception.TooBig
131
132    def add_question(self, qname, rdtype, rdclass=dns.rdataclass.IN):
133        """Add a question to the message."""
134
135        self._set_section(QUESTION)
136        with self._track_size():
137            qname.to_wire(self.output, self.compress, self.origin)
138            self.output.write(struct.pack("!HH", rdtype, rdclass))
139        self.counts[QUESTION] += 1
140
141    def add_rrset(self, section, rrset, **kw):
142        """Add the rrset to the specified section.
143
144        Any keyword arguments are passed on to the rdataset's to_wire()
145        routine.
146        """
147
148        self._set_section(section)
149        with self._track_size():
150            n = rrset.to_wire(self.output, self.compress, self.origin, **kw)
151        self.counts[section] += n
152
153    def add_rdataset(self, section, name, rdataset, **kw):
154        """Add the rdataset to the specified section, using the specified
155        name as the owner name.
156
157        Any keyword arguments are passed on to the rdataset's to_wire()
158        routine.
159        """
160
161        self._set_section(section)
162        with self._track_size():
163            n = rdataset.to_wire(name, self.output, self.compress, self.origin,
164                                 **kw)
165        self.counts[section] += n
166
167    def add_edns(self, edns, ednsflags, payload, options=None):
168        """Add an EDNS OPT record to the message."""
169
170        # make sure the EDNS version in ednsflags agrees with edns
171        ednsflags &= 0xFF00FFFF
172        ednsflags |= (edns << 16)
173        opt = dns.message.Message._make_opt(ednsflags, payload, options)
174        self.add_rrset(ADDITIONAL, opt)
175
176    def add_tsig(self, keyname, secret, fudge, id, tsig_error, other_data,
177                 request_mac, algorithm=dns.tsig.default_algorithm):
178        """Add a TSIG signature to the message."""
179
180        s = self.output.getvalue()
181
182        if isinstance(secret, dns.tsig.Key):
183            key = secret
184        else:
185            key = dns.tsig.Key(keyname, secret, algorithm)
186        tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge,
187                                              b'', id, tsig_error, other_data)
188        (tsig, _) = dns.tsig.sign(s, key, tsig[0], int(time.time()),
189                                  request_mac)
190        self._write_tsig(tsig, keyname)
191
192    def add_multi_tsig(self, ctx, keyname, secret, fudge, id, tsig_error,
193                       other_data, request_mac,
194                       algorithm=dns.tsig.default_algorithm):
195        """Add a TSIG signature to the message. Unlike add_tsig(), this can be
196        used for a series of consecutive DNS envelopes, e.g. for a zone
197        transfer over TCP [RFC2845, 4.4].
198
199        For the first message in the sequence, give ctx=None. For each
200        subsequent message, give the ctx that was returned from the
201        add_multi_tsig() call for the previous message."""
202
203        s = self.output.getvalue()
204
205        if isinstance(secret, dns.tsig.Key):
206            key = secret
207        else:
208            key = dns.tsig.Key(keyname, secret, algorithm)
209        tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge,
210                                              b'', id, tsig_error, other_data)
211        (tsig, ctx) = dns.tsig.sign(s, key, tsig[0], int(time.time()),
212                                    request_mac, ctx, True)
213        self._write_tsig(tsig, keyname)
214        return ctx
215
216    def _write_tsig(self, tsig, keyname):
217        self._set_section(ADDITIONAL)
218        with self._track_size():
219            keyname.to_wire(self.output, self.compress, self.origin)
220            self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG,
221                                          dns.rdataclass.ANY, 0, 0))
222            rdata_start = self.output.tell()
223            tsig.to_wire(self.output)
224
225        after = self.output.tell()
226        self.output.seek(rdata_start - 2)
227        self.output.write(struct.pack('!H', after - rdata_start))
228        self.counts[ADDITIONAL] += 1
229        self.output.seek(10)
230        self.output.write(struct.pack('!H', self.counts[ADDITIONAL]))
231        self.output.seek(0, io.SEEK_END)
232
233    def write_header(self):
234        """Write the DNS message header.
235
236        Writing the DNS message header is done after all sections
237        have been rendered, but before the optional TSIG signature
238        is added.
239        """
240
241        self.output.seek(0)
242        self.output.write(struct.pack('!HHHHHH', self.id, self.flags,
243                                      self.counts[0], self.counts[1],
244                                      self.counts[2], self.counts[3]))
245        self.output.seek(0, io.SEEK_END)
246
247    def get_wire(self):
248        """Return the wire format message."""
249
250        return self.output.getvalue()
251