1 /*
2  * This file is part of PowerDNS or dnsdist.
3  * Copyright -- PowerDNS.COM B.V. and its contributors
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of version 2 of the GNU General Public License as
7  * published by the Free Software Foundation.
8  *
9  * In addition, for the avoidance of any doubt, permission is granted to
10  * link this program with OpenSSL and to (re)distribute the binaries
11  * produced as the result of such linking.
12  *
13  * This program is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  * GNU General Public License for more details.
17  *
18  * You should have received a copy of the GNU General Public License
19  * along with this program; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21  */
22 #pragma once
23 #include <map>
24 #include <sstream>
25 #include <stdexcept>
26 #include <iostream>
27 #include <vector>
28 #include <errno.h>
29 // #include <netinet/in.h>
30 #include "misc.hh"
31 
32 #include <boost/tuple/tuple.hpp>
33 #include <boost/tuple/tuple_comparison.hpp>
34 #include "dns.hh"
35 #include "dnswriter.hh"
36 #include "dnsname.hh"
37 #include "pdnsexception.hh"
38 #include "iputils.hh"
39 #include "svc-records.hh"
40 
41 /** DNS records have three representations:
42     1) in the packet
43     2) parsed in a class, ready for use
44     3) in the zone
45 
46     We should implement bidirectional transitions between 1&2 and 2&3.
47     Currently we have: 1 -> 2
48                        2 -> 3
49 
50     We can add:        2 -> 1  easily by reversing the packetwriter
51     And we might be able to reverse 2 -> 3 as well
52 */
53 
54 #include "namespaces.hh"
55 
56 class MOADNSException : public runtime_error
57 {
58 public:
MOADNSException(const string & str)59   MOADNSException(const string& str) : runtime_error(str)
60   {}
61 };
62 
63 
64 class MOADNSParser;
65 
66 class PacketReader
67 {
68 public:
PacketReader(const pdns_string_view & content,uint16_t initialPos=sizeof (dnsheader))69   PacketReader(const pdns_string_view& content, uint16_t initialPos=sizeof(dnsheader))
70     : d_pos(initialPos), d_startrecordpos(initialPos), d_content(content)
71   {
72     if(content.size() > std::numeric_limits<uint16_t>::max())
73       throw std::out_of_range("packet too large");
74 
75     d_recordlen = (uint16_t) content.size();
76     not_used = 0;
77   }
78 
79   uint32_t get32BitInt();
80   uint16_t get16BitInt();
81   uint8_t get8BitInt();
82 
83   void xfrNodeOrLocatorID(NodeOrLocatorID& val);
84   void xfr48BitInt(uint64_t& val);
85 
xfr32BitInt(uint32_t & val)86   void xfr32BitInt(uint32_t& val)
87   {
88     val=get32BitInt();
89   }
90 
xfrIP(uint32_t & val)91   void xfrIP(uint32_t& val)
92   {
93     xfr32BitInt(val);
94     val=htonl(val);
95   }
96 
xfrIP6(std::string & val)97   void xfrIP6(std::string &val) {
98     xfrBlob(val, 16);
99   }
100 
xfrCAWithoutPort(uint8_t version,ComboAddress & val)101   void xfrCAWithoutPort(uint8_t version, ComboAddress &val) {
102     string blob;
103     if (version == 4) xfrBlob(blob, 4);
104     else if (version == 6) xfrBlob(blob, 16);
105     else throw runtime_error("invalid IP protocol");
106     val = makeComboAddressFromRaw(version, blob);
107   }
108 
xfrCAPort(ComboAddress & val)109   void xfrCAPort(ComboAddress &val) {
110     uint16_t port;
111     xfr16BitInt(port);
112     val.sin4.sin_port = port;
113   }
114 
xfrTime(uint32_t & val)115   void xfrTime(uint32_t& val)
116   {
117     xfr32BitInt(val);
118   }
119 
120 
xfr16BitInt(uint16_t & val)121   void xfr16BitInt(uint16_t& val)
122   {
123     val=get16BitInt();
124   }
125 
xfrType(uint16_t & val)126   void xfrType(uint16_t& val)
127   {
128     xfr16BitInt(val);
129   }
130 
131 
xfr8BitInt(uint8_t & val)132   void xfr8BitInt(uint8_t& val)
133   {
134     val=get8BitInt();
135   }
136 
137 
xfrName(DNSName & name,bool compress=false,bool noDot=false)138   void xfrName(DNSName &name, bool compress=false, bool noDot=false)
139   {
140     name=getName();
141   }
142 
xfrText(string & text,bool multi=false,bool lenField=true)143   void xfrText(string &text, bool multi=false, bool lenField=true)
144   {
145     text=getText(multi, lenField);
146   }
147 
xfrUnquotedText(string & text,bool lenField)148   void xfrUnquotedText(string &text, bool lenField){
149     text=getUnquotedText(lenField);
150   }
151 
152   void xfrBlob(string& blob);
153   void xfrBlobNoSpaces(string& blob, int len);
154   void xfrBlob(string& blob, int length);
155   void xfrHexBlob(string& blob, bool keepReading=false);
156   void xfrSvcParamKeyVals(set<SvcParam> &kvs);
157 
158   void getDnsrecordheader(struct dnsrecordheader &ah);
159   void copyRecord(vector<unsigned char>& dest, uint16_t len);
160   void copyRecord(unsigned char* dest, uint16_t len);
161 
162   DNSName getName();
163   string getText(bool multi, bool lenField);
164   string getUnquotedText(bool lenField);
165 
166 
eof()167   bool eof() { return true; };
getRemaining() const168   const string getRemaining() const {
169     return "";
170   };
171 
getPosition() const172   uint16_t getPosition() const
173   {
174     return d_pos;
175   }
176 
skip(uint16_t n)177   void skip(uint16_t n)
178   {
179     d_pos += n;
180   }
181 
182 private:
183   uint16_t d_pos;
184   uint16_t d_startrecordpos; // needed for getBlob later on
185   uint16_t d_recordlen;      // ditto
186   uint16_t not_used; // Aligns the whole class on 8-byte boundaries
187   const pdns_string_view d_content;
188 };
189 
190 struct DNSRecord;
191 
192 class DNSRecordContent
193 {
194 public:
195   static std::shared_ptr<DNSRecordContent> mastermake(const DNSRecord &dr, PacketReader& pr);
196   static std::shared_ptr<DNSRecordContent> mastermake(const DNSRecord &dr, PacketReader& pr, uint16_t opcode);
197   static std::shared_ptr<DNSRecordContent> mastermake(uint16_t qtype, uint16_t qclass, const string& zone);
198   static string upgradeContent(const DNSName& qname, const QType& qtype, const string& content);
199 
200   virtual std::string getZoneRepresentation(bool noDot=false) const = 0;
~DNSRecordContent()201   virtual ~DNSRecordContent() {}
202   virtual void toPacket(DNSPacketWriter& pw)=0;
serialize(const DNSName & qname,bool canonic=false,bool lowerCase=false)203   virtual string serialize(const DNSName& qname, bool canonic=false, bool lowerCase=false) // it would rock if this were const, but it is too hard
204   {
205     vector<uint8_t> packet;
206     DNSPacketWriter pw(packet, g_rootdnsname, 1);
207     if(canonic)
208       pw.setCanonic(true);
209 
210     if(lowerCase)
211       pw.setLowercase(true);
212 
213     pw.startRecord(qname, this->getType());
214     this->toPacket(pw);
215 
216     string record;
217     pw.getRecordPayload(record); // needs to be called before commit()
218     return record;
219   }
220 
operator ==(const DNSRecordContent & rhs) const221   virtual bool operator==(const DNSRecordContent& rhs) const
222   {
223     return typeid(*this)==typeid(rhs) && this->getZoneRepresentation() == rhs.getZoneRepresentation();
224   }
225 
226   static shared_ptr<DNSRecordContent> deserialize(const DNSName& qname, uint16_t qtype, const string& serialized);
227 
doRecordCheck(const struct DNSRecord &)228   void doRecordCheck(const struct DNSRecord&){}
229 
230   typedef std::shared_ptr<DNSRecordContent> makerfunc_t(const struct DNSRecord& dr, PacketReader& pr);
231   typedef std::shared_ptr<DNSRecordContent> zmakerfunc_t(const string& str);
232 
regist(uint16_t cl,uint16_t ty,makerfunc_t * f,zmakerfunc_t * z,const char * name)233   static void regist(uint16_t cl, uint16_t ty, makerfunc_t* f, zmakerfunc_t* z, const char* name)
234   {
235     if(f)
236       getTypemap()[make_pair(cl,ty)]=f;
237     if(z)
238       getZmakermap()[make_pair(cl,ty)]=z;
239 
240     getT2Namemap().insert(make_pair(make_pair(cl,ty), name));
241     getN2Typemap().insert(make_pair(name, make_pair(cl,ty)));
242   }
243 
unregist(uint16_t cl,uint16_t ty)244   static void unregist(uint16_t cl, uint16_t ty)
245   {
246     pair<uint16_t, uint16_t> key=make_pair(cl, ty);
247     getTypemap().erase(key);
248     getZmakermap().erase(key);
249   }
250 
isUnknownType(const string & name)251   static bool isUnknownType(const string& name)
252   {
253     return boost::starts_with(name, "TYPE") || boost::starts_with(name, "type");
254   }
255 
TypeToNumber(const string & name)256   static uint16_t TypeToNumber(const string& name)
257   {
258     n2typemap_t::const_iterator iter = getN2Typemap().find(toUpper(name));
259     if(iter != getN2Typemap().end())
260       return iter->second.second;
261 
262     if (isUnknownType(name))
263       return (uint16_t) pdns_stou(name.substr(4));
264 
265     throw runtime_error("Unknown DNS type '"+name+"'");
266   }
267 
NumberToType(uint16_t num,uint16_t classnum=1)268   static const string NumberToType(uint16_t num, uint16_t classnum=1)
269   {
270     t2namemap_t::const_iterator iter = getT2Namemap().find(make_pair(classnum, num));
271     if(iter == getT2Namemap().end())
272       return "TYPE" + std::to_string(num);
273       //      throw runtime_error("Unknown DNS type with numerical id "+std::to_string(num));
274     return iter->second;
275   }
276 
277   virtual uint16_t getType() const = 0;
278 
279 protected:
280   typedef std::map<std::pair<uint16_t, uint16_t>, makerfunc_t* > typemap_t;
281   typedef std::map<std::pair<uint16_t, uint16_t>, zmakerfunc_t* > zmakermap_t;
282   typedef std::map<std::pair<uint16_t, uint16_t>, string > t2namemap_t;
283   typedef std::map<string, std::pair<uint16_t, uint16_t> > n2typemap_t;
284   static typemap_t& getTypemap();
285   static t2namemap_t& getT2Namemap();
286   static n2typemap_t& getN2Typemap();
287   static zmakermap_t& getZmakermap();
288 };
289 
290 struct DNSRecord
291 {
DNSRecordDNSRecord292   DNSRecord() : d_type(0), d_class(QClass::IN), d_ttl(0), d_clen(0), d_place(DNSResourceRecord::ANSWER)
293   {}
294   explicit DNSRecord(const DNSResourceRecord& rr);
295   DNSName d_name;
296   std::shared_ptr<DNSRecordContent> d_content;
297   uint16_t d_type;
298   uint16_t d_class;
299   uint32_t d_ttl;
300   uint16_t d_clen;
301   DNSResourceRecord::Place d_place;
302 
operator <DNSRecord303   bool operator<(const DNSRecord& rhs) const
304   {
305     if(tie(d_name, d_type, d_class, d_ttl) < tie(rhs.d_name, rhs.d_type, rhs.d_class, rhs.d_ttl))
306       return true;
307 
308     if(tie(d_name, d_type, d_class, d_ttl) != tie(rhs.d_name, rhs.d_type, rhs.d_class, rhs.d_ttl))
309       return false;
310 
311     string lzrp, rzrp;
312     if(d_content)
313       lzrp=toLower(d_content->getZoneRepresentation());
314     if(rhs.d_content)
315       rzrp=toLower(rhs.d_content->getZoneRepresentation());
316 
317     return lzrp < rzrp;
318   }
319 
320   // this orders in canonical order and keeps the SOA record on top
prettyCompareDNSRecord321   static bool prettyCompare(const DNSRecord& a, const DNSRecord& b)
322   {
323     auto aType = (a.d_type == QType::SOA) ? 0 : a.d_type;
324     auto bType = (b.d_type == QType::SOA) ? 0 : b.d_type;
325 
326     if(a.d_name.canonCompare(b.d_name))
327       return true;
328     if(b.d_name.canonCompare(a.d_name))
329       return false;
330 
331     if(tie(aType, a.d_class, a.d_ttl) < tie(bType, b.d_class, b.d_ttl))
332       return true;
333 
334     if(tie(aType, a.d_class, a.d_ttl) != tie(bType, b.d_class, b.d_ttl))
335       return false;
336 
337     string lzrp, rzrp;
338     if(a.d_content)
339       lzrp=toLower(a.d_content->getZoneRepresentation());
340     if(b.d_content)
341       rzrp=toLower(b.d_content->getZoneRepresentation());
342 
343     return lzrp < rzrp;
344   }
345 
346 
operator ==DNSRecord347   bool operator==(const DNSRecord& rhs) const
348   {
349     if(d_type != rhs.d_type || d_class != rhs.d_class || d_name != rhs.d_name)
350       return false;
351 
352     return *d_content == *rhs.d_content;
353   }
354 };
355 
356 struct DNSZoneRecord
357 {
358   int domain_id{-1};
359   uint8_t scopeMask{0};
360   int signttl{0};
361   DNSName wildcardname;
362   bool auth{true};
363   bool disabled{false};
364   DNSRecord dr;
365 };
366 
367 class UnknownRecordContent : public DNSRecordContent
368 {
369 public:
UnknownRecordContent(const DNSRecord & dr,PacketReader & pr)370   UnknownRecordContent(const DNSRecord& dr, PacketReader& pr)
371     : d_dr(dr)
372   {
373     pr.copyRecord(d_record, dr.d_clen);
374   }
375 
376   UnknownRecordContent(const string& zone);
377 
378   string getZoneRepresentation(bool noDot) const override;
379   void toPacket(DNSPacketWriter& pw) override;
getType() const380   uint16_t getType() const override
381   {
382     return d_dr.d_type;
383   }
384 
385 private:
386   DNSRecord d_dr;
387   vector<uint8_t> d_record;
388 };
389 
390 //! This class can be used to parse incoming packets, and is copyable
391 class MOADNSParser : public boost::noncopyable
392 {
393 public:
394   //! Parse from a string
MOADNSParser(bool query,const string & buffer)395   MOADNSParser(bool query, const string& buffer): d_tsigPos(0)
396   {
397     init(query, buffer);
398   }
399 
400   //! Parse from a pointer and length
MOADNSParser(bool query,const char * packet,unsigned int len)401   MOADNSParser(bool query, const char *packet, unsigned int len) : d_tsigPos(0)
402   {
403     init(query, pdns_string_view(packet, len));
404   }
405 
406   DNSName d_qname;
407   uint16_t d_qclass, d_qtype;
408   //uint8_t d_rcode;
409   dnsheader d_header;
410 
411   typedef vector<pair<DNSRecord, uint16_t > > answers_t;
412 
413   //! All answers contained in this packet (everything *but* the question section)
414   answers_t d_answers;
415 
getTSIGPos() const416   uint16_t getTSIGPos() const
417   {
418     return d_tsigPos;
419   }
420 
421   bool hasEDNS() const;
422 
423 private:
424   void init(bool query, const pdns_string_view& packet);
425   uint16_t d_tsigPos;
426 };
427 
428 string simpleCompress(const string& label, const string& root="");
429 void ageDNSPacket(char* packet, size_t length, uint32_t seconds);
430 void ageDNSPacket(std::string& packet, uint32_t seconds);
431 void editDNSPacketTTL(char* packet, size_t length, const std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)>& visitor);
432 uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA=nullptr);
433 uint32_t getDNSPacketLength(const char* packet, size_t length);
434 uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type);
435 bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payloadSize, uint16_t* z);
436 
437 template<typename T>
getRR(const DNSRecord & dr)438 std::shared_ptr<T> getRR(const DNSRecord& dr)
439 {
440   return std::dynamic_pointer_cast<T>(dr.d_content);
441 }
442 
443 /** Simple DNSPacketMangler. Ritual is: get a pointer into the packet and moveOffset() to beyond your needs
444  *  If you survive that, feel free to read from the pointer */
445 class DNSPacketMangler
446 {
447 public:
DNSPacketMangler(std::string & packet)448   explicit DNSPacketMangler(std::string& packet)
449     : d_packet((char*) packet.c_str()), d_length(packet.length()), d_notyouroffset(12), d_offset(d_notyouroffset)
450   {}
DNSPacketMangler(char * packet,size_t length)451   DNSPacketMangler(char* packet, size_t length)
452     : d_packet(packet), d_length(length), d_notyouroffset(12), d_offset(d_notyouroffset)
453   {}
454 
455   /*! Advances past a wire-format domain name
456    * The name is not checked for adherence to length restrictions.
457    * Compression pointers are not followed.
458    */
skipDomainName()459   void skipDomainName()
460   {
461     uint8_t len;
462     while((len=get8BitInt())) {
463       if(len >= 0xc0) { // extended label
464         get8BitInt();
465         return;
466       }
467       skipBytes(len);
468     }
469   }
470 
skipBytes(uint16_t bytes)471   void skipBytes(uint16_t bytes)
472   {
473     moveOffset(bytes);
474   }
rewindBytes(uint16_t by)475   void rewindBytes(uint16_t by)
476   {
477     rewindOffset(by);
478   }
get32BitInt()479   uint32_t get32BitInt()
480   {
481     const char* p = d_packet + d_offset;
482     moveOffset(4);
483     uint32_t ret;
484     memcpy(&ret, (void*)p, sizeof(ret));
485     return ntohl(ret);
486   }
get16BitInt()487   uint16_t get16BitInt()
488   {
489     const char* p = d_packet + d_offset;
490     moveOffset(2);
491     uint16_t ret;
492     memcpy(&ret, (void*)p, sizeof(ret));
493     return ntohs(ret);
494   }
495 
get8BitInt()496   uint8_t get8BitInt()
497   {
498     const char* p = d_packet + d_offset;
499     moveOffset(1);
500     return *p;
501   }
502 
skipRData()503   void skipRData()
504   {
505     int toskip = get16BitInt();
506     moveOffset(toskip);
507   }
508 
decreaseAndSkip32BitInt(uint32_t decrease)509   void decreaseAndSkip32BitInt(uint32_t decrease)
510   {
511     const char *p = d_packet + d_offset;
512     moveOffset(4);
513 
514     uint32_t tmp;
515     memcpy(&tmp, (void*) p, sizeof(tmp));
516     tmp = ntohl(tmp);
517     if (tmp > decrease) {
518       tmp -= decrease;
519     } else {
520       tmp = 0;
521     }
522     tmp = htonl(tmp);
523     memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp));
524   }
525 
setAndSkip32BitInt(uint32_t value)526   void setAndSkip32BitInt(uint32_t value)
527   {
528     moveOffset(4);
529 
530     value = htonl(value);
531     memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value));
532   }
533 
getOffset() const534   uint32_t getOffset() const
535   {
536     return d_offset;
537   }
538 
539 private:
moveOffset(uint16_t by)540   void moveOffset(uint16_t by)
541   {
542     d_notyouroffset += by;
543     if(d_notyouroffset > d_length)
544       throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > "
545       + std::to_string(d_length) );
546   }
547 
rewindOffset(uint16_t by)548   void rewindOffset(uint16_t by)
549   {
550     if(d_notyouroffset < by)
551       throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
552                               + std::to_string(by));
553     d_notyouroffset -= by;
554     if(d_notyouroffset < 12)
555       throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
556                               + std::to_string(12));
557   }
558 
559   char* d_packet;
560   size_t d_length;
561 
562   uint32_t d_notyouroffset;  // only 'moveOffset' can touch this
563   const uint32_t&  d_offset; // look.. but don't touch
564 };
565