1package physical 2 3import ( 4 "fmt" 5 "io/ioutil" 6 "log" 7 "net" 8 "net/url" 9 "strconv" 10 "strings" 11 "sync" 12 "sync/atomic" 13 "time" 14 15 "crypto/tls" 16 "crypto/x509" 17 18 "github.com/armon/go-metrics" 19 "github.com/hashicorp/consul/api" 20 "github.com/hashicorp/consul/lib" 21 "github.com/hashicorp/errwrap" 22 "github.com/hashicorp/go-cleanhttp" 23) 24 25const ( 26 // checkJitterFactor specifies the jitter factor used to stagger checks 27 checkJitterFactor = 16 28 29 // checkMinBuffer specifies provides a guarantee that a check will not 30 // be executed too close to the TTL check timeout 31 checkMinBuffer = 100 * time.Millisecond 32 33 // consulRetryInterval specifies the retry duration to use when an 34 // API call to the Consul agent fails. 35 consulRetryInterval = 1 * time.Second 36 37 // defaultCheckTimeout changes the timeout of TTL checks 38 defaultCheckTimeout = 5 * time.Second 39 40 // defaultServiceName is the default Consul service name used when 41 // advertising a Vault instance. 42 defaultServiceName = "vault" 43 44 // reconcileTimeout is how often Vault should query Consul to detect 45 // and fix any state drift. 46 reconcileTimeout = 60 * time.Second 47) 48 49type notifyEvent struct{} 50 51// ConsulBackend is a physical backend that stores data at specific 52// prefix within Consul. It is used for most production situations as 53// it allows Vault to run on multiple machines in a highly-available manner. 54type ConsulBackend struct { 55 path string 56 logger *log.Logger 57 client *api.Client 58 kv *api.KV 59 permitPool *PermitPool 60 serviceLock sync.RWMutex 61 advertiseHost string 62 advertisePort int64 63 serviceName string 64 disableRegistration bool 65 checkTimeout time.Duration 66 67 notifyActiveCh chan notifyEvent 68 notifySealedCh chan notifyEvent 69} 70 71// newConsulBackend constructs a Consul backend using the given API client 72// and the prefix in the KV store. 73func newConsulBackend(conf map[string]string, logger *log.Logger) (Backend, error) { 74 // Get the path in Consul 75 path, ok := conf["path"] 76 if !ok { 77 path = "vault/" 78 } 79 logger.Printf("[DEBUG]: consul: config path set to %v", path) 80 81 // Ensure path is suffixed but not prefixed 82 if !strings.HasSuffix(path, "/") { 83 logger.Printf("[WARN]: consul: appending trailing forward slash to path") 84 path += "/" 85 } 86 if strings.HasPrefix(path, "/") { 87 logger.Printf("[WARN]: consul: trimming path of its forward slash") 88 path = strings.TrimPrefix(path, "/") 89 } 90 91 // Allow admins to disable consul integration 92 disableReg, ok := conf["disable_registration"] 93 var disableRegistration bool 94 if ok && disableReg != "" { 95 b, err := strconv.ParseBool(disableReg) 96 if err != nil { 97 return nil, errwrap.Wrapf("failed parsing disable_registration parameter: {{err}}", err) 98 } 99 disableRegistration = b 100 } 101 logger.Printf("[DEBUG]: consul: config disable_registration set to %v", disableRegistration) 102 103 // Get the service name to advertise in Consul 104 service, ok := conf["service"] 105 if !ok { 106 service = defaultServiceName 107 } 108 logger.Printf("[DEBUG]: consul: config service set to %s", service) 109 110 checkTimeout := defaultCheckTimeout 111 checkTimeoutStr, ok := conf["check_timeout"] 112 if ok { 113 d, err := time.ParseDuration(checkTimeoutStr) 114 if err != nil { 115 return nil, err 116 } 117 118 min, _ := lib.DurationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor) 119 if min < checkMinBuffer { 120 return nil, fmt.Errorf("Consul check_timeout must be greater than %v", min) 121 } 122 123 checkTimeout = d 124 logger.Printf("[DEBUG]: consul: config check_timeout set to %v", d) 125 } 126 127 // Configure the client 128 consulConf := api.DefaultConfig() 129 130 if addr, ok := conf["address"]; ok { 131 consulConf.Address = addr 132 logger.Printf("[DEBUG]: consul: config address set to %d", addr) 133 } 134 if scheme, ok := conf["scheme"]; ok { 135 consulConf.Scheme = scheme 136 logger.Printf("[DEBUG]: consul: config scheme set to %d", scheme) 137 } 138 if token, ok := conf["token"]; ok { 139 consulConf.Token = token 140 logger.Printf("[DEBUG]: consul: config token set") 141 } 142 143 if consulConf.Scheme == "https" { 144 tlsClientConfig, err := setupTLSConfig(conf) 145 if err != nil { 146 return nil, err 147 } 148 149 transport := cleanhttp.DefaultPooledTransport() 150 transport.MaxIdleConnsPerHost = 4 151 transport.TLSClientConfig = tlsClientConfig 152 consulConf.HttpClient.Transport = transport 153 logger.Printf("[DEBUG]: consul: configured TLS") 154 } 155 156 client, err := api.NewClient(consulConf) 157 if err != nil { 158 return nil, errwrap.Wrapf("client setup failed: {{err}}", err) 159 } 160 161 maxParStr, ok := conf["max_parallel"] 162 var maxParInt int 163 if ok { 164 maxParInt, err = strconv.Atoi(maxParStr) 165 if err != nil { 166 return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err) 167 } 168 logger.Printf("[DEBUG]: consul: max_parallel set to %d", maxParInt) 169 } 170 171 // Setup the backend 172 c := &ConsulBackend{ 173 path: path, 174 logger: logger, 175 client: client, 176 kv: client.KV(), 177 permitPool: NewPermitPool(maxParInt), 178 serviceName: service, 179 checkTimeout: checkTimeout, 180 disableRegistration: disableRegistration, 181 } 182 return c, nil 183} 184 185func setupTLSConfig(conf map[string]string) (*tls.Config, error) { 186 serverName := strings.Split(conf["address"], ":") 187 188 insecureSkipVerify := false 189 if _, ok := conf["tls_skip_verify"]; ok { 190 insecureSkipVerify = true 191 } 192 193 tlsClientConfig := &tls.Config{ 194 InsecureSkipVerify: insecureSkipVerify, 195 ServerName: serverName[0], 196 } 197 198 _, okCert := conf["tls_cert_file"] 199 _, okKey := conf["tls_key_file"] 200 201 if okCert && okKey { 202 tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"]) 203 if err != nil { 204 return nil, fmt.Errorf("client tls setup failed: %v", err) 205 } 206 207 tlsClientConfig.Certificates = []tls.Certificate{tlsCert} 208 } 209 210 if tlsCaFile, ok := conf["tls_ca_file"]; ok { 211 caPool := x509.NewCertPool() 212 213 data, err := ioutil.ReadFile(tlsCaFile) 214 if err != nil { 215 return nil, fmt.Errorf("failed to read CA file: %v", err) 216 } 217 218 if !caPool.AppendCertsFromPEM(data) { 219 return nil, fmt.Errorf("failed to parse CA certificate") 220 } 221 222 tlsClientConfig.RootCAs = caPool 223 } 224 225 return tlsClientConfig, nil 226} 227 228// Put is used to insert or update an entry 229func (c *ConsulBackend) Put(entry *Entry) error { 230 defer metrics.MeasureSince([]string{"consul", "put"}, time.Now()) 231 pair := &api.KVPair{ 232 Key: c.path + entry.Key, 233 Value: entry.Value, 234 } 235 236 c.permitPool.Acquire() 237 defer c.permitPool.Release() 238 239 _, err := c.kv.Put(pair, nil) 240 return err 241} 242 243// Get is used to fetch an entry 244func (c *ConsulBackend) Get(key string) (*Entry, error) { 245 defer metrics.MeasureSince([]string{"consul", "get"}, time.Now()) 246 247 c.permitPool.Acquire() 248 defer c.permitPool.Release() 249 250 pair, _, err := c.kv.Get(c.path+key, nil) 251 if err != nil { 252 return nil, err 253 } 254 if pair == nil { 255 return nil, nil 256 } 257 ent := &Entry{ 258 Key: key, 259 Value: pair.Value, 260 } 261 return ent, nil 262} 263 264// Delete is used to permanently delete an entry 265func (c *ConsulBackend) Delete(key string) error { 266 defer metrics.MeasureSince([]string{"consul", "delete"}, time.Now()) 267 268 c.permitPool.Acquire() 269 defer c.permitPool.Release() 270 271 _, err := c.kv.Delete(c.path+key, nil) 272 return err 273} 274 275// List is used to list all the keys under a given 276// prefix, up to the next prefix. 277func (c *ConsulBackend) List(prefix string) ([]string, error) { 278 defer metrics.MeasureSince([]string{"consul", "list"}, time.Now()) 279 scan := c.path + prefix 280 281 // The TrimPrefix call below will not work correctly if we have "//" at the 282 // end. This can happen in cases where you are e.g. listing the root of a 283 // prefix in a logical backend via "/" instead of "" 284 if strings.HasSuffix(scan, "//") { 285 scan = scan[:len(scan)-1] 286 } 287 288 c.permitPool.Acquire() 289 defer c.permitPool.Release() 290 291 out, _, err := c.kv.Keys(scan, "/", nil) 292 for idx, val := range out { 293 out[idx] = strings.TrimPrefix(val, scan) 294 } 295 296 return out, err 297} 298 299// Lock is used for mutual exclusion based on the given key. 300func (c *ConsulBackend) LockWith(key, value string) (Lock, error) { 301 // Create the lock 302 opts := &api.LockOptions{ 303 Key: c.path + key, 304 Value: []byte(value), 305 SessionName: "Vault Lock", 306 MonitorRetries: 5, 307 } 308 lock, err := c.client.LockOpts(opts) 309 if err != nil { 310 return nil, fmt.Errorf("failed to create lock: %v", err) 311 } 312 cl := &ConsulLock{ 313 client: c.client, 314 key: c.path + key, 315 lock: lock, 316 } 317 return cl, nil 318} 319 320// DetectHostAddr is used to detect the host address by asking the Consul agent 321func (c *ConsulBackend) DetectHostAddr() (string, error) { 322 agent := c.client.Agent() 323 self, err := agent.Self() 324 if err != nil { 325 return "", err 326 } 327 addr, ok := self["Member"]["Addr"].(string) 328 if !ok { 329 return "", fmt.Errorf("Unable to convert an address to string") 330 } 331 return addr, nil 332} 333 334// ConsulLock is used to provide the Lock interface backed by Consul 335type ConsulLock struct { 336 client *api.Client 337 key string 338 lock *api.Lock 339} 340 341func (c *ConsulLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { 342 return c.lock.Lock(stopCh) 343} 344 345func (c *ConsulLock) Unlock() error { 346 return c.lock.Unlock() 347} 348 349func (c *ConsulLock) Value() (bool, string, error) { 350 kv := c.client.KV() 351 352 pair, _, err := kv.Get(c.key, nil) 353 if err != nil { 354 return false, "", err 355 } 356 if pair == nil { 357 return false, "", nil 358 } 359 held := pair.Session != "" 360 value := string(pair.Value) 361 return held, value, nil 362} 363 364func (c *ConsulBackend) NotifyActiveStateChange() error { 365 select { 366 case c.notifyActiveCh <- notifyEvent{}: 367 default: 368 // NOTE: If this occurs Vault's active status could be out of 369 // sync with Consul until reconcileTimer expires. 370 c.logger.Printf("[WARN]: consul: Concurrent state change notify dropped") 371 } 372 373 return nil 374} 375 376func (c *ConsulBackend) NotifySealedStateChange() error { 377 select { 378 case c.notifySealedCh <- notifyEvent{}: 379 default: 380 // NOTE: If this occurs Vault's sealed status could be out of 381 // sync with Consul until checkTimer expires. 382 c.logger.Printf("[WARN]: consul: Concurrent sealed state change notify dropped") 383 } 384 385 return nil 386} 387 388func (c *ConsulBackend) checkDuration() time.Duration { 389 return lib.DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor) 390} 391 392func (c *ConsulBackend) RunServiceDiscovery(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) (err error) { 393 if err := c.setAdvertiseAddr(advertiseAddr); err != nil { 394 return err 395 } 396 397 go c.runEventDemuxer(shutdownCh, advertiseAddr, activeFunc, sealedFunc) 398 399 return nil 400} 401 402func (c *ConsulBackend) runEventDemuxer(shutdownCh ShutdownChannel, advertiseAddr string, activeFunc activeFunction, sealedFunc sealedFunction) { 403 // Fire the reconcileTimer immediately upon starting the event demuxer 404 reconcileTimer := time.NewTimer(0) 405 defer reconcileTimer.Stop() 406 407 // Schedule the first check. Consul TTL checks are passing by 408 // default, checkTimer does not need to be run immediately. 409 checkTimer := time.NewTimer(c.checkDuration()) 410 defer checkTimer.Stop() 411 412 // Use a reactor pattern to handle and dispatch events to singleton 413 // goroutine handlers for execution. It is not acceptable to drop 414 // inbound events from Notify*(). 415 // 416 // goroutines are dispatched if the demuxer can acquire a lock (via 417 // an atomic CAS incr) on the handler. Handlers are responsible for 418 // deregistering themselves (atomic CAS decr). Handlers and the 419 // demuxer share a lock to synchronize information at the beginning 420 // and end of a handler's life (or after a handler wakes up from 421 // sleeping during a back-off/retry). 422 var shutdown bool 423 var checkLock int64 424 var registeredServiceID string 425 var serviceRegLock int64 426shutdown: 427 for { 428 select { 429 case <-c.notifyActiveCh: 430 // Run reconcile immediately upon active state change notification 431 reconcileTimer.Reset(0) 432 case <-c.notifySealedCh: 433 // Run check timer immediately upon a seal state change notification 434 checkTimer.Reset(0) 435 case <-reconcileTimer.C: 436 // Unconditionally rearm the reconcileTimer 437 reconcileTimer.Reset(reconcileTimeout - lib.RandomStagger(reconcileTimeout/checkJitterFactor)) 438 439 // Abort if service discovery is disabled or a 440 // reconcile handler is already active 441 if !c.disableRegistration && atomic.CompareAndSwapInt64(&serviceRegLock, 0, 1) { 442 // Enter handler with serviceRegLock held 443 go func() { 444 defer atomic.CompareAndSwapInt64(&serviceRegLock, 1, 0) 445 for !shutdown { 446 serviceID, err := c.reconcileConsul(registeredServiceID, activeFunc, sealedFunc) 447 if err != nil { 448 c.logger.Printf("[WARN]: consul: reconcile unable to talk with Consul backend: %v", err) 449 time.Sleep(consulRetryInterval) 450 continue 451 } 452 453 c.serviceLock.Lock() 454 defer c.serviceLock.Unlock() 455 456 registeredServiceID = serviceID 457 return 458 } 459 }() 460 } 461 case <-checkTimer.C: 462 checkTimer.Reset(c.checkDuration()) 463 // Abort if service discovery is disabled or a 464 // reconcile handler is active 465 if !c.disableRegistration && atomic.CompareAndSwapInt64(&checkLock, 0, 1) { 466 // Enter handler with checkLock held 467 go func() { 468 defer atomic.CompareAndSwapInt64(&checkLock, 1, 0) 469 for !shutdown { 470 sealed := sealedFunc() 471 if err := c.runCheck(sealed); err != nil { 472 c.logger.Printf("[WARN]: consul: check unable to talk with Consul backend: %v", err) 473 time.Sleep(consulRetryInterval) 474 continue 475 } 476 return 477 } 478 }() 479 } 480 case <-shutdownCh: 481 c.logger.Printf("[INFO]: consul: Shutting down consul backend") 482 shutdown = true 483 break shutdown 484 } 485 } 486 487 c.serviceLock.RLock() 488 defer c.serviceLock.RUnlock() 489 if err := c.client.Agent().ServiceDeregister(registeredServiceID); err != nil { 490 c.logger.Printf("[WARN]: consul: service deregistration failed: %v", err) 491 } 492} 493 494// checkID returns the ID used for a Consul Check. Assume at least a read 495// lock is held. 496func (c *ConsulBackend) checkID() string { 497 return "vault-sealed-check" 498} 499 500// reconcileConsul queries the state of Vault Core and Consul and fixes up 501// Consul's state according to what's in Vault. reconcileConsul is called 502// without any locks held and can be run concurrently, therefore no changes 503// to ConsulBackend can be made in this method (i.e. wtb const receiver for 504// compiler enforced safety). 505func (c *ConsulBackend) reconcileConsul(registeredServiceID string, activeFunc activeFunction, sealedFunc sealedFunction) (serviceID string, err error) { 506 // Query vault Core for its current state 507 active := activeFunc() 508 sealed := sealedFunc() 509 510 agent := c.client.Agent() 511 512 // Get the current state of Vault from Consul 513 var currentVaultService *api.AgentService 514 if services, err := agent.Services(); err == nil { 515 if service, ok := services[c.serviceName]; ok { 516 currentVaultService = service 517 } 518 } 519 520 serviceID = c.serviceID() 521 tags := serviceTags(active) 522 523 var reregister bool 524 switch { 525 case currentVaultService == nil, 526 registeredServiceID == "": 527 reregister = true 528 default: 529 switch { 530 case len(currentVaultService.Tags) != 1, 531 currentVaultService.Tags[0] != tags[0]: 532 reregister = true 533 } 534 } 535 536 if !reregister { 537 return "", nil 538 } 539 540 service := &api.AgentServiceRegistration{ 541 ID: serviceID, 542 Name: c.serviceName, 543 Tags: tags, 544 Port: int(c.advertisePort), 545 Address: c.advertiseHost, 546 EnableTagOverride: false, 547 } 548 549 checkStatus := api.HealthCritical 550 if !sealed { 551 checkStatus = api.HealthPassing 552 } 553 554 sealedCheck := &api.AgentCheckRegistration{ 555 ID: c.checkID(), 556 Name: "Vault Sealed Status", 557 Notes: "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server", 558 ServiceID: serviceID, 559 AgentServiceCheck: api.AgentServiceCheck{ 560 TTL: c.checkTimeout.String(), 561 Status: checkStatus, 562 }, 563 } 564 565 if err := agent.ServiceRegister(service); err != nil { 566 return "", errwrap.Wrapf(`service registration failed: {{err}}`, err) 567 } 568 569 if err := agent.CheckRegister(sealedCheck); err != nil { 570 return serviceID, errwrap.Wrapf(`service check registration failed: {{err}}`, err) 571 } 572 573 return serviceID, nil 574} 575 576// runCheck immediately pushes a TTL check. 577func (c *ConsulBackend) runCheck(sealed bool) error { 578 // Run a TTL check 579 agent := c.client.Agent() 580 if !sealed { 581 return agent.PassTTL(c.checkID(), "Vault Unsealed") 582 } else { 583 return agent.FailTTL(c.checkID(), "Vault Sealed") 584 } 585} 586 587// serviceID returns the Vault ServiceID for use in Consul. Assume at least 588// a read lock is held. 589func (c *ConsulBackend) serviceID() string { 590 return fmt.Sprintf("%s:%s:%d", c.serviceName, c.advertiseHost, c.advertisePort) 591} 592 593// serviceTags returns all of the relevant tags for Consul. 594func serviceTags(active bool) []string { 595 activeTag := "standby" 596 if active { 597 activeTag = "active" 598 } 599 return []string{activeTag} 600} 601 602func (c *ConsulBackend) setAdvertiseAddr(addr string) (err error) { 603 if addr == "" { 604 return fmt.Errorf("advertise address must not be empty") 605 } 606 607 url, err := url.Parse(addr) 608 if err != nil { 609 return errwrap.Wrapf(fmt.Sprintf(`failed to parse advertise URL "%v": {{err}}`, addr), err) 610 } 611 612 var portStr string 613 c.advertiseHost, portStr, err = net.SplitHostPort(url.Host) 614 if err != nil { 615 if url.Scheme == "http" { 616 portStr = "80" 617 } else if url.Scheme == "https" { 618 portStr = "443" 619 } else if url.Scheme == "unix" { 620 portStr = "-1" 621 c.advertiseHost = url.Path 622 } else { 623 return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in advertise address "%v": {{err}}`, url.Host), err) 624 } 625 } 626 c.advertisePort, err = strconv.ParseInt(portStr, 10, 0) 627 if err != nil || c.advertisePort < -1 || c.advertisePort > 65535 { 628 return errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err) 629 } 630 631 return nil 632} 633