1 /*
2 * This file is part of PowerDNS or dnsdist.
3 * Copyright -- PowerDNS.COM B.V. and its contributors
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of version 2 of the GNU General Public License as
7 * published by the Free Software Foundation.
8 *
9 * In addition, for the avoidance of any doubt, permission is granted to
10 * link this program with OpenSSL and to (re)distribute the binaries
11 * produced as the result of such linking.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with this program; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 #pragma once
23 #include <map>
24 #include <sstream>
25 #include <stdexcept>
26 #include <iostream>
27 #include <vector>
28 #include <errno.h>
29 // #include <netinet/in.h>
30 #include "misc.hh"
31
32 #include <boost/tuple/tuple.hpp>
33 #include <boost/tuple/tuple_comparison.hpp>
34 #include "dns.hh"
35 #include "dnswriter.hh"
36 #include "dnsname.hh"
37 #include "pdnsexception.hh"
38 #include "iputils.hh"
39 #include "svc-records.hh"
40
41 /** DNS records have three representations:
42 1) in the packet
43 2) parsed in a class, ready for use
44 3) in the zone
45
46 We should implement bidirectional transitions between 1&2 and 2&3.
47 Currently we have: 1 -> 2
48 2 -> 3
49
50 We can add: 2 -> 1 easily by reversing the packetwriter
51 And we might be able to reverse 2 -> 3 as well
52 */
53
54 #include "namespaces.hh"
55
56 class MOADNSException : public runtime_error
57 {
58 public:
MOADNSException(const string & str)59 MOADNSException(const string& str) : runtime_error(str)
60 {}
61 };
62
63
64 class MOADNSParser;
65
66 class PacketReader
67 {
68 public:
PacketReader(const pdns_string_view & content,uint16_t initialPos=sizeof (dnsheader))69 PacketReader(const pdns_string_view& content, uint16_t initialPos=sizeof(dnsheader))
70 : d_pos(initialPos), d_startrecordpos(initialPos), d_content(content)
71 {
72 if(content.size() > std::numeric_limits<uint16_t>::max())
73 throw std::out_of_range("packet too large");
74
75 d_recordlen = (uint16_t) content.size();
76 not_used = 0;
77 }
78
79 uint32_t get32BitInt();
80 uint16_t get16BitInt();
81 uint8_t get8BitInt();
82
83 void xfrNodeOrLocatorID(NodeOrLocatorID& val);
84 void xfr48BitInt(uint64_t& val);
85
xfr32BitInt(uint32_t & val)86 void xfr32BitInt(uint32_t& val)
87 {
88 val=get32BitInt();
89 }
90
xfrIP(uint32_t & val)91 void xfrIP(uint32_t& val)
92 {
93 xfr32BitInt(val);
94 val=htonl(val);
95 }
96
xfrIP6(std::string & val)97 void xfrIP6(std::string &val) {
98 xfrBlob(val, 16);
99 }
100
xfrCAWithoutPort(uint8_t version,ComboAddress & val)101 void xfrCAWithoutPort(uint8_t version, ComboAddress &val) {
102 string blob;
103 if (version == 4) xfrBlob(blob, 4);
104 else if (version == 6) xfrBlob(blob, 16);
105 else throw runtime_error("invalid IP protocol");
106 val = makeComboAddressFromRaw(version, blob);
107 }
108
xfrCAPort(ComboAddress & val)109 void xfrCAPort(ComboAddress &val) {
110 uint16_t port;
111 xfr16BitInt(port);
112 val.sin4.sin_port = port;
113 }
114
xfrTime(uint32_t & val)115 void xfrTime(uint32_t& val)
116 {
117 xfr32BitInt(val);
118 }
119
120
xfr16BitInt(uint16_t & val)121 void xfr16BitInt(uint16_t& val)
122 {
123 val=get16BitInt();
124 }
125
xfrType(uint16_t & val)126 void xfrType(uint16_t& val)
127 {
128 xfr16BitInt(val);
129 }
130
131
xfr8BitInt(uint8_t & val)132 void xfr8BitInt(uint8_t& val)
133 {
134 val=get8BitInt();
135 }
136
137
xfrName(DNSName & name,bool compress=false,bool noDot=false)138 void xfrName(DNSName &name, bool compress=false, bool noDot=false)
139 {
140 name=getName();
141 }
142
xfrText(string & text,bool multi=false,bool lenField=true)143 void xfrText(string &text, bool multi=false, bool lenField=true)
144 {
145 text=getText(multi, lenField);
146 }
147
xfrUnquotedText(string & text,bool lenField)148 void xfrUnquotedText(string &text, bool lenField){
149 text=getUnquotedText(lenField);
150 }
151
152 void xfrBlob(string& blob);
153 void xfrBlobNoSpaces(string& blob, int len);
154 void xfrBlob(string& blob, int length);
155 void xfrHexBlob(string& blob, bool keepReading=false);
156 void xfrSvcParamKeyVals(set<SvcParam> &kvs);
157
158 void getDnsrecordheader(struct dnsrecordheader &ah);
159 void copyRecord(vector<unsigned char>& dest, uint16_t len);
160 void copyRecord(unsigned char* dest, uint16_t len);
161
162 DNSName getName();
163 string getText(bool multi, bool lenField);
164 string getUnquotedText(bool lenField);
165
166
eof()167 bool eof() { return true; };
getRemaining() const168 const string getRemaining() const {
169 return "";
170 };
171
getPosition() const172 uint16_t getPosition() const
173 {
174 return d_pos;
175 }
176
skip(uint16_t n)177 void skip(uint16_t n)
178 {
179 d_pos += n;
180 }
181
182 private:
183 uint16_t d_pos;
184 uint16_t d_startrecordpos; // needed for getBlob later on
185 uint16_t d_recordlen; // ditto
186 uint16_t not_used; // Aligns the whole class on 8-byte boundaries
187 const pdns_string_view d_content;
188 };
189
190 struct DNSRecord;
191
192 class DNSRecordContent
193 {
194 public:
195 static std::shared_ptr<DNSRecordContent> mastermake(const DNSRecord &dr, PacketReader& pr);
196 static std::shared_ptr<DNSRecordContent> mastermake(const DNSRecord &dr, PacketReader& pr, uint16_t opcode);
197 static std::shared_ptr<DNSRecordContent> mastermake(uint16_t qtype, uint16_t qclass, const string& zone);
198 static string upgradeContent(const DNSName& qname, const QType& qtype, const string& content);
199
200 virtual std::string getZoneRepresentation(bool noDot=false) const = 0;
~DNSRecordContent()201 virtual ~DNSRecordContent() {}
202 virtual void toPacket(DNSPacketWriter& pw)=0;
serialize(const DNSName & qname,bool canonic=false,bool lowerCase=false)203 virtual string serialize(const DNSName& qname, bool canonic=false, bool lowerCase=false) // it would rock if this were const, but it is too hard
204 {
205 vector<uint8_t> packet;
206 DNSPacketWriter pw(packet, g_rootdnsname, 1);
207 if(canonic)
208 pw.setCanonic(true);
209
210 if(lowerCase)
211 pw.setLowercase(true);
212
213 pw.startRecord(qname, this->getType());
214 this->toPacket(pw);
215
216 string record;
217 pw.getRecordPayload(record); // needs to be called before commit()
218 return record;
219 }
220
operator ==(const DNSRecordContent & rhs) const221 virtual bool operator==(const DNSRecordContent& rhs) const
222 {
223 return typeid(*this)==typeid(rhs) && this->getZoneRepresentation() == rhs.getZoneRepresentation();
224 }
225
226 static shared_ptr<DNSRecordContent> deserialize(const DNSName& qname, uint16_t qtype, const string& serialized);
227
doRecordCheck(const struct DNSRecord &)228 void doRecordCheck(const struct DNSRecord&){}
229
230 typedef std::shared_ptr<DNSRecordContent> makerfunc_t(const struct DNSRecord& dr, PacketReader& pr);
231 typedef std::shared_ptr<DNSRecordContent> zmakerfunc_t(const string& str);
232
regist(uint16_t cl,uint16_t ty,makerfunc_t * f,zmakerfunc_t * z,const char * name)233 static void regist(uint16_t cl, uint16_t ty, makerfunc_t* f, zmakerfunc_t* z, const char* name)
234 {
235 if(f)
236 getTypemap()[make_pair(cl,ty)]=f;
237 if(z)
238 getZmakermap()[make_pair(cl,ty)]=z;
239
240 getT2Namemap().insert(make_pair(make_pair(cl,ty), name));
241 getN2Typemap().insert(make_pair(name, make_pair(cl,ty)));
242 }
243
unregist(uint16_t cl,uint16_t ty)244 static void unregist(uint16_t cl, uint16_t ty)
245 {
246 pair<uint16_t, uint16_t> key=make_pair(cl, ty);
247 getTypemap().erase(key);
248 getZmakermap().erase(key);
249 }
250
isUnknownType(const string & name)251 static bool isUnknownType(const string& name)
252 {
253 return boost::starts_with(name, "TYPE") || boost::starts_with(name, "type");
254 }
255
TypeToNumber(const string & name)256 static uint16_t TypeToNumber(const string& name)
257 {
258 n2typemap_t::const_iterator iter = getN2Typemap().find(toUpper(name));
259 if(iter != getN2Typemap().end())
260 return iter->second.second;
261
262 if (isUnknownType(name))
263 return (uint16_t) pdns_stou(name.substr(4));
264
265 throw runtime_error("Unknown DNS type '"+name+"'");
266 }
267
NumberToType(uint16_t num,uint16_t classnum=1)268 static const string NumberToType(uint16_t num, uint16_t classnum=1)
269 {
270 t2namemap_t::const_iterator iter = getT2Namemap().find(make_pair(classnum, num));
271 if(iter == getT2Namemap().end())
272 return "TYPE" + std::to_string(num);
273 // throw runtime_error("Unknown DNS type with numerical id "+std::to_string(num));
274 return iter->second;
275 }
276
277 virtual uint16_t getType() const = 0;
278
279 protected:
280 typedef std::map<std::pair<uint16_t, uint16_t>, makerfunc_t* > typemap_t;
281 typedef std::map<std::pair<uint16_t, uint16_t>, zmakerfunc_t* > zmakermap_t;
282 typedef std::map<std::pair<uint16_t, uint16_t>, string > t2namemap_t;
283 typedef std::map<string, std::pair<uint16_t, uint16_t> > n2typemap_t;
284 static typemap_t& getTypemap();
285 static t2namemap_t& getT2Namemap();
286 static n2typemap_t& getN2Typemap();
287 static zmakermap_t& getZmakermap();
288 };
289
290 struct DNSRecord
291 {
DNSRecordDNSRecord292 DNSRecord() : d_type(0), d_class(QClass::IN), d_ttl(0), d_clen(0), d_place(DNSResourceRecord::ANSWER)
293 {}
294 explicit DNSRecord(const DNSResourceRecord& rr);
295 DNSName d_name;
296 std::shared_ptr<DNSRecordContent> d_content;
297 uint16_t d_type;
298 uint16_t d_class;
299 uint32_t d_ttl;
300 uint16_t d_clen;
301 DNSResourceRecord::Place d_place;
302
operator <DNSRecord303 bool operator<(const DNSRecord& rhs) const
304 {
305 if(tie(d_name, d_type, d_class, d_ttl) < tie(rhs.d_name, rhs.d_type, rhs.d_class, rhs.d_ttl))
306 return true;
307
308 if(tie(d_name, d_type, d_class, d_ttl) != tie(rhs.d_name, rhs.d_type, rhs.d_class, rhs.d_ttl))
309 return false;
310
311 string lzrp, rzrp;
312 if(d_content)
313 lzrp=toLower(d_content->getZoneRepresentation());
314 if(rhs.d_content)
315 rzrp=toLower(rhs.d_content->getZoneRepresentation());
316
317 return lzrp < rzrp;
318 }
319
320 // this orders in canonical order and keeps the SOA record on top
prettyCompareDNSRecord321 static bool prettyCompare(const DNSRecord& a, const DNSRecord& b)
322 {
323 auto aType = (a.d_type == QType::SOA) ? 0 : a.d_type;
324 auto bType = (b.d_type == QType::SOA) ? 0 : b.d_type;
325
326 if(a.d_name.canonCompare(b.d_name))
327 return true;
328 if(b.d_name.canonCompare(a.d_name))
329 return false;
330
331 if(tie(aType, a.d_class, a.d_ttl) < tie(bType, b.d_class, b.d_ttl))
332 return true;
333
334 if(tie(aType, a.d_class, a.d_ttl) != tie(bType, b.d_class, b.d_ttl))
335 return false;
336
337 string lzrp, rzrp;
338 if(a.d_content)
339 lzrp=toLower(a.d_content->getZoneRepresentation());
340 if(b.d_content)
341 rzrp=toLower(b.d_content->getZoneRepresentation());
342
343 return lzrp < rzrp;
344 }
345
346
operator ==DNSRecord347 bool operator==(const DNSRecord& rhs) const
348 {
349 if(d_type != rhs.d_type || d_class != rhs.d_class || d_name != rhs.d_name)
350 return false;
351
352 return *d_content == *rhs.d_content;
353 }
354 };
355
356 struct DNSZoneRecord
357 {
358 int domain_id{-1};
359 uint8_t scopeMask{0};
360 int signttl{0};
361 DNSName wildcardname;
362 bool auth{true};
363 bool disabled{false};
364 DNSRecord dr;
365 };
366
367 class UnknownRecordContent : public DNSRecordContent
368 {
369 public:
UnknownRecordContent(const DNSRecord & dr,PacketReader & pr)370 UnknownRecordContent(const DNSRecord& dr, PacketReader& pr)
371 : d_dr(dr)
372 {
373 pr.copyRecord(d_record, dr.d_clen);
374 }
375
376 UnknownRecordContent(const string& zone);
377
378 string getZoneRepresentation(bool noDot) const override;
379 void toPacket(DNSPacketWriter& pw) override;
getType() const380 uint16_t getType() const override
381 {
382 return d_dr.d_type;
383 }
384
385 private:
386 DNSRecord d_dr;
387 vector<uint8_t> d_record;
388 };
389
390 //! This class can be used to parse incoming packets, and is copyable
391 class MOADNSParser : public boost::noncopyable
392 {
393 public:
394 //! Parse from a string
MOADNSParser(bool query,const string & buffer)395 MOADNSParser(bool query, const string& buffer): d_tsigPos(0)
396 {
397 init(query, buffer);
398 }
399
400 //! Parse from a pointer and length
MOADNSParser(bool query,const char * packet,unsigned int len)401 MOADNSParser(bool query, const char *packet, unsigned int len) : d_tsigPos(0)
402 {
403 init(query, pdns_string_view(packet, len));
404 }
405
406 DNSName d_qname;
407 uint16_t d_qclass, d_qtype;
408 //uint8_t d_rcode;
409 dnsheader d_header;
410
411 typedef vector<pair<DNSRecord, uint16_t > > answers_t;
412
413 //! All answers contained in this packet (everything *but* the question section)
414 answers_t d_answers;
415
getTSIGPos() const416 uint16_t getTSIGPos() const
417 {
418 return d_tsigPos;
419 }
420
421 bool hasEDNS() const;
422
423 private:
424 void init(bool query, const pdns_string_view& packet);
425 uint16_t d_tsigPos;
426 };
427
428 string simpleCompress(const string& label, const string& root="");
429 void ageDNSPacket(char* packet, size_t length, uint32_t seconds);
430 void ageDNSPacket(std::string& packet, uint32_t seconds);
431 void editDNSPacketTTL(char* packet, size_t length, const std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)>& visitor);
432 uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA=nullptr);
433 uint32_t getDNSPacketLength(const char* packet, size_t length);
434 uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type);
435 bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z);
436
437 template<typename T>
getRR(const DNSRecord & dr)438 std::shared_ptr<T> getRR(const DNSRecord& dr)
439 {
440 return std::dynamic_pointer_cast<T>(dr.d_content);
441 }
442
443 /** Simple DNSPacketMangler. Ritual is: get a pointer into the packet and moveOffset() to beyond your needs
444 * If you survive that, feel free to read from the pointer */
445 class DNSPacketMangler
446 {
447 public:
DNSPacketMangler(std::string & packet)448 explicit DNSPacketMangler(std::string& packet)
449 : d_packet((char*) packet.c_str()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset)
450 {}
DNSPacketMangler(char * packet,size_t length)451 DNSPacketMangler(char* packet, size_t length)
452 : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset)
453 {}
454
455 /*! Advances past a wire-format domain name
456 * The name is not checked for adherence to length restrictions.
457 * Compression pointers are not followed.
458 */
skipDomainName()459 void skipDomainName()
460 {
461 uint8_t len;
462 while((len=get8BitInt())) {
463 if(len >= 0xc0) { // extended label
464 get8BitInt();
465 return;
466 }
467 skipBytes(len);
468 }
469 }
470
skipBytes(uint16_t bytes)471 void skipBytes(uint16_t bytes)
472 {
473 moveOffset(bytes);
474 }
rewindBytes(uint16_t by)475 void rewindBytes(uint16_t by)
476 {
477 rewindOffset(by);
478 }
get32BitInt()479 uint32_t get32BitInt()
480 {
481 const char* p = d_packet + d_offset;
482 moveOffset(4);
483 uint32_t ret;
484 memcpy(&ret, (void*)p, sizeof(ret));
485 return ntohl(ret);
486 }
get16BitInt()487 uint16_t get16BitInt()
488 {
489 const char* p = d_packet + d_offset;
490 moveOffset(2);
491 uint16_t ret;
492 memcpy(&ret, (void*)p, sizeof(ret));
493 return ntohs(ret);
494 }
495
get8BitInt()496 uint8_t get8BitInt()
497 {
498 const char* p = d_packet + d_offset;
499 moveOffset(1);
500 return *p;
501 }
502
skipRData()503 void skipRData()
504 {
505 int toskip = get16BitInt();
506 moveOffset(toskip);
507 }
508
decreaseAndSkip32BitInt(uint32_t decrease)509 void decreaseAndSkip32BitInt(uint32_t decrease)
510 {
511 const char *p = d_packet + d_offset;
512 moveOffset(4);
513
514 uint32_t tmp;
515 memcpy(&tmp, (void*) p, sizeof(tmp));
516 tmp = ntohl(tmp);
517 if (tmp > decrease) {
518 tmp -= decrease;
519 } else {
520 tmp = 0;
521 }
522 tmp = htonl(tmp);
523 memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp));
524 }
525
setAndSkip32BitInt(uint32_t value)526 void setAndSkip32BitInt(uint32_t value)
527 {
528 moveOffset(4);
529
530 value = htonl(value);
531 memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value));
532 }
533
getOffset() const534 uint32_t getOffset() const
535 {
536 return d_offset;
537 }
538
539 private:
moveOffset(uint16_t by)540 void moveOffset(uint16_t by)
541 {
542 d_notyouroffset += by;
543 if(d_notyouroffset > d_length)
544 throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > "
545 + std::to_string(d_length) );
546 }
547
rewindOffset(uint16_t by)548 void rewindOffset(uint16_t by)
549 {
550 if(d_notyouroffset < by)
551 throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
552 + std::to_string(by));
553 d_notyouroffset -= by;
554 if(d_notyouroffset < 12)
555 throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
556 + std::to_string(12));
557 }
558
559 char* d_packet;
560 size_t d_length;
561
562 uint32_t d_notyouroffset; // only 'moveOffset' can touch this
563 const uint32_t& d_offset; // look.. but don't touch
564 };
565