1package overlay 2 3import ( 4 "bytes" 5 "encoding/binary" 6 "encoding/hex" 7 "fmt" 8 "hash/fnv" 9 "net" 10 "sync" 11 "syscall" 12 13 "strconv" 14 15 "github.com/sirupsen/logrus" 16 "github.com/docker/libnetwork/iptables" 17 "github.com/docker/libnetwork/ns" 18 "github.com/docker/libnetwork/types" 19 "github.com/vishvananda/netlink" 20) 21 22const ( 23 mark = uint32(0xD0C4E3) 24 timeout = 30 25 pktExpansion = 26 // SPI(4) + SeqN(4) + IV(8) + PadLength(1) + NextHeader(1) + ICV(8) 26) 27 28const ( 29 forward = iota + 1 30 reverse 31 bidir 32) 33 34type key struct { 35 value []byte 36 tag uint32 37} 38 39func (k *key) String() string { 40 if k != nil { 41 return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag) 42 } 43 return "" 44} 45 46type spi struct { 47 forward int 48 reverse int 49} 50 51func (s *spi) String() string { 52 return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse)) 53} 54 55type encrMap struct { 56 nodes map[string][]*spi 57 sync.Mutex 58} 59 60func (e *encrMap) String() string { 61 e.Lock() 62 defer e.Unlock() 63 b := new(bytes.Buffer) 64 for k, v := range e.nodes { 65 b.WriteString("\n") 66 b.WriteString(k) 67 b.WriteString(":") 68 b.WriteString("[") 69 for _, s := range v { 70 b.WriteString(s.String()) 71 b.WriteString(",") 72 } 73 b.WriteString("]") 74 75 } 76 return b.String() 77} 78 79func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error { 80 logrus.Debugf("checkEncryption(%s, %v, %d, %t)", nid[0:7], rIP, vxlanID, isLocal) 81 82 n := d.network(nid) 83 if n == nil || !n.secure { 84 return nil 85 } 86 87 if len(d.keys) == 0 { 88 return types.ForbiddenErrorf("encryption key is not present") 89 } 90 91 lIP := net.ParseIP(d.bindAddress) 92 aIP := net.ParseIP(d.advertiseAddress) 93 nodes := map[string]net.IP{} 94 95 switch { 96 case isLocal: 97 if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool { 98 if !aIP.Equal(pEntry.vtep) { 99 nodes[pEntry.vtep.String()] = pEntry.vtep 100 } 101 return false 102 }); err != nil { 103 logrus.Warnf("Failed to retrieve list of participating nodes in overlay network %s: %v", nid[0:5], err) 104 } 105 default: 106 if len(d.network(nid).endpoints) > 0 { 107 nodes[rIP.String()] = rIP 108 } 109 } 110 111 logrus.Debugf("List of nodes: %s", nodes) 112 113 if add { 114 for _, rIP := range nodes { 115 if err := setupEncryption(lIP, aIP, rIP, vxlanID, d.secMap, d.keys); err != nil { 116 logrus.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err) 117 } 118 } 119 } else { 120 if len(nodes) == 0 { 121 if err := removeEncryption(lIP, rIP, d.secMap); err != nil { 122 logrus.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err) 123 } 124 } 125 } 126 127 return nil 128} 129 130func setupEncryption(localIP, advIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error { 131 logrus.Debugf("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP) 132 rIPs := remoteIP.String() 133 134 indices := make([]*spi, 0, len(keys)) 135 136 err := programMangle(vni, true) 137 if err != nil { 138 logrus.Warn(err) 139 } 140 141 for i, k := range keys { 142 spis := &spi{buildSPI(advIP, remoteIP, k.tag), buildSPI(remoteIP, advIP, k.tag)} 143 dir := reverse 144 if i == 0 { 145 dir = bidir 146 } 147 fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true) 148 if err != nil { 149 logrus.Warn(err) 150 } 151 indices = append(indices, spis) 152 if i != 0 { 153 continue 154 } 155 err = programSP(fSA, rSA, true) 156 if err != nil { 157 logrus.Warn(err) 158 } 159 } 160 161 em.Lock() 162 em.nodes[rIPs] = indices 163 em.Unlock() 164 165 return nil 166} 167 168func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error { 169 em.Lock() 170 indices, ok := em.nodes[remoteIP.String()] 171 em.Unlock() 172 if !ok { 173 return nil 174 } 175 for i, idxs := range indices { 176 dir := reverse 177 if i == 0 { 178 dir = bidir 179 } 180 fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false) 181 if err != nil { 182 logrus.Warn(err) 183 } 184 if i != 0 { 185 continue 186 } 187 err = programSP(fSA, rSA, false) 188 if err != nil { 189 logrus.Warn(err) 190 } 191 } 192 return nil 193} 194 195func programMangle(vni uint32, add bool) (err error) { 196 var ( 197 p = strconv.FormatUint(uint64(vxlanPort), 10) 198 c = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8) 199 m = strconv.FormatUint(uint64(mark), 10) 200 chain = "OUTPUT" 201 rule = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m} 202 a = "-A" 203 action = "install" 204 ) 205 206 if add == iptables.Exists(iptables.Mangle, chain, rule...) { 207 return 208 } 209 210 if !add { 211 a = "-D" 212 action = "remove" 213 } 214 215 if err = iptables.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil { 216 logrus.Warnf("could not %s mangle rule: %v", action, err) 217 } 218 219 return 220} 221 222func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) { 223 var ( 224 action = "Removing" 225 xfrmProgram = ns.NlHandle().XfrmStateDel 226 ) 227 228 if add { 229 action = "Adding" 230 xfrmProgram = ns.NlHandle().XfrmStateAdd 231 } 232 233 if dir&reverse > 0 { 234 rSA = &netlink.XfrmState{ 235 Src: remoteIP, 236 Dst: localIP, 237 Proto: netlink.XFRM_PROTO_ESP, 238 Spi: spi.reverse, 239 Mode: netlink.XFRM_MODE_TRANSPORT, 240 } 241 if add { 242 rSA.Aead = buildAeadAlgo(k, spi.reverse) 243 } 244 245 exists, err := saExists(rSA) 246 if err != nil { 247 exists = !add 248 } 249 250 if add != exists { 251 logrus.Debugf("%s: rSA{%s}", action, rSA) 252 if err := xfrmProgram(rSA); err != nil { 253 logrus.Warnf("Failed %s rSA{%s}: %v", action, rSA, err) 254 } 255 } 256 } 257 258 if dir&forward > 0 { 259 fSA = &netlink.XfrmState{ 260 Src: localIP, 261 Dst: remoteIP, 262 Proto: netlink.XFRM_PROTO_ESP, 263 Spi: spi.forward, 264 Mode: netlink.XFRM_MODE_TRANSPORT, 265 } 266 if add { 267 fSA.Aead = buildAeadAlgo(k, spi.forward) 268 } 269 270 exists, err := saExists(fSA) 271 if err != nil { 272 exists = !add 273 } 274 275 if add != exists { 276 logrus.Debugf("%s fSA{%s}", action, fSA) 277 if err := xfrmProgram(fSA); err != nil { 278 logrus.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err) 279 } 280 } 281 } 282 283 return 284} 285 286func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error { 287 action := "Removing" 288 xfrmProgram := ns.NlHandle().XfrmPolicyDel 289 if add { 290 action = "Adding" 291 xfrmProgram = ns.NlHandle().XfrmPolicyAdd 292 } 293 294 fullMask := net.CIDRMask(8*len(fSA.Src), 8*len(fSA.Src)) 295 296 fPol := &netlink.XfrmPolicy{ 297 Src: &net.IPNet{IP: fSA.Src, Mask: fullMask}, 298 Dst: &net.IPNet{IP: fSA.Dst, Mask: fullMask}, 299 Dir: netlink.XFRM_DIR_OUT, 300 Proto: 17, 301 DstPort: 4789, 302 Mark: &netlink.XfrmMark{ 303 Value: mark, 304 }, 305 Tmpls: []netlink.XfrmPolicyTmpl{ 306 { 307 Src: fSA.Src, 308 Dst: fSA.Dst, 309 Proto: netlink.XFRM_PROTO_ESP, 310 Mode: netlink.XFRM_MODE_TRANSPORT, 311 Spi: fSA.Spi, 312 }, 313 }, 314 } 315 316 exists, err := spExists(fPol) 317 if err != nil { 318 exists = !add 319 } 320 321 if add != exists { 322 logrus.Debugf("%s fSP{%s}", action, fPol) 323 if err := xfrmProgram(fPol); err != nil { 324 logrus.Warnf("%s fSP{%s}: %v", action, fPol, err) 325 } 326 } 327 328 return nil 329} 330 331func saExists(sa *netlink.XfrmState) (bool, error) { 332 _, err := ns.NlHandle().XfrmStateGet(sa) 333 switch err { 334 case nil: 335 return true, nil 336 case syscall.ESRCH: 337 return false, nil 338 default: 339 err = fmt.Errorf("Error while checking for SA existence: %v", err) 340 logrus.Warn(err) 341 return false, err 342 } 343} 344 345func spExists(sp *netlink.XfrmPolicy) (bool, error) { 346 _, err := ns.NlHandle().XfrmPolicyGet(sp) 347 switch err { 348 case nil: 349 return true, nil 350 case syscall.ENOENT: 351 return false, nil 352 default: 353 err = fmt.Errorf("Error while checking for SP existence: %v", err) 354 logrus.Warn(err) 355 return false, err 356 } 357} 358 359func buildSPI(src, dst net.IP, st uint32) int { 360 b := make([]byte, 4) 361 binary.BigEndian.PutUint32(b, st) 362 h := fnv.New32a() 363 h.Write(src) 364 h.Write(b) 365 h.Write(dst) 366 return int(binary.BigEndian.Uint32(h.Sum(nil))) 367} 368 369func buildAeadAlgo(k *key, s int) *netlink.XfrmStateAlgo { 370 salt := make([]byte, 4) 371 binary.BigEndian.PutUint32(salt, uint32(s)) 372 return &netlink.XfrmStateAlgo{ 373 Name: "rfc4106(gcm(aes))", 374 Key: append(k.value, salt...), 375 ICVLen: 64, 376 } 377} 378 379func (d *driver) secMapWalk(f func(string, []*spi) ([]*spi, bool)) error { 380 d.secMap.Lock() 381 for node, indices := range d.secMap.nodes { 382 idxs, stop := f(node, indices) 383 if idxs != nil { 384 d.secMap.nodes[node] = idxs 385 } 386 if stop { 387 break 388 } 389 } 390 d.secMap.Unlock() 391 return nil 392} 393 394func (d *driver) setKeys(keys []*key) error { 395 // Accept the encryption keys and clear any stale encryption map 396 d.Lock() 397 d.keys = keys 398 d.secMap = &encrMap{nodes: map[string][]*spi{}} 399 d.Unlock() 400 logrus.Debugf("Initial encryption keys: %v", d.keys) 401 return nil 402} 403 404// updateKeys allows to add a new key and/or change the primary key and/or prune an existing key 405// The primary key is the key used in transmission and will go in first position in the list. 406func (d *driver) updateKeys(newKey, primary, pruneKey *key) error { 407 logrus.Debugf("Updating Keys. New: %v, Primary: %v, Pruned: %v", newKey, primary, pruneKey) 408 409 logrus.Debugf("Current: %v", d.keys) 410 411 var ( 412 newIdx = -1 413 priIdx = -1 414 delIdx = -1 415 lIP = net.ParseIP(d.bindAddress) 416 ) 417 418 d.Lock() 419 // add new 420 if newKey != nil { 421 d.keys = append(d.keys, newKey) 422 newIdx += len(d.keys) 423 } 424 for i, k := range d.keys { 425 if primary != nil && k.tag == primary.tag { 426 priIdx = i 427 } 428 if pruneKey != nil && k.tag == pruneKey.tag { 429 delIdx = i 430 } 431 } 432 d.Unlock() 433 434 if (newKey != nil && newIdx == -1) || 435 (primary != nil && priIdx == -1) || 436 (pruneKey != nil && delIdx == -1) { 437 return types.BadRequestErrorf("cannot find proper key indices while processing key update:"+ 438 "(newIdx,priIdx,delIdx):(%d, %d, %d)", newIdx, priIdx, delIdx) 439 } 440 441 d.secMapWalk(func(rIPs string, spis []*spi) ([]*spi, bool) { 442 rIP := net.ParseIP(rIPs) 443 return updateNodeKey(lIP, rIP, spis, d.keys, newIdx, priIdx, delIdx), false 444 }) 445 446 d.Lock() 447 // swap primary 448 if priIdx != -1 { 449 swp := d.keys[0] 450 d.keys[0] = d.keys[priIdx] 451 d.keys[priIdx] = swp 452 } 453 // prune 454 if delIdx != -1 { 455 if delIdx == 0 { 456 delIdx = priIdx 457 } 458 d.keys = append(d.keys[:delIdx], d.keys[delIdx+1:]...) 459 } 460 d.Unlock() 461 462 logrus.Debugf("Updated: %v", d.keys) 463 464 return nil 465} 466 467/******************************************************** 468 * Steady state: rSA0, rSA1, rSA2, fSA1, fSP1 469 * Rotation --> -rSA0, +rSA3, +fSA2, +fSP2/-fSP1, -fSA1 470 * Steady state: rSA1, rSA2, rSA3, fSA2, fSP2 471 *********************************************************/ 472 473// Spis and keys are sorted in such away the one in position 0 is the primary 474func updateNodeKey(lIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi { 475 logrus.Debugf("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx) 476 477 spis := idxs 478 logrus.Debugf("Current: %v", spis) 479 480 // add new 481 if newIdx != -1 { 482 spis = append(spis, &spi{ 483 forward: buildSPI(lIP, rIP, curKeys[newIdx].tag), 484 reverse: buildSPI(rIP, lIP, curKeys[newIdx].tag), 485 }) 486 } 487 488 if delIdx != -1 { 489 // -rSA0 490 programSA(lIP, rIP, spis[delIdx], nil, reverse, false) 491 } 492 493 if newIdx > -1 { 494 // +RSA2 495 programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true) 496 } 497 498 if priIdx > 0 { 499 // +fSA2 500 fSA2, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true) 501 502 // +fSP2, -fSP1 503 fullMask := net.CIDRMask(8*len(fSA2.Src), 8*len(fSA2.Src)) 504 fSP1 := &netlink.XfrmPolicy{ 505 Src: &net.IPNet{IP: fSA2.Src, Mask: fullMask}, 506 Dst: &net.IPNet{IP: fSA2.Dst, Mask: fullMask}, 507 Dir: netlink.XFRM_DIR_OUT, 508 Proto: 17, 509 DstPort: 4789, 510 Mark: &netlink.XfrmMark{ 511 Value: mark, 512 }, 513 Tmpls: []netlink.XfrmPolicyTmpl{ 514 { 515 Src: fSA2.Src, 516 Dst: fSA2.Dst, 517 Proto: netlink.XFRM_PROTO_ESP, 518 Mode: netlink.XFRM_MODE_TRANSPORT, 519 Spi: fSA2.Spi, 520 }, 521 }, 522 } 523 logrus.Debugf("Updating fSP{%s}", fSP1) 524 if err := ns.NlHandle().XfrmPolicyUpdate(fSP1); err != nil { 525 logrus.Warnf("Failed to update fSP{%s}: %v", fSP1, err) 526 } 527 528 // -fSA1 529 programSA(lIP, rIP, spis[0], nil, forward, false) 530 } 531 532 // swap 533 if priIdx > 0 { 534 swp := spis[0] 535 spis[0] = spis[priIdx] 536 spis[priIdx] = swp 537 } 538 // prune 539 if delIdx != -1 { 540 if delIdx == 0 { 541 delIdx = priIdx 542 } 543 spis = append(spis[:delIdx], spis[delIdx+1:]...) 544 } 545 546 logrus.Debugf("Updated: %v", spis) 547 548 return spis 549} 550 551func (n *network) maxMTU() int { 552 mtu := 1500 553 if n.mtu != 0 { 554 mtu = n.mtu 555 } 556 mtu -= vxlanEncap 557 if n.secure { 558 // In case of encryption account for the 559 // esp packet espansion and padding 560 mtu -= pktExpansion 561 mtu -= (mtu % 4) 562 } 563 return mtu 564} 565