1package overlay
2
3import (
4	"bytes"
5	"encoding/binary"
6	"encoding/hex"
7	"fmt"
8	"hash/fnv"
9	"net"
10	"sync"
11	"syscall"
12
13	"strconv"
14
15	"github.com/sirupsen/logrus"
16	"github.com/docker/libnetwork/iptables"
17	"github.com/docker/libnetwork/ns"
18	"github.com/docker/libnetwork/types"
19	"github.com/vishvananda/netlink"
20)
21
22const (
23	mark         = uint32(0xD0C4E3)
24	timeout      = 30
25	pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8)
26)
27
28const (
29	forward = iota + 1
30	reverse
31	bidir
32)
33
34type key struct {
35	value []byte
36	tag   uint32
37}
38
39func (k *key) String() string {
40	if k != nil {
41		return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag)
42	}
43	return ""
44}
45
46type spi struct {
47	forward int
48	reverse int
49}
50
51func (s *spi) String() string {
52	return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse))
53}
54
55type encrMap struct {
56	nodes map[string][]*spi
57	sync.Mutex
58}
59
60func (e *encrMap) String() string {
61	e.Lock()
62	defer e.Unlock()
63	b := new(bytes.Buffer)
64	for k, v := range e.nodes {
65		b.WriteString("\n")
66		b.WriteString(k)
67		b.WriteString(":")
68		b.WriteString("[")
69		for _, s := range v {
70			b.WriteString(s.String())
71			b.WriteString(",")
72		}
73		b.WriteString("]")
74
75	}
76	return b.String()
77}
78
79func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error {
80	logrus.Debugf("checkEncryption(%s, %v, %d, %t)", nid[0:7], rIP, vxlanID, isLocal)
81
82	n := d.network(nid)
83	if n == nil || !n.secure {
84		return nil
85	}
86
87	if len(d.keys) == 0 {
88		return types.ForbiddenErrorf("encryption key is not present")
89	}
90
91	lIP := net.ParseIP(d.bindAddress)
92	aIP := net.ParseIP(d.advertiseAddress)
93	nodes := map[string]net.IP{}
94
95	switch {
96	case isLocal:
97		if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool {
98			if !aIP.Equal(pEntry.vtep) {
99				nodes[pEntry.vtep.String()] = pEntry.vtep
100			}
101			return false
102		}); err != nil {
103			logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %s: %v", nid[0:5], err)
104		}
105	default:
106		if len(d.network(nid).endpoints) > 0 {
107			nodes[rIP.String()] = rIP
108		}
109	}
110
111	logrus.Debugf("List of nodes: %s", nodes)
112
113	if add {
114		for _, rIP := range nodes {
115			if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil {
116				logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err)
117			}
118		}
119	} else {
120		if len(nodes) == 0 {
121			if err := removeEncryption(lIP, rIP, d.secMap); err != nil {
122				logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err)
123			}
124		}
125	}
126
127	return nil
128}
129
130func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error {
131	logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP)
132	rIPs := remoteIP.String()
133
134	indices := make([]*spi, 0, len(keys))
135
136	err := programMangle(vni, true)
137	if err != nil {
138		logrus.Warn(err)
139	}
140
141	for i, k := range keys {
142		spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)}
143		dir := reverse
144		if i == 0 {
145			dir = bidir
146		}
147		fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true)
148		if err != nil {
149			logrus.Warn(err)
150		}
151		indices = append(indices, spis)
152		if i != 0 {
153			continue
154		}
155		err = programSP(fSA, rSA, true)
156		if err != nil {
157			logrus.Warn(err)
158		}
159	}
160
161	em.Lock()
162	em.nodes[rIPs] = indices
163	em.Unlock()
164
165	return nil
166}
167
168func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error {
169	em.Lock()
170	indices, ok := em.nodes[remoteIP.String()]
171	em.Unlock()
172	if !ok {
173		return nil
174	}
175	for i, idxs := range indices {
176		dir := reverse
177		if i == 0 {
178			dir = bidir
179		}
180		fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false)
181		if err != nil {
182			logrus.Warn(err)
183		}
184		if i != 0 {
185			continue
186		}
187		err = programSP(fSA, rSA, false)
188		if err != nil {
189			logrus.Warn(err)
190		}
191	}
192	return nil
193}
194
195func programMangle(vni uint32, add bool) (err error) {
196	var (
197		p      = strconv.FormatUint(uint64(vxlanPort), 10)
198		c      = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
199		m      = strconv.FormatUint(uint64(mark), 10)
200		chain  = "OUTPUT"
201		rule   = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m}
202		a      = "-A"
203		action = "install"
204	)
205
206	if add == iptables.Exists(iptables.Mangle, chain, rule...) {
207		return
208	}
209
210	if !add {
211		a = "-D"
212		action = "remove"
213	}
214
215	if err = iptables.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil {
216		logrus.Warnf("could not %s mangle rule: %v", action, err)
217	}
218
219	return
220}
221
222func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) {
223	var (
224		action      = "Removing"
225		xfrmProgram = ns.NlHandle().XfrmStateDel
226	)
227
228	if add {
229		action = "Adding"
230		xfrmProgram = ns.NlHandle().XfrmStateAdd
231	}
232
233	if dir&reverse > 0 {
234		rSA = &netlink.XfrmState{
235			Src:   remoteIP,
236			Dst:   localIP,
237			Proto: netlink.XFRM_PROTO_ESP,
238			Spi:   spi.reverse,
239			Mode:  netlink.XFRM_MODE_TRANSPORT,
240		}
241		if add {
242			rSA.Aead = buildAeadAlgo(k, spi.reverse)
243		}
244
245		exists, err := saExists(rSA)
246		if err != nil {
247			exists = !add
248		}
249
250		if add != exists {
251			logrus.Debugf("%s: rSA{%s}", action, rSA)
252			if err := xfrmProgram(rSA); err != nil {
253				logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err)
254			}
255		}
256	}
257
258	if dir&forward > 0 {
259		fSA = &netlink.XfrmState{
260			Src:   localIP,
261			Dst:   remoteIP,
262			Proto: netlink.XFRM_PROTO_ESP,
263			Spi:   spi.forward,
264			Mode:  netlink.XFRM_MODE_TRANSPORT,
265		}
266		if add {
267			fSA.Aead = buildAeadAlgo(k, spi.forward)
268		}
269
270		exists, err := saExists(fSA)
271		if err != nil {
272			exists = !add
273		}
274
275		if add != exists {
276			logrus.Debugf("%s fSA{%s}", action, fSA)
277			if err := xfrmProgram(fSA); err != nil {
278				logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err)
279			}
280		}
281	}
282
283	return
284}
285
286func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error {
287	action := "Removing"
288	xfrmProgram := ns.NlHandle().XfrmPolicyDel
289	if add {
290		action = "Adding"
291		xfrmProgram = ns.NlHandle().XfrmPolicyAdd
292	}
293
294	fullMask := net.CIDRMask(8*len(fSA.Src), 8*len(fSA.Src))
295
296	fPol := &netlink.XfrmPolicy{
297		Src:     &net.IPNet{IP: fSA.Src, Mask: fullMask},
298		Dst:     &net.IPNet{IP: fSA.Dst, Mask: fullMask},
299		Dir:     netlink.XFRM_DIR_OUT,
300		Proto:   17,
301		DstPort: 4789,
302		Mark: &netlink.XfrmMark{
303			Value: mark,
304		},
305		Tmpls: []netlink.XfrmPolicyTmpl{
306			{
307				Src:   fSA.Src,
308				Dst:   fSA.Dst,
309				Proto: netlink.XFRM_PROTO_ESP,
310				Mode:  netlink.XFRM_MODE_TRANSPORT,
311				Spi:   fSA.Spi,
312			},
313		},
314	}
315
316	exists, err := spExists(fPol)
317	if err != nil {
318		exists = !add
319	}
320
321	if add != exists {
322		logrus.Debugf("%s fSP{%s}", action, fPol)
323		if err := xfrmProgram(fPol); err != nil {
324			logrus.Warnf("%s fSP{%s}: %v", action, fPol, err)
325		}
326	}
327
328	return nil
329}
330
331func saExists(sa *netlink.XfrmState) (bool, error) {
332	_, err := ns.NlHandle().XfrmStateGet(sa)
333	switch err {
334	case nil:
335		return true, nil
336	case syscall.ESRCH:
337		return false, nil
338	default:
339		err = fmt.Errorf("Error while checking for SA existence: %v", err)
340		logrus.Warn(err)
341		return false, err
342	}
343}
344
345func spExists(sp *netlink.XfrmPolicy) (bool, error) {
346	_, err := ns.NlHandle().XfrmPolicyGet(sp)
347	switch err {
348	case nil:
349		return true, nil
350	case syscall.ENOENT:
351		return false, nil
352	default:
353		err = fmt.Errorf("Error while checking for SP existence: %v", err)
354		logrus.Warn(err)
355		return false, err
356	}
357}
358
359func buildSPI(src, dst net.IP, st uint32) int {
360	b := make([]byte, 4)
361	binary.BigEndian.PutUint32(b, st)
362	h := fnv.New32a()
363	h.Write(src)
364	h.Write(b)
365	h.Write(dst)
366	return int(binary.BigEndian.Uint32(h.Sum(nil)))
367}
368
369func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo {
370	salt := make([]byte, 4)
371	binary.BigEndian.PutUint32(salt, uint32(s))
372	return &netlink.XfrmStateAlgo{
373		Name:   "rfc4106(gcm(aes))",
374		Key:    append(k.value, salt...),
375		ICVLen: 64,
376	}
377}
378
379func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error {
380	d.secMap.Lock()
381	for node, indices := range d.secMap.nodes {
382		idxs, stop := f(node, indices)
383		if idxs != nil {
384			d.secMap.nodes[node] = idxs
385		}
386		if stop {
387			break
388		}
389	}
390	d.secMap.Unlock()
391	return nil
392}
393
394func (d *driver) setKeys(keys []*key) error {
395	// Accept the encryption keys and clear any stale encryption map
396	d.Lock()
397	d.keys = keys
398	d.secMap = &encrMap{nodes: map[string][]*spi{}}
399	d.Unlock()
400	logrus.Debugf("Initial encryption keys: %v", d.keys)
401	return nil
402}
403
404// updateKeys allows to add a new key and/or change the primary key and/or prune an existing key
405// The primary key is the key used in transmission and will go in first position in the list.
406func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
407	logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey)
408
409	logrus.Debugf("Current: %v", d.keys)
410
411	var (
412		newIdx = -1
413		priIdx = -1
414		delIdx = -1
415		lIP    = net.ParseIP(d.bindAddress)
416	)
417
418	d.Lock()
419	// add new
420	if newKey != nil {
421		d.keys = append(d.keys, newKey)
422		newIdx += len(d.keys)
423	}
424	for i, k := range d.keys {
425		if primary != nil && k.tag == primary.tag {
426			priIdx = i
427		}
428		if pruneKey != nil && k.tag == pruneKey.tag {
429			delIdx = i
430		}
431	}
432	d.Unlock()
433
434	if (newKey != nil && newIdx == -1) ||
435		(primary != nil && priIdx == -1) ||
436		(pruneKey != nil && delIdx == -1) {
437		return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+
438			"(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx)
439	}
440
441	d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) {
442		rIP := net.ParseIP(rIPs)
443		return updateNodeKey(lIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false
444	})
445
446	d.Lock()
447	// swap primary
448	if priIdx != -1 {
449		swp := d.keys[0]
450		d.keys[0] = d.keys[priIdx]
451		d.keys[priIdx] = swp
452	}
453	// prune
454	if delIdx != -1 {
455		if delIdx == 0 {
456			delIdx = priIdx
457		}
458		d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...)
459	}
460	d.Unlock()
461
462	logrus.Debugf("Updated: %v", d.keys)
463
464	return nil
465}
466
467/********************************************************
468 * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1
469 * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1
470 * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2
471 *********************************************************/
472
473// Spis and keys are sorted in such away the one in position 0 is the primary
474func updateNodeKey(lIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi {
475	logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx)
476
477	spis := idxs
478	logrus.Debugf("Current: %v", spis)
479
480	// add new
481	if newIdx != -1 {
482		spis = append(spis, &spi{
483			forward: buildSPI(lIP, rIP, curKeys[newIdx].tag),
484			reverse: buildSPI(rIP, lIP, curKeys[newIdx].tag),
485		})
486	}
487
488	if delIdx != -1 {
489		// -rSA0
490		programSA(lIP, rIP, spis[delIdx], nil, reverse, false)
491	}
492
493	if newIdx > -1 {
494		// +RSA2
495		programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true)
496	}
497
498	if priIdx > 0 {
499		// +fSA2
500		fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true)
501
502		// +fSP2, -fSP1
503		fullMask := net.CIDRMask(8*len(fSA2.Src), 8*len(fSA2.Src))
504		fSP1 := &netlink.XfrmPolicy{
505			Src:     &net.IPNet{IP: fSA2.Src, Mask: fullMask},
506			Dst:     &net.IPNet{IP: fSA2.Dst, Mask: fullMask},
507			Dir:     netlink.XFRM_DIR_OUT,
508			Proto:   17,
509			DstPort: 4789,
510			Mark: &netlink.XfrmMark{
511				Value: mark,
512			},
513			Tmpls: []netlink.XfrmPolicyTmpl{
514				{
515					Src:   fSA2.Src,
516					Dst:   fSA2.Dst,
517					Proto: netlink.XFRM_PROTO_ESP,
518					Mode:  netlink.XFRM_MODE_TRANSPORT,
519					Spi:   fSA2.Spi,
520				},
521			},
522		}
523		logrus.Debugf("Updating fSP{%s}", fSP1)
524		if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil {
525			logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err)
526		}
527
528		// -fSA1
529		programSA(lIP, rIP, spis[0], nil, forward, false)
530	}
531
532	// swap
533	if priIdx > 0 {
534		swp := spis[0]
535		spis[0] = spis[priIdx]
536		spis[priIdx] = swp
537	}
538	// prune
539	if delIdx != -1 {
540		if delIdx == 0 {
541			delIdx = priIdx
542		}
543		spis = append(spis[:delIdx], spis[delIdx+1:]...)
544	}
545
546	logrus.Debugf("Updated: %v", spis)
547
548	return spis
549}
550
551func (n *network) maxMTU() int {
552	mtu := 1500
553	if n.mtu != 0 {
554		mtu = n.mtu
555	}
556	mtu -= vxlanEncap
557	if n.secure {
558		// In case of encryption account for the
559		// esp packet espansion and padding
560		mtu -= pktExpansion
561		mtu -= (mtu % 4)
562	}
563	return mtu
564}
565