1package dns
2
3import (
4	"fmt"
5	"time"
6)
7
8// Envelope is used when doing a zone transfer with a remote server.
9type Envelope struct {
10	RR    []RR  // The set of RRs in the answer section of the xfr reply message.
11	Error error // If something went wrong, this contains the error.
12}
13
14// A Transfer defines parameters that are used during a zone transfer.
15type Transfer struct {
16	*Conn
17	DialTimeout    time.Duration     // net.DialTimeout, defaults to 2 seconds
18	ReadTimeout    time.Duration     // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
19	WriteTimeout   time.Duration     // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
20	TsigSecret     map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
21	tsigTimersOnly bool
22}
23
24// Think we need to away to stop the transfer
25
26// In performs an incoming transfer with the server in a.
27// If you would like to set the source IP, or some other attribute
28// of a Dialer for a Transfer, you can do so by specifying the attributes
29// in the Transfer.Conn:
30//
31//	d := net.Dialer{LocalAddr: transfer_source}
32//	con, err := d.Dial("tcp", master)
33//	dnscon := &dns.Conn{Conn:con}
34//	transfer = &dns.Transfer{Conn: dnscon}
35//	channel, err := transfer.In(message, master)
36//
37func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) {
38	timeout := dnsTimeout
39	if t.DialTimeout != 0 {
40		timeout = t.DialTimeout
41	}
42	if t.Conn == nil {
43		t.Conn, err = DialTimeout("tcp", a, timeout)
44		if err != nil {
45			return nil, err
46		}
47	}
48	if err := t.WriteMsg(q); err != nil {
49		return nil, err
50	}
51	env = make(chan *Envelope)
52	go func() {
53		if q.Question[0].Qtype == TypeAXFR {
54			go t.inAxfr(q, env)
55			return
56		}
57		if q.Question[0].Qtype == TypeIXFR {
58			go t.inIxfr(q, env)
59			return
60		}
61	}()
62	return env, nil
63}
64
65func (t *Transfer) inAxfr(q *Msg, c chan *Envelope) {
66	first := true
67	defer t.Close()
68	defer close(c)
69	timeout := dnsTimeout
70	if t.ReadTimeout != 0 {
71		timeout = t.ReadTimeout
72	}
73	for {
74		t.Conn.SetReadDeadline(time.Now().Add(timeout))
75		in, err := t.ReadMsg()
76		if err != nil {
77			c <- &Envelope{nil, err}
78			return
79		}
80		if q.Id != in.Id {
81			c <- &Envelope{in.Answer, ErrId}
82			return
83		}
84		if first {
85			if in.Rcode != RcodeSuccess {
86				c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}}
87				return
88			}
89			if !isSOAFirst(in) {
90				c <- &Envelope{in.Answer, ErrSoa}
91				return
92			}
93			first = !first
94			// only one answer that is SOA, receive more
95			if len(in.Answer) == 1 {
96				t.tsigTimersOnly = true
97				c <- &Envelope{in.Answer, nil}
98				continue
99			}
100		}
101
102		if !first {
103			t.tsigTimersOnly = true // Subsequent envelopes use this.
104			if isSOALast(in) {
105				c <- &Envelope{in.Answer, nil}
106				return
107			}
108			c <- &Envelope{in.Answer, nil}
109		}
110	}
111}
112
113func (t *Transfer) inIxfr(q *Msg, c chan *Envelope) {
114	serial := uint32(0) // The first serial seen is the current server serial
115	axfr := true
116	n := 0
117	qser := q.Ns[0].(*SOA).Serial
118	defer t.Close()
119	defer close(c)
120	timeout := dnsTimeout
121	if t.ReadTimeout != 0 {
122		timeout = t.ReadTimeout
123	}
124	for {
125		t.SetReadDeadline(time.Now().Add(timeout))
126		in, err := t.ReadMsg()
127		if err != nil {
128			c <- &Envelope{nil, err}
129			return
130		}
131		if q.Id != in.Id {
132			c <- &Envelope{in.Answer, ErrId}
133			return
134		}
135		if in.Rcode != RcodeSuccess {
136			c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}}
137			return
138		}
139		if n == 0 {
140			// Check if the returned answer is ok
141			if !isSOAFirst(in) {
142				c <- &Envelope{in.Answer, ErrSoa}
143				return
144			}
145			// This serial is important
146			serial = in.Answer[0].(*SOA).Serial
147			// Check if there are no changes in zone
148			if qser >= serial {
149				c <- &Envelope{in.Answer, nil}
150				return
151			}
152		}
153		// Now we need to check each message for SOA records, to see what we need to do
154		t.tsigTimersOnly = true
155		for _, rr := range in.Answer {
156			if v, ok := rr.(*SOA); ok {
157				if v.Serial == serial {
158					n++
159					// quit if it's a full axfr or the the servers' SOA is repeated the third time
160					if axfr && n == 2 || n == 3 {
161						c <- &Envelope{in.Answer, nil}
162						return
163					}
164				} else if axfr {
165					// it's an ixfr
166					axfr = false
167				}
168			}
169		}
170		c <- &Envelope{in.Answer, nil}
171	}
172}
173
174// Out performs an outgoing transfer with the client connecting in w.
175// Basic use pattern:
176//
177//	ch := make(chan *dns.Envelope)
178//	tr := new(dns.Transfer)
179//	go tr.Out(w, r, ch)
180//	ch <- &dns.Envelope{RR: []dns.RR{soa, rr1, rr2, rr3, soa}}
181//	close(ch)
182//	w.Hijack()
183//	// w.Close() // Client closes connection
184//
185// The server is responsible for sending the correct sequence of RRs through the
186// channel ch.
187func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error {
188	for x := range ch {
189		r := new(Msg)
190		// Compress?
191		r.SetReply(q)
192		r.Authoritative = true
193		// assume it fits TODO(miek): fix
194		r.Answer = append(r.Answer, x.RR...)
195		if err := w.WriteMsg(r); err != nil {
196			return err
197		}
198	}
199	w.TsigTimersOnly(true)
200	return nil
201}
202
203// ReadMsg reads a message from the transfer connection t.
204func (t *Transfer) ReadMsg() (*Msg, error) {
205	m := new(Msg)
206	p := make([]byte, MaxMsgSize)
207	n, err := t.Read(p)
208	if err != nil && n == 0 {
209		return nil, err
210	}
211	p = p[:n]
212	if err := m.Unpack(p); err != nil {
213		return nil, err
214	}
215	if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
216		if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
217			return m, ErrSecret
218		}
219		// Need to work on the original message p, as that was used to calculate the tsig.
220		err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
221		t.tsigRequestMAC = ts.MAC
222	}
223	return m, err
224}
225
226// WriteMsg writes a message through the transfer connection t.
227func (t *Transfer) WriteMsg(m *Msg) (err error) {
228	var out []byte
229	if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
230		if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
231			return ErrSecret
232		}
233		out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
234	} else {
235		out, err = m.Pack()
236	}
237	if err != nil {
238		return err
239	}
240	if _, err = t.Write(out); err != nil {
241		return err
242	}
243	return nil
244}
245
246func isSOAFirst(in *Msg) bool {
247	if len(in.Answer) > 0 {
248		return in.Answer[0].Header().Rrtype == TypeSOA
249	}
250	return false
251}
252
253func isSOALast(in *Msg) bool {
254	if len(in.Answer) > 0 {
255		return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA
256	}
257	return false
258}
259
260const errXFR = "bad xfr rcode: %d"
261