1package dns
2
3import (
4	"strings"
5	"sync"
6)
7
8// ServeMux is an DNS request multiplexer. It matches the zone name of
9// each incoming request against a list of registered patterns add calls
10// the handler for the pattern that most closely matches the zone name.
11//
12// ServeMux is DNSSEC aware, meaning that queries for the DS record are
13// redirected to the parent zone (if that is also registered), otherwise
14// the child gets the query.
15//
16// ServeMux is also safe for concurrent access from multiple goroutines.
17//
18// The zero ServeMux is empty and ready for use.
19type ServeMux struct {
20	z map[string]Handler
21	m sync.RWMutex
22}
23
24// NewServeMux allocates and returns a new ServeMux.
25func NewServeMux() *ServeMux {
26	return new(ServeMux)
27}
28
29// DefaultServeMux is the default ServeMux used by Serve.
30var DefaultServeMux = NewServeMux()
31
32func (mux *ServeMux) match(q string, t uint16) Handler {
33	mux.m.RLock()
34	defer mux.m.RUnlock()
35	if mux.z == nil {
36		return nil
37	}
38
39	var handler Handler
40
41	// TODO(tmthrgd): Once https://go-review.googlesource.com/c/go/+/137575
42	// lands in a go release, replace the following with strings.ToLower.
43	var sb strings.Builder
44	for i := 0; i < len(q); i++ {
45		c := q[i]
46		if !(c >= 'A' && c <= 'Z') {
47			continue
48		}
49
50		sb.Grow(len(q))
51		sb.WriteString(q[:i])
52
53		for ; i < len(q); i++ {
54			c := q[i]
55			if c >= 'A' && c <= 'Z' {
56				c += 'a' - 'A'
57			}
58
59			sb.WriteByte(c)
60		}
61
62		q = sb.String()
63		break
64	}
65
66	for off, end := 0, false; !end; off, end = NextLabel(q, off) {
67		if h, ok := mux.z[q[off:]]; ok {
68			if t != TypeDS {
69				return h
70			}
71			// Continue for DS to see if we have a parent too, if so delegate to the parent
72			handler = h
73		}
74	}
75
76	// Wildcard match, if we have found nothing try the root zone as a last resort.
77	if h, ok := mux.z["."]; ok {
78		return h
79	}
80
81	return handler
82}
83
84// Handle adds a handler to the ServeMux for pattern.
85func (mux *ServeMux) Handle(pattern string, handler Handler) {
86	if pattern == "" {
87		panic("dns: invalid pattern " + pattern)
88	}
89	mux.m.Lock()
90	if mux.z == nil {
91		mux.z = make(map[string]Handler)
92	}
93	mux.z[Fqdn(pattern)] = handler
94	mux.m.Unlock()
95}
96
97// HandleFunc adds a handler function to the ServeMux for pattern.
98func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
99	mux.Handle(pattern, HandlerFunc(handler))
100}
101
102// HandleRemove deregisters the handler specific for pattern from the ServeMux.
103func (mux *ServeMux) HandleRemove(pattern string) {
104	if pattern == "" {
105		panic("dns: invalid pattern " + pattern)
106	}
107	mux.m.Lock()
108	delete(mux.z, Fqdn(pattern))
109	mux.m.Unlock()
110}
111
112// ServeDNS dispatches the request to the handler whose pattern most
113// closely matches the request message.
114//
115// ServeDNS is DNSSEC aware, meaning that queries for the DS record
116// are redirected to the parent zone (if that is also registered),
117// otherwise the child gets the query.
118//
119// If no handler is found, or there is no question, a standard SERVFAIL
120// message is returned
121func (mux *ServeMux) ServeDNS(w ResponseWriter, req *Msg) {
122	var h Handler
123	if len(req.Question) >= 1 { // allow more than one question
124		h = mux.match(req.Question[0].Name, req.Question[0].Qtype)
125	}
126
127	if h != nil {
128		h.ServeDNS(w, req)
129	} else {
130		HandleFailed(w, req)
131	}
132}
133
134// Handle registers the handler with the given pattern
135// in the DefaultServeMux. The documentation for
136// ServeMux explains how patterns are matched.
137func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
138
139// HandleRemove deregisters the handle with the given pattern
140// in the DefaultServeMux.
141func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }
142
143// HandleFunc registers the handler function with the given pattern
144// in the DefaultServeMux.
145func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
146	DefaultServeMux.HandleFunc(pattern, handler)
147}
148