1package dns 2 3import ( 4 "time" 5) 6 7// Envelope is used when doing a zone transfer with a remote server. 8type Envelope struct { 9 RR []RR // The set of RRs in the answer section of the xfr reply message. 10 Error error // If something went wrong, this contains the error. 11} 12 13// A Transfer defines parameters that are used during a zone transfer. 14type Transfer struct { 15 *Conn 16 DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds 17 ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds 18 WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds 19 TsigSecret map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be fully qualified 20 tsigTimersOnly bool 21} 22 23// Think we need to away to stop the transfer 24 25// In performs an incoming transfer with the server in a. 26// If you would like to set the source IP, or some other attribute 27// of a Dialer for a Transfer, you can do so by specifying the attributes 28// in the Transfer.Conn: 29// 30// d := net.Dialer{LocalAddr: transfer_source} 31// con, err := d.Dial("tcp", master) 32// dnscon := &dns.Conn{Conn:con} 33// transfer = &dns.Transfer{Conn: dnscon} 34// channel, err := transfer.In(message, master) 35// 36func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) { 37 timeout := dnsTimeout 38 if t.DialTimeout != 0 { 39 timeout = t.DialTimeout 40 } 41 if t.Conn == nil { 42 t.Conn, err = DialTimeout("tcp", a, timeout) 43 if err != nil { 44 return nil, err 45 } 46 } 47 if err := t.WriteMsg(q); err != nil { 48 return nil, err 49 } 50 env = make(chan *Envelope) 51 go func() { 52 if q.Question[0].Qtype == TypeAXFR { 53 go t.inAxfr(q.Id, env) 54 return 55 } 56 if q.Question[0].Qtype == TypeIXFR { 57 go t.inIxfr(q.Id, env) 58 return 59 } 60 }() 61 return env, nil 62} 63 64func (t *Transfer) inAxfr(id uint16, c chan *Envelope) { 65 first := true 66 defer t.Close() 67 defer close(c) 68 timeout := dnsTimeout 69 if t.ReadTimeout != 0 { 70 timeout = t.ReadTimeout 71 } 72 for { 73 t.Conn.SetReadDeadline(time.Now().Add(timeout)) 74 in, err := t.ReadMsg() 75 if err != nil { 76 c <- &Envelope{nil, err} 77 return 78 } 79 if id != in.Id { 80 c <- &Envelope{in.Answer, ErrId} 81 return 82 } 83 if first { 84 if !isSOAFirst(in) { 85 c <- &Envelope{in.Answer, ErrSoa} 86 return 87 } 88 first = !first 89 // only one answer that is SOA, receive more 90 if len(in.Answer) == 1 { 91 t.tsigTimersOnly = true 92 c <- &Envelope{in.Answer, nil} 93 continue 94 } 95 } 96 97 if !first { 98 t.tsigTimersOnly = true // Subsequent envelopes use this. 99 if isSOALast(in) { 100 c <- &Envelope{in.Answer, nil} 101 return 102 } 103 c <- &Envelope{in.Answer, nil} 104 } 105 } 106} 107 108func (t *Transfer) inIxfr(id uint16, c chan *Envelope) { 109 serial := uint32(0) // The first serial seen is the current server serial 110 first := true 111 defer t.Close() 112 defer close(c) 113 timeout := dnsTimeout 114 if t.ReadTimeout != 0 { 115 timeout = t.ReadTimeout 116 } 117 for { 118 t.SetReadDeadline(time.Now().Add(timeout)) 119 in, err := t.ReadMsg() 120 if err != nil { 121 c <- &Envelope{nil, err} 122 return 123 } 124 if id != in.Id { 125 c <- &Envelope{in.Answer, ErrId} 126 return 127 } 128 if first { 129 // A single SOA RR signals "no changes" 130 if len(in.Answer) == 1 && isSOAFirst(in) { 131 c <- &Envelope{in.Answer, nil} 132 return 133 } 134 135 // Check if the returned answer is ok 136 if !isSOAFirst(in) { 137 c <- &Envelope{in.Answer, ErrSoa} 138 return 139 } 140 // This serial is important 141 serial = in.Answer[0].(*SOA).Serial 142 first = !first 143 } 144 145 // Now we need to check each message for SOA records, to see what we need to do 146 if !first { 147 t.tsigTimersOnly = true 148 // If the last record in the IXFR contains the servers' SOA, we should quit 149 if v, ok := in.Answer[len(in.Answer)-1].(*SOA); ok { 150 if v.Serial == serial { 151 c <- &Envelope{in.Answer, nil} 152 return 153 } 154 } 155 c <- &Envelope{in.Answer, nil} 156 } 157 } 158} 159 160// Out performs an outgoing transfer with the client connecting in w. 161// Basic use pattern: 162// 163// ch := make(chan *dns.Envelope) 164// tr := new(dns.Transfer) 165// go tr.Out(w, r, ch) 166// ch <- &dns.Envelope{RR: []dns.RR{soa, rr1, rr2, rr3, soa}} 167// close(ch) 168// w.Hijack() 169// // w.Close() // Client closes connection 170// 171// The server is responsible for sending the correct sequence of RRs through the 172// channel ch. 173func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error { 174 for x := range ch { 175 r := new(Msg) 176 // Compress? 177 r.SetReply(q) 178 r.Authoritative = true 179 // assume it fits TODO(miek): fix 180 r.Answer = append(r.Answer, x.RR...) 181 if err := w.WriteMsg(r); err != nil { 182 return err 183 } 184 } 185 w.TsigTimersOnly(true) 186 return nil 187} 188 189// ReadMsg reads a message from the transfer connection t. 190func (t *Transfer) ReadMsg() (*Msg, error) { 191 m := new(Msg) 192 p := make([]byte, MaxMsgSize) 193 n, err := t.Read(p) 194 if err != nil && n == 0 { 195 return nil, err 196 } 197 p = p[:n] 198 if err := m.Unpack(p); err != nil { 199 return nil, err 200 } 201 if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil { 202 if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { 203 return m, ErrSecret 204 } 205 // Need to work on the original message p, as that was used to calculate the tsig. 206 err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) 207 t.tsigRequestMAC = ts.MAC 208 } 209 return m, err 210} 211 212// WriteMsg writes a message through the transfer connection t. 213func (t *Transfer) WriteMsg(m *Msg) (err error) { 214 var out []byte 215 if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil { 216 if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { 217 return ErrSecret 218 } 219 out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) 220 } else { 221 out, err = m.Pack() 222 } 223 if err != nil { 224 return err 225 } 226 if _, err = t.Write(out); err != nil { 227 return err 228 } 229 return nil 230} 231 232func isSOAFirst(in *Msg) bool { 233 if len(in.Answer) > 0 { 234 return in.Answer[0].Header().Rrtype == TypeSOA 235 } 236 return false 237} 238 239func isSOALast(in *Msg) bool { 240 if len(in.Answer) > 0 { 241 return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA 242 } 243 return false 244} 245