1package overlay 2 3import ( 4 "bytes" 5 "encoding/binary" 6 "encoding/hex" 7 "fmt" 8 "hash/fnv" 9 "net" 10 "sync" 11 "syscall" 12 13 "strconv" 14 15 "github.com/docker/libnetwork/iptables" 16 "github.com/docker/libnetwork/ns" 17 "github.com/docker/libnetwork/types" 18 "github.com/sirupsen/logrus" 19 "github.com/vishvananda/netlink" 20) 21 22const ( 23 r = 0xD0C4E3 24 pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8) 25) 26 27const ( 28 forward = iota + 1 29 reverse 30 bidir 31) 32 33var spMark = netlink.XfrmMark{Value: uint32(r), Mask: 0xffffffff} 34 35type key struct { 36 value []byte 37 tag uint32 38} 39 40func (k *key) String() string { 41 if k != nil { 42 return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag) 43 } 44 return "" 45} 46 47type spi struct { 48 forward int 49 reverse int 50} 51 52func (s *spi) String() string { 53 return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse)) 54} 55 56type encrMap struct { 57 nodes map[string][]*spi 58 sync.Mutex 59} 60 61func (e *encrMap) String() string { 62 e.Lock() 63 defer e.Unlock() 64 b := new(bytes.Buffer) 65 for k, v := range e.nodes { 66 b.WriteString("\n") 67 b.WriteString(k) 68 b.WriteString(":") 69 b.WriteString("[") 70 for _, s := range v { 71 b.WriteString(s.String()) 72 b.WriteString(",") 73 } 74 b.WriteString("]") 75 76 } 77 return b.String() 78} 79 80func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error { 81 logrus.Debugf("checkEncryption(%s, %v, %d, %t)", nid[0:7], rIP, vxlanID, isLocal) 82 83 n := d.network(nid) 84 if n == nil || !n.secure { 85 return nil 86 } 87 88 if len(d.keys) == 0 { 89 return types.ForbiddenErrorf("encryption key is not present") 90 } 91 92 lIP := net.ParseIP(d.bindAddress) 93 aIP := net.ParseIP(d.advertiseAddress) 94 nodes := map[string]net.IP{} 95 96 switch { 97 case isLocal: 98 if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool { 99 if !aIP.Equal(pEntry.vtep) { 100 nodes[pEntry.vtep.String()] = pEntry.vtep 101 } 102 return false 103 }); err != nil { 104 logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %s: %v", nid[0:5], err) 105 } 106 default: 107 if len(d.network(nid).endpoints) > 0 { 108 nodes[rIP.String()] = rIP 109 } 110 } 111 112 logrus.Debugf("List of nodes: %s", nodes) 113 114 if add { 115 for _, rIP := range nodes { 116 if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil { 117 logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err) 118 } 119 } 120 } else { 121 if len(nodes) == 0 { 122 if err := removeEncryption(lIP, rIP, d.secMap); err != nil { 123 logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err) 124 } 125 } 126 } 127 128 return nil 129} 130 131func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error { 132 logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP) 133 rIPs := remoteIP.String() 134 135 indices := make([]*spi, 0, len(keys)) 136 137 err := programMangle(vni, true) 138 if err != nil { 139 logrus.Warn(err) 140 } 141 142 err = programInput(vni, true) 143 if err != nil { 144 logrus.Warn(err) 145 } 146 147 for i, k := range keys { 148 spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)} 149 dir := reverse 150 if i == 0 { 151 dir = bidir 152 } 153 fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true) 154 if err != nil { 155 logrus.Warn(err) 156 } 157 indices = append(indices, spis) 158 if i != 0 { 159 continue 160 } 161 err = programSP(fSA, rSA, true) 162 if err != nil { 163 logrus.Warn(err) 164 } 165 } 166 167 em.Lock() 168 em.nodes[rIPs] = indices 169 em.Unlock() 170 171 return nil 172} 173 174func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error { 175 em.Lock() 176 indices, ok := em.nodes[remoteIP.String()] 177 em.Unlock() 178 if !ok { 179 return nil 180 } 181 for i, idxs := range indices { 182 dir := reverse 183 if i == 0 { 184 dir = bidir 185 } 186 fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false) 187 if err != nil { 188 logrus.Warn(err) 189 } 190 if i != 0 { 191 continue 192 } 193 err = programSP(fSA, rSA, false) 194 if err != nil { 195 logrus.Warn(err) 196 } 197 } 198 return nil 199} 200 201func programMangle(vni uint32, add bool) (err error) { 202 var ( 203 p = strconv.FormatUint(uint64(vxlanPort), 10) 204 c = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8) 205 m = strconv.FormatUint(uint64(r), 10) 206 chain = "OUTPUT" 207 rule = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m} 208 a = "-A" 209 action = "install" 210 ) 211 212 if add == iptables.Exists(iptables.Mangle, chain, rule...) { 213 return 214 } 215 216 if !add { 217 a = "-D" 218 action = "remove" 219 } 220 221 if err = iptables.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil { 222 logrus.Warnf("could not %s mangle rule: %v", action, err) 223 } 224 225 return 226} 227 228func programInput(vni uint32, add bool) (err error) { 229 var ( 230 port = strconv.FormatUint(uint64(vxlanPort), 10) 231 vniMatch = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8) 232 plainVxlan = []string{"-p", "udp", "--dport", port, "-m", "u32", "--u32", vniMatch, "-j"} 233 ipsecVxlan = append([]string{"-m", "policy", "--dir", "in", "--pol", "ipsec"}, plainVxlan...) 234 block = append(plainVxlan, "DROP") 235 accept = append(ipsecVxlan, "ACCEPT") 236 chain = "INPUT" 237 action = iptables.Append 238 msg = "add" 239 ) 240 241 if !add { 242 action = iptables.Delete 243 msg = "remove" 244 } 245 246 if err := iptables.ProgramRule(iptables.Filter, chain, action, accept); err != nil { 247 logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err) 248 } 249 250 if err := iptables.ProgramRule(iptables.Filter, chain, action, block); err != nil { 251 logrus.Errorf("could not %s input rule: %v. Please do it manually.", msg, err) 252 } 253 254 return 255} 256 257func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) { 258 var ( 259 action = "Removing" 260 xfrmProgram = ns.NlHandle().XfrmStateDel 261 ) 262 263 if add { 264 action = "Adding" 265 xfrmProgram = ns.NlHandle().XfrmStateAdd 266 } 267 268 if dir&reverse > 0 { 269 rSA = &netlink.XfrmState{ 270 Src: remoteIP, 271 Dst: localIP, 272 Proto: netlink.XFRM_PROTO_ESP, 273 Spi: spi.reverse, 274 Mode: netlink.XFRM_MODE_TRANSPORT, 275 Reqid: r, 276 } 277 if add { 278 rSA.Aead = buildAeadAlgo(k, spi.reverse) 279 } 280 281 exists, err := saExists(rSA) 282 if err != nil { 283 exists = !add 284 } 285 286 if add != exists { 287 logrus.Debugf("%s: rSA{%s}", action, rSA) 288 if err := xfrmProgram(rSA); err != nil { 289 logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err) 290 } 291 } 292 } 293 294 if dir&forward > 0 { 295 fSA = &netlink.XfrmState{ 296 Src: localIP, 297 Dst: remoteIP, 298 Proto: netlink.XFRM_PROTO_ESP, 299 Spi: spi.forward, 300 Mode: netlink.XFRM_MODE_TRANSPORT, 301 Reqid: r, 302 } 303 if add { 304 fSA.Aead = buildAeadAlgo(k, spi.forward) 305 } 306 307 exists, err := saExists(fSA) 308 if err != nil { 309 exists = !add 310 } 311 312 if add != exists { 313 logrus.Debugf("%s fSA{%s}", action, fSA) 314 if err := xfrmProgram(fSA); err != nil { 315 logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err) 316 } 317 } 318 } 319 320 return 321} 322 323func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error { 324 action := "Removing" 325 xfrmProgram := ns.NlHandle().XfrmPolicyDel 326 if add { 327 action = "Adding" 328 xfrmProgram = ns.NlHandle().XfrmPolicyAdd 329 } 330 331 // Create a congruent cidr 332 s := types.GetMinimalIP(fSA.Src) 333 d := types.GetMinimalIP(fSA.Dst) 334 fullMask := net.CIDRMask(8*len(s), 8*len(s)) 335 336 fPol := &netlink.XfrmPolicy{ 337 Src: &net.IPNet{IP: s, Mask: fullMask}, 338 Dst: &net.IPNet{IP: d, Mask: fullMask}, 339 Dir: netlink.XFRM_DIR_OUT, 340 Proto: 17, 341 DstPort: 4789, 342 Mark: &spMark, 343 Tmpls: []netlink.XfrmPolicyTmpl{ 344 { 345 Src: fSA.Src, 346 Dst: fSA.Dst, 347 Proto: netlink.XFRM_PROTO_ESP, 348 Mode: netlink.XFRM_MODE_TRANSPORT, 349 Spi: fSA.Spi, 350 Reqid: r, 351 }, 352 }, 353 } 354 355 exists, err := spExists(fPol) 356 if err != nil { 357 exists = !add 358 } 359 360 if add != exists { 361 logrus.Debugf("%s fSP{%s}", action, fPol) 362 if err := xfrmProgram(fPol); err != nil { 363 logrus.Warnf("%s fSP{%s}: %v", action, fPol, err) 364 } 365 } 366 367 return nil 368} 369 370func saExists(sa *netlink.XfrmState) (bool, error) { 371 _, err := ns.NlHandle().XfrmStateGet(sa) 372 switch err { 373 case nil: 374 return true, nil 375 case syscall.ESRCH: 376 return false, nil 377 default: 378 err = fmt.Errorf("Error while checking for SA existence: %v", err) 379 logrus.Warn(err) 380 return false, err 381 } 382} 383 384func spExists(sp *netlink.XfrmPolicy) (bool, error) { 385 _, err := ns.NlHandle().XfrmPolicyGet(sp) 386 switch err { 387 case nil: 388 return true, nil 389 case syscall.ENOENT: 390 return false, nil 391 default: 392 err = fmt.Errorf("Error while checking for SP existence: %v", err) 393 logrus.Warn(err) 394 return false, err 395 } 396} 397 398func buildSPI(src, dst net.IP, st uint32) int { 399 b := make([]byte, 4) 400 binary.BigEndian.PutUint32(b, st) 401 h := fnv.New32a() 402 h.Write(src) 403 h.Write(b) 404 h.Write(dst) 405 return int(binary.BigEndian.Uint32(h.Sum(nil))) 406} 407 408func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo { 409 salt := make([]byte, 4) 410 binary.BigEndian.PutUint32(salt, uint32(s)) 411 return &netlink.XfrmStateAlgo{ 412 Name: "rfc4106(gcm(aes))", 413 Key: append(k.value, salt...), 414 ICVLen: 64, 415 } 416} 417 418func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error { 419 d.secMap.Lock() 420 for node, indices := range d.secMap.nodes { 421 idxs, stop := f(node, indices) 422 if idxs != nil { 423 d.secMap.nodes[node] = idxs 424 } 425 if stop { 426 break 427 } 428 } 429 d.secMap.Unlock() 430 return nil 431} 432 433func (d *driver) setKeys(keys []*key) error { 434 // Remove any stale policy, state 435 clearEncryptionStates() 436 // Accept the encryption keys and clear any stale encryption map 437 d.Lock() 438 d.keys = keys 439 d.secMap = &encrMap{nodes: map[string][]*spi{}} 440 d.Unlock() 441 logrus.Debugf("Initial encryption keys: %v", d.keys) 442 return nil 443} 444 445// updateKeys allows to add a new key and/or change the primary key and/or prune an existing key 446// The primary key is the key used in transmission and will go in first position in the list. 447func (d *driver) updateKeys(newKey, primary, pruneKey *key) error { 448 logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey) 449 450 logrus.Debugf("Current: %v", d.keys) 451 452 var ( 453 newIdx = -1 454 priIdx = -1 455 delIdx = -1 456 lIP = net.ParseIP(d.bindAddress) 457 aIP = net.ParseIP(d.advertiseAddress) 458 ) 459 460 d.Lock() 461 // add new 462 if newKey != nil { 463 d.keys = append(d.keys, newKey) 464 newIdx += len(d.keys) 465 } 466 for i, k := range d.keys { 467 if primary != nil && k.tag == primary.tag { 468 priIdx = i 469 } 470 if pruneKey != nil && k.tag == pruneKey.tag { 471 delIdx = i 472 } 473 } 474 d.Unlock() 475 476 if (newKey != nil && newIdx == -1) || 477 (primary != nil && priIdx == -1) || 478 (pruneKey != nil && delIdx == -1) { 479 return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+ 480 "(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx) 481 } 482 483 d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) { 484 rIP := net.ParseIP(rIPs) 485 return updateNodeKey(lIP, aIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false 486 }) 487 488 d.Lock() 489 // swap primary 490 if priIdx != -1 { 491 swp := d.keys[0] 492 d.keys[0] = d.keys[priIdx] 493 d.keys[priIdx] = swp 494 } 495 // prune 496 if delIdx != -1 { 497 if delIdx == 0 { 498 delIdx = priIdx 499 } 500 d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...) 501 } 502 d.Unlock() 503 504 logrus.Debugf("Updated: %v", d.keys) 505 506 return nil 507} 508 509/******************************************************** 510 * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1 511 * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1 512 * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2 513 *********************************************************/ 514 515// Spis and keys are sorted in such away the one in position 0 is the primary 516func updateNodeKey(lIP, aIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi { 517 logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx) 518 519 spis := idxs 520 logrus.Debugf("Current: %v", spis) 521 522 // add new 523 if newIdx != -1 { 524 spis = append(spis, &spi{ 525 forward: buildSPI(aIP, rIP, curKeys[newIdx].tag), 526 reverse: buildSPI(rIP, aIP, curKeys[newIdx].tag), 527 }) 528 } 529 530 if delIdx != -1 { 531 // -rSA0 532 programSA(lIP, rIP, spis[delIdx], nil, reverse, false) 533 } 534 535 if newIdx > -1 { 536 // +rSA2 537 programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true) 538 } 539 540 if priIdx > 0 { 541 // +fSA2 542 fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true) 543 544 // +fSP2, -fSP1 545 s := types.GetMinimalIP(fSA2.Src) 546 d := types.GetMinimalIP(fSA2.Dst) 547 fullMask := net.CIDRMask(8*len(s), 8*len(s)) 548 549 fSP1 := &netlink.XfrmPolicy{ 550 Src: &net.IPNet{IP: s, Mask: fullMask}, 551 Dst: &net.IPNet{IP: d, Mask: fullMask}, 552 Dir: netlink.XFRM_DIR_OUT, 553 Proto: 17, 554 DstPort: 4789, 555 Mark: &spMark, 556 Tmpls: []netlink.XfrmPolicyTmpl{ 557 { 558 Src: fSA2.Src, 559 Dst: fSA2.Dst, 560 Proto: netlink.XFRM_PROTO_ESP, 561 Mode: netlink.XFRM_MODE_TRANSPORT, 562 Spi: fSA2.Spi, 563 Reqid: r, 564 }, 565 }, 566 } 567 logrus.Debugf("Updating fSP{%s}", fSP1) 568 if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil { 569 logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err) 570 } 571 572 // -fSA1 573 programSA(lIP, rIP, spis[0], nil, forward, false) 574 } 575 576 // swap 577 if priIdx > 0 { 578 swp := spis[0] 579 spis[0] = spis[priIdx] 580 spis[priIdx] = swp 581 } 582 // prune 583 if delIdx != -1 { 584 if delIdx == 0 { 585 delIdx = priIdx 586 } 587 spis = append(spis[:delIdx], spis[delIdx+1:]...) 588 } 589 590 logrus.Debugf("Updated: %v", spis) 591 592 return spis 593} 594 595func (n *network) maxMTU() int { 596 mtu := 1500 597 if n.mtu != 0 { 598 mtu = n.mtu 599 } 600 mtu -= vxlanEncap 601 if n.secure { 602 // In case of encryption account for the 603 // esp packet espansion and padding 604 mtu -= pktExpansion 605 mtu -= (mtu % 4) 606 } 607 return mtu 608} 609 610func clearEncryptionStates() { 611 nlh := ns.NlHandle() 612 spList, err := nlh.XfrmPolicyList(netlink.FAMILY_ALL) 613 if err != nil { 614 logrus.Warnf("Failed to retrieve SP list for cleanup: %v", err) 615 } 616 saList, err := nlh.XfrmStateList(netlink.FAMILY_ALL) 617 if err != nil { 618 logrus.Warnf("Failed to retrieve SA list for cleanup: %v", err) 619 } 620 for _, sp := range spList { 621 if sp.Mark != nil && sp.Mark.Value == spMark.Value { 622 if err := nlh.XfrmPolicyDel(&sp); err != nil { 623 logrus.Warnf("Failed to delete stale SP %s: %v", sp, err) 624 continue 625 } 626 logrus.Debugf("Removed stale SP: %s", sp) 627 } 628 } 629 for _, sa := range saList { 630 if sa.Reqid == r { 631 if err := nlh.XfrmStateDel(&sa); err != nil { 632 logrus.Warnf("Failed to delete stale SA %s: %v", sa, err) 633 continue 634 } 635 logrus.Debugf("Removed stale SA: %s", sa) 636 } 637 } 638} 639