1package overlay
2
3import (
4	"bytes"
5	"encoding/binary"
6	"encoding/hex"
7	"fmt"
8	"hash/fnv"
9	"net"
10	"sync"
11	"syscall"
12
13	"strconv"
14
15	"github.com/docker/libnetwork/iptables"
16	"github.com/docker/libnetwork/ns"
17	"github.com/docker/libnetwork/types"
18	"github.com/sirupsen/logrus"
19	"github.com/vishvananda/netlink"
20)
21
22const (
23	r            = 0xD0C4E3
24	pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8)
25)
26
27const (
28	forward = iota + 1
29	reverse
30	bidir
31)
32
33var spMark = netlink.XfrmMark{Value: uint32(r), Mask: 0xffffffff}
34
35type key struct {
36	value []byte
37	tag   uint32
38}
39
40func (k *key) String() string {
41	if k != nil {
42		return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag)
43	}
44	return ""
45}
46
47type spi struct {
48	forward int
49	reverse int
50}
51
52func (s *spi) String() string {
53	return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse))
54}
55
56type encrMap struct {
57	nodes map[string][]*spi
58	sync.Mutex
59}
60
61func (e *encrMap) String() string {
62	e.Lock()
63	defer e.Unlock()
64	b := new(bytes.Buffer)
65	for k, v := range e.nodes {
66		b.WriteString("\n")
67		b.WriteString(k)
68		b.WriteString(":")
69		b.WriteString("[")
70		for _, s := range v {
71			b.WriteString(s.String())
72			b.WriteString(",")
73		}
74		b.WriteString("]")
75
76	}
77	return b.String()
78}
79
80func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error {
81	logrus.Debugf("checkEncryption(%s, %v, %d, %t)", nid[0:7], rIP, vxlanID, isLocal)
82
83	n := d.network(nid)
84	if n == nil || !n.secure {
85		return nil
86	}
87
88	if len(d.keys) == 0 {
89		return types.ForbiddenErrorf("encryption key is not present")
90	}
91
92	lIP := net.ParseIP(d.bindAddress)
93	aIP := net.ParseIP(d.advertiseAddress)
94	nodes := map[string]net.IP{}
95
96	switch {
97	case isLocal:
98		if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool {
99			if !aIP.Equal(pEntry.vtep) {
100				nodes[pEntry.vtep.String()] = pEntry.vtep
101			}
102			return false
103		}); err != nil {
104			logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %s: %v", nid[0:5], err)
105		}
106	default:
107		if len(d.network(nid).endpoints) > 0 {
108			nodes[rIP.String()] = rIP
109		}
110	}
111
112	logrus.Debugf("List of nodes: %s", nodes)
113
114	if add {
115		for _, rIP := range nodes {
116			if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil {
117				logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err)
118			}
119		}
120	} else {
121		if len(nodes) == 0 {
122			if err := removeEncryption(lIP, rIP, d.secMap); err != nil {
123				logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err)
124			}
125		}
126	}
127
128	return nil
129}
130
131func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error {
132	logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP)
133	rIPs := remoteIP.String()
134
135	indices := make([]*spi, 0, len(keys))
136
137	err := programMangle(vni, true)
138	if err != nil {
139		logrus.Warn(err)
140	}
141
142	err = programInput(vni, true)
143	if err != nil {
144		logrus.Warn(err)
145	}
146
147	for i, k := range keys {
148		spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)}
149		dir := reverse
150		if i == 0 {
151			dir = bidir
152		}
153		fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true)
154		if err != nil {
155			logrus.Warn(err)
156		}
157		indices = append(indices, spis)
158		if i != 0 {
159			continue
160		}
161		err = programSP(fSA, rSA, true)
162		if err != nil {
163			logrus.Warn(err)
164		}
165	}
166
167	em.Lock()
168	em.nodes[rIPs] = indices
169	em.Unlock()
170
171	return nil
172}
173
174func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error {
175	em.Lock()
176	indices, ok := em.nodes[remoteIP.String()]
177	em.Unlock()
178	if !ok {
179		return nil
180	}
181	for i, idxs := range indices {
182		dir := reverse
183		if i == 0 {
184			dir = bidir
185		}
186		fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false)
187		if err != nil {
188			logrus.Warn(err)
189		}
190		if i != 0 {
191			continue
192		}
193		err = programSP(fSA, rSA, false)
194		if err != nil {
195			logrus.Warn(err)
196		}
197	}
198	return nil
199}
200
201func programMangle(vni uint32, add bool) (err error) {
202	var (
203		p      = strconv.FormatUint(uint64(vxlanPort), 10)
204		c      = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
205		m      = strconv.FormatUint(uint64(r), 10)
206		chain  = "OUTPUT"
207		rule   = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m}
208		a      = "-A"
209		action = "install"
210	)
211
212	if add == iptables.Exists(iptables.Mangle, chain, rule...) {
213		return
214	}
215
216	if !add {
217		a = "-D"
218		action = "remove"
219	}
220
221	if err = iptables.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil {
222		logrus.Warnf("could not %s mangle rule: %v", action, err)
223	}
224
225	return
226}
227
228func programInput(vni uint32, add bool) (err error) {
229	var (
230		port       = strconv.FormatUint(uint64(vxlanPort), 10)
231		vniMatch   = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8)
232		plainVxlan = []string{"-p", "udp", "--dport", port, "-m", "u32", "--u32", vniMatch, "-j"}
233		ipsecVxlan = append([]string{"-m", "policy", "--dir", "in", "--pol", "ipsec"}, plainVxlan...)
234		block      = append(plainVxlan, "DROP")
235		accept     = append(ipsecVxlan, "ACCEPT")
236		chain      = "INPUT"
237		action     = iptables.Append
238		msg        = "add"
239	)
240
241	if !add {
242		action = iptables.Delete
243		msg = "remove"
244	}
245
246	if err := iptables.ProgramRule(iptables.Filter, chain, action, accept); err != nil {
247		logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err)
248	}
249
250	if err := iptables.ProgramRule(iptables.Filter, chain, action, block); err != nil {
251		logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err)
252	}
253
254	return
255}
256
257func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) {
258	var (
259		action      = "Removing"
260		xfrmProgram = ns.NlHandle().XfrmStateDel
261	)
262
263	if add {
264		action = "Adding"
265		xfrmProgram = ns.NlHandle().XfrmStateAdd
266	}
267
268	if dir&reverse > 0 {
269		rSA = &netlink.XfrmState{
270			Src:   remoteIP,
271			Dst:   localIP,
272			Proto: netlink.XFRM_PROTO_ESP,
273			Spi:   spi.reverse,
274			Mode:  netlink.XFRM_MODE_TRANSPORT,
275			Reqid: r,
276		}
277		if add {
278			rSA.Aead = buildAeadAlgo(k, spi.reverse)
279		}
280
281		exists, err := saExists(rSA)
282		if err != nil {
283			exists = !add
284		}
285
286		if add != exists {
287			logrus.Debugf("%s: rSA{%s}", action, rSA)
288			if err := xfrmProgram(rSA); err != nil {
289				logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err)
290			}
291		}
292	}
293
294	if dir&forward > 0 {
295		fSA = &netlink.XfrmState{
296			Src:   localIP,
297			Dst:   remoteIP,
298			Proto: netlink.XFRM_PROTO_ESP,
299			Spi:   spi.forward,
300			Mode:  netlink.XFRM_MODE_TRANSPORT,
301			Reqid: r,
302		}
303		if add {
304			fSA.Aead = buildAeadAlgo(k, spi.forward)
305		}
306
307		exists, err := saExists(fSA)
308		if err != nil {
309			exists = !add
310		}
311
312		if add != exists {
313			logrus.Debugf("%s fSA{%s}", action, fSA)
314			if err := xfrmProgram(fSA); err != nil {
315				logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err)
316			}
317		}
318	}
319
320	return
321}
322
323func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error {
324	action := "Removing"
325	xfrmProgram := ns.NlHandle().XfrmPolicyDel
326	if add {
327		action = "Adding"
328		xfrmProgram = ns.NlHandle().XfrmPolicyAdd
329	}
330
331	// Create a congruent cidr
332	s := types.GetMinimalIP(fSA.Src)
333	d := types.GetMinimalIP(fSA.Dst)
334	fullMask := net.CIDRMask(8*len(s), 8*len(s))
335
336	fPol := &netlink.XfrmPolicy{
337		Src:     &net.IPNet{IP: s, Mask: fullMask},
338		Dst:     &net.IPNet{IP: d, Mask: fullMask},
339		Dir:     netlink.XFRM_DIR_OUT,
340		Proto:   17,
341		DstPort: 4789,
342		Mark:    &spMark,
343		Tmpls: []netlink.XfrmPolicyTmpl{
344			{
345				Src:   fSA.Src,
346				Dst:   fSA.Dst,
347				Proto: netlink.XFRM_PROTO_ESP,
348				Mode:  netlink.XFRM_MODE_TRANSPORT,
349				Spi:   fSA.Spi,
350				Reqid: r,
351			},
352		},
353	}
354
355	exists, err := spExists(fPol)
356	if err != nil {
357		exists = !add
358	}
359
360	if add != exists {
361		logrus.Debugf("%s fSP{%s}", action, fPol)
362		if err := xfrmProgram(fPol); err != nil {
363			logrus.Warnf("%s fSP{%s}: %v", action, fPol, err)
364		}
365	}
366
367	return nil
368}
369
370func saExists(sa *netlink.XfrmState) (bool, error) {
371	_, err := ns.NlHandle().XfrmStateGet(sa)
372	switch err {
373	case nil:
374		return true, nil
375	case syscall.ESRCH:
376		return false, nil
377	default:
378		err = fmt.Errorf("Error while checking for SA existence: %v", err)
379		logrus.Warn(err)
380		return false, err
381	}
382}
383
384func spExists(sp *netlink.XfrmPolicy) (bool, error) {
385	_, err := ns.NlHandle().XfrmPolicyGet(sp)
386	switch err {
387	case nil:
388		return true, nil
389	case syscall.ENOENT:
390		return false, nil
391	default:
392		err = fmt.Errorf("Error while checking for SP existence: %v", err)
393		logrus.Warn(err)
394		return false, err
395	}
396}
397
398func buildSPI(src, dst net.IP, st uint32) int {
399	b := make([]byte, 4)
400	binary.BigEndian.PutUint32(b, st)
401	h := fnv.New32a()
402	h.Write(src)
403	h.Write(b)
404	h.Write(dst)
405	return int(binary.BigEndian.Uint32(h.Sum(nil)))
406}
407
408func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo {
409	salt := make([]byte, 4)
410	binary.BigEndian.PutUint32(salt, uint32(s))
411	return &netlink.XfrmStateAlgo{
412		Name:   "rfc4106(gcm(aes))",
413		Key:    append(k.value, salt...),
414		ICVLen: 64,
415	}
416}
417
418func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error {
419	d.secMap.Lock()
420	for node, indices := range d.secMap.nodes {
421		idxs, stop := f(node, indices)
422		if idxs != nil {
423			d.secMap.nodes[node] = idxs
424		}
425		if stop {
426			break
427		}
428	}
429	d.secMap.Unlock()
430	return nil
431}
432
433func (d *driver) setKeys(keys []*key) error {
434	// Remove any stale policy, state
435	clearEncryptionStates()
436	// Accept the encryption keys and clear any stale encryption map
437	d.Lock()
438	d.keys = keys
439	d.secMap = &encrMap{nodes: map[string][]*spi{}}
440	d.Unlock()
441	logrus.Debugf("Initial encryption keys: %v", d.keys)
442	return nil
443}
444
445// updateKeys allows to add a new key and/or change the primary key and/or prune an existing key
446// The primary key is the key used in transmission and will go in first position in the list.
447func (d *driver) updateKeys(newKey, primary, pruneKey *key) error {
448	logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey)
449
450	logrus.Debugf("Current: %v", d.keys)
451
452	var (
453		newIdx = -1
454		priIdx = -1
455		delIdx = -1
456		lIP    = net.ParseIP(d.bindAddress)
457		aIP    = net.ParseIP(d.advertiseAddress)
458	)
459
460	d.Lock()
461	// add new
462	if newKey != nil {
463		d.keys = append(d.keys, newKey)
464		newIdx += len(d.keys)
465	}
466	for i, k := range d.keys {
467		if primary != nil && k.tag == primary.tag {
468			priIdx = i
469		}
470		if pruneKey != nil && k.tag == pruneKey.tag {
471			delIdx = i
472		}
473	}
474	d.Unlock()
475
476	if (newKey != nil && newIdx == -1) ||
477		(primary != nil && priIdx == -1) ||
478		(pruneKey != nil && delIdx == -1) {
479		return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+
480			"(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx)
481	}
482
483	d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) {
484		rIP := net.ParseIP(rIPs)
485		return updateNodeKey(lIP, aIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false
486	})
487
488	d.Lock()
489	// swap primary
490	if priIdx != -1 {
491		swp := d.keys[0]
492		d.keys[0] = d.keys[priIdx]
493		d.keys[priIdx] = swp
494	}
495	// prune
496	if delIdx != -1 {
497		if delIdx == 0 {
498			delIdx = priIdx
499		}
500		d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...)
501	}
502	d.Unlock()
503
504	logrus.Debugf("Updated: %v", d.keys)
505
506	return nil
507}
508
509/********************************************************
510 * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1
511 * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1
512 * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2
513 *********************************************************/
514
515// Spis and keys are sorted in such away the one in position 0 is the primary
516func updateNodeKey(lIP, aIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi {
517	logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx)
518
519	spis := idxs
520	logrus.Debugf("Current: %v", spis)
521
522	// add new
523	if newIdx != -1 {
524		spis = append(spis, &spi{
525			forward: buildSPI(aIP, rIP, curKeys[newIdx].tag),
526			reverse: buildSPI(rIP, aIP, curKeys[newIdx].tag),
527		})
528	}
529
530	if delIdx != -1 {
531		// -rSA0
532		programSA(lIP, rIP, spis[delIdx], nil, reverse, false)
533	}
534
535	if newIdx > -1 {
536		// +rSA2
537		programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true)
538	}
539
540	if priIdx > 0 {
541		// +fSA2
542		fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true)
543
544		// +fSP2, -fSP1
545		s := types.GetMinimalIP(fSA2.Src)
546		d := types.GetMinimalIP(fSA2.Dst)
547		fullMask := net.CIDRMask(8*len(s), 8*len(s))
548
549		fSP1 := &netlink.XfrmPolicy{
550			Src:     &net.IPNet{IP: s, Mask: fullMask},
551			Dst:     &net.IPNet{IP: d, Mask: fullMask},
552			Dir:     netlink.XFRM_DIR_OUT,
553			Proto:   17,
554			DstPort: 4789,
555			Mark:    &spMark,
556			Tmpls: []netlink.XfrmPolicyTmpl{
557				{
558					Src:   fSA2.Src,
559					Dst:   fSA2.Dst,
560					Proto: netlink.XFRM_PROTO_ESP,
561					Mode:  netlink.XFRM_MODE_TRANSPORT,
562					Spi:   fSA2.Spi,
563					Reqid: r,
564				},
565			},
566		}
567		logrus.Debugf("Updating fSP{%s}", fSP1)
568		if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil {
569			logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err)
570		}
571
572		// -fSA1
573		programSA(lIP, rIP, spis[0], nil, forward, false)
574	}
575
576	// swap
577	if priIdx > 0 {
578		swp := spis[0]
579		spis[0] = spis[priIdx]
580		spis[priIdx] = swp
581	}
582	// prune
583	if delIdx != -1 {
584		if delIdx == 0 {
585			delIdx = priIdx
586		}
587		spis = append(spis[:delIdx], spis[delIdx+1:]...)
588	}
589
590	logrus.Debugf("Updated: %v", spis)
591
592	return spis
593}
594
595func (n *network) maxMTU() int {
596	mtu := 1500
597	if n.mtu != 0 {
598		mtu = n.mtu
599	}
600	mtu -= vxlanEncap
601	if n.secure {
602		// In case of encryption account for the
603		// esp packet espansion and padding
604		mtu -= pktExpansion
605		mtu -= (mtu % 4)
606	}
607	return mtu
608}
609
610func clearEncryptionStates() {
611	nlh := ns.NlHandle()
612	spList, err := nlh.XfrmPolicyList(netlink.FAMILY_ALL)
613	if err != nil {
614		logrus.Warnf("Failed to retrieve SP list for cleanup: %v", err)
615	}
616	saList, err := nlh.XfrmStateList(netlink.FAMILY_ALL)
617	if err != nil {
618		logrus.Warnf("Failed to retrieve SA list for cleanup: %v", err)
619	}
620	for _, sp := range spList {
621		if sp.Mark != nil && sp.Mark.Value == spMark.Value {
622			if err := nlh.XfrmPolicyDel(&sp); err != nil {
623				logrus.Warnf("Failed to delete stale SP %s: %v", sp, err)
624				continue
625			}
626			logrus.Debugf("Removed stale SP: %s", sp)
627		}
628	}
629	for _, sa := range saList {
630		if sa.Reqid == r {
631			if err := nlh.XfrmStateDel(&sa); err != nil {
632				logrus.Warnf("Failed to delete stale SA %s: %v", sa, err)
633				continue
634			}
635			logrus.Debugf("Removed stale SA: %s", sa)
636		}
637	}
638}
639