1package consul 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "io/ioutil" 8 "net" 9 "net/http" 10 "net/url" 11 "regexp" 12 "strconv" 13 "strings" 14 "sync" 15 "sync/atomic" 16 "time" 17 18 "golang.org/x/net/http2" 19 20 log "github.com/hashicorp/go-hclog" 21 22 "crypto/tls" 23 "crypto/x509" 24 25 metrics "github.com/armon/go-metrics" 26 "github.com/hashicorp/consul/api" 27 "github.com/hashicorp/consul/lib" 28 "github.com/hashicorp/errwrap" 29 multierror "github.com/hashicorp/go-multierror" 30 "github.com/hashicorp/vault/helper/consts" 31 "github.com/hashicorp/vault/helper/parseutil" 32 "github.com/hashicorp/vault/helper/strutil" 33 "github.com/hashicorp/vault/helper/tlsutil" 34 "github.com/hashicorp/vault/physical" 35) 36 37const ( 38 // checkJitterFactor specifies the jitter factor used to stagger checks 39 checkJitterFactor = 16 40 41 // checkMinBuffer specifies provides a guarantee that a check will not 42 // be executed too close to the TTL check timeout 43 checkMinBuffer = 100 * time.Millisecond 44 45 // consulRetryInterval specifies the retry duration to use when an 46 // API call to the Consul agent fails. 47 consulRetryInterval = 1 * time.Second 48 49 // defaultCheckTimeout changes the timeout of TTL checks 50 defaultCheckTimeout = 5 * time.Second 51 52 // DefaultServiceName is the default Consul service name used when 53 // advertising a Vault instance. 54 DefaultServiceName = "vault" 55 56 // reconcileTimeout is how often Vault should query Consul to detect 57 // and fix any state drift. 58 reconcileTimeout = 60 * time.Second 59 60 // consistencyModeDefault is the configuration value used to tell 61 // consul to use default consistency. 62 consistencyModeDefault = "default" 63 64 // consistencyModeStrong is the configuration value used to tell 65 // consul to use strong consistency. 66 consistencyModeStrong = "strong" 67) 68 69type notifyEvent struct{} 70 71// Verify ConsulBackend satisfies the correct interfaces 72var _ physical.Backend = (*ConsulBackend)(nil) 73var _ physical.HABackend = (*ConsulBackend)(nil) 74var _ physical.Lock = (*ConsulLock)(nil) 75var _ physical.Transactional = (*ConsulBackend)(nil) 76var _ physical.ServiceDiscovery = (*ConsulBackend)(nil) 77 78var ( 79 hostnameRegex = regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`) 80) 81 82// ConsulBackend is a physical backend that stores data at specific 83// prefix within Consul. It is used for most production situations as 84// it allows Vault to run on multiple machines in a highly-available manner. 85type ConsulBackend struct { 86 path string 87 logger log.Logger 88 client *api.Client 89 kv *api.KV 90 permitPool *physical.PermitPool 91 serviceLock sync.RWMutex 92 redirectHost string 93 redirectPort int64 94 serviceName string 95 serviceTags []string 96 serviceAddress *string 97 disableRegistration bool 98 checkTimeout time.Duration 99 consistencyMode string 100 101 notifyActiveCh chan notifyEvent 102 notifySealedCh chan notifyEvent 103 notifyPerfStandbyCh chan notifyEvent 104 105 sessionTTL string 106 lockWaitTime time.Duration 107} 108 109// NewConsulBackend constructs a Consul backend using the given API client 110// and the prefix in the KV store. 111func NewConsulBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { 112 // Get the path in Consul 113 path, ok := conf["path"] 114 if !ok { 115 path = "vault/" 116 } 117 if logger.IsDebug() { 118 logger.Debug("config path set", "path", path) 119 } 120 121 // Ensure path is suffixed but not prefixed 122 if !strings.HasSuffix(path, "/") { 123 logger.Warn("appending trailing forward slash to path") 124 path += "/" 125 } 126 if strings.HasPrefix(path, "/") { 127 logger.Warn("trimming path of its forward slash") 128 path = strings.TrimPrefix(path, "/") 129 } 130 131 // Allow admins to disable consul integration 132 disableReg, ok := conf["disable_registration"] 133 var disableRegistration bool 134 if ok && disableReg != "" { 135 b, err := parseutil.ParseBool(disableReg) 136 if err != nil { 137 return nil, errwrap.Wrapf("failed parsing disable_registration parameter: {{err}}", err) 138 } 139 disableRegistration = b 140 } 141 if logger.IsDebug() { 142 logger.Debug("config disable_registration set", "disable_registration", disableRegistration) 143 } 144 145 // Get the service name to advertise in Consul 146 service, ok := conf["service"] 147 if !ok { 148 service = DefaultServiceName 149 } 150 if !hostnameRegex.MatchString(service) { 151 return nil, errors.New("service name must be valid per RFC 1123 and can contain only alphanumeric characters or dashes") 152 } 153 if logger.IsDebug() { 154 logger.Debug("config service set", "service", service) 155 } 156 157 // Get the additional tags to attach to the registered service name 158 tags := conf["service_tags"] 159 if logger.IsDebug() { 160 logger.Debug("config service_tags set", "service_tags", tags) 161 } 162 163 // Get the service-specific address to override the use of the HA redirect address 164 var serviceAddr *string 165 serviceAddrStr, ok := conf["service_address"] 166 if ok { 167 serviceAddr = &serviceAddrStr 168 } 169 if logger.IsDebug() { 170 logger.Debug("config service_address set", "service_address", serviceAddr) 171 } 172 173 checkTimeout := defaultCheckTimeout 174 checkTimeoutStr, ok := conf["check_timeout"] 175 if ok { 176 d, err := parseutil.ParseDurationSecond(checkTimeoutStr) 177 if err != nil { 178 return nil, err 179 } 180 181 min, _ := lib.DurationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor) 182 if min < checkMinBuffer { 183 return nil, fmt.Errorf("consul check_timeout must be greater than %v", min) 184 } 185 186 checkTimeout = d 187 if logger.IsDebug() { 188 logger.Debug("config check_timeout set", "check_timeout", d) 189 } 190 } 191 192 sessionTTL := api.DefaultLockSessionTTL 193 sessionTTLStr, ok := conf["session_ttl"] 194 if ok { 195 _, err := parseutil.ParseDurationSecond(sessionTTLStr) 196 if err != nil { 197 return nil, errwrap.Wrapf("invalid session_ttl: {{err}}", err) 198 } 199 sessionTTL = sessionTTLStr 200 if logger.IsDebug() { 201 logger.Debug("config session_ttl set", "session_ttl", sessionTTL) 202 } 203 } 204 205 lockWaitTime := api.DefaultLockWaitTime 206 lockWaitTimeRaw, ok := conf["lock_wait_time"] 207 if ok { 208 d, err := parseutil.ParseDurationSecond(lockWaitTimeRaw) 209 if err != nil { 210 return nil, errwrap.Wrapf("invalid lock_wait_time: {{err}}", err) 211 } 212 lockWaitTime = d 213 if logger.IsDebug() { 214 logger.Debug("config lock_wait_time set", "lock_wait_time", d) 215 } 216 } 217 218 // Configure the client 219 consulConf := api.DefaultConfig() 220 // Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore 221 consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount 222 223 if addr, ok := conf["address"]; ok { 224 consulConf.Address = addr 225 if logger.IsDebug() { 226 logger.Debug("config address set", "address", addr) 227 } 228 } 229 if scheme, ok := conf["scheme"]; ok { 230 consulConf.Scheme = scheme 231 if logger.IsDebug() { 232 logger.Debug("config scheme set", "scheme", scheme) 233 } 234 } 235 if token, ok := conf["token"]; ok { 236 consulConf.Token = token 237 logger.Debug("config token set") 238 } 239 240 if consulConf.Scheme == "https" { 241 tlsClientConfig, err := setupTLSConfig(conf) 242 if err != nil { 243 return nil, err 244 } 245 246 consulConf.Transport.TLSClientConfig = tlsClientConfig 247 if err := http2.ConfigureTransport(consulConf.Transport); err != nil { 248 return nil, err 249 } 250 logger.Debug("configured TLS") 251 } 252 253 consulConf.HttpClient = &http.Client{Transport: consulConf.Transport} 254 client, err := api.NewClient(consulConf) 255 if err != nil { 256 return nil, errwrap.Wrapf("client setup failed: {{err}}", err) 257 } 258 259 maxParStr, ok := conf["max_parallel"] 260 var maxParInt int 261 if ok { 262 maxParInt, err = strconv.Atoi(maxParStr) 263 if err != nil { 264 return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err) 265 } 266 if logger.IsDebug() { 267 logger.Debug("max_parallel set", "max_parallel", maxParInt) 268 } 269 } 270 271 consistencyMode, ok := conf["consistency_mode"] 272 if ok { 273 switch consistencyMode { 274 case consistencyModeDefault, consistencyModeStrong: 275 default: 276 return nil, fmt.Errorf("invalid consistency_mode value: %q", consistencyMode) 277 } 278 } else { 279 consistencyMode = consistencyModeDefault 280 } 281 282 // Setup the backend 283 c := &ConsulBackend{ 284 path: path, 285 logger: logger, 286 client: client, 287 kv: client.KV(), 288 permitPool: physical.NewPermitPool(maxParInt), 289 serviceName: service, 290 serviceTags: strutil.ParseDedupLowercaseAndSortStrings(tags, ","), 291 serviceAddress: serviceAddr, 292 checkTimeout: checkTimeout, 293 disableRegistration: disableRegistration, 294 consistencyMode: consistencyMode, 295 notifyActiveCh: make(chan notifyEvent), 296 notifySealedCh: make(chan notifyEvent), 297 notifyPerfStandbyCh: make(chan notifyEvent), 298 sessionTTL: sessionTTL, 299 lockWaitTime: lockWaitTime, 300 } 301 return c, nil 302} 303 304func setupTLSConfig(conf map[string]string) (*tls.Config, error) { 305 serverName, _, err := net.SplitHostPort(conf["address"]) 306 switch { 307 case err == nil: 308 case strings.Contains(err.Error(), "missing port"): 309 serverName = conf["address"] 310 default: 311 return nil, err 312 } 313 314 insecureSkipVerify := false 315 tlsSkipVerify, ok := conf["tls_skip_verify"] 316 317 if ok && tlsSkipVerify != "" { 318 b, err := parseutil.ParseBool(tlsSkipVerify) 319 if err != nil { 320 return nil, errwrap.Wrapf("failed parsing tls_skip_verify parameter: {{err}}", err) 321 } 322 insecureSkipVerify = b 323 } 324 325 tlsMinVersionStr, ok := conf["tls_min_version"] 326 if !ok { 327 // Set the default value 328 tlsMinVersionStr = "tls12" 329 } 330 331 tlsMinVersion, ok := tlsutil.TLSLookup[tlsMinVersionStr] 332 if !ok { 333 return nil, fmt.Errorf("invalid 'tls_min_version'") 334 } 335 336 tlsClientConfig := &tls.Config{ 337 MinVersion: tlsMinVersion, 338 InsecureSkipVerify: insecureSkipVerify, 339 ServerName: serverName, 340 } 341 342 _, okCert := conf["tls_cert_file"] 343 _, okKey := conf["tls_key_file"] 344 345 if okCert && okKey { 346 tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"]) 347 if err != nil { 348 return nil, errwrap.Wrapf("client tls setup failed: {{err}}", err) 349 } 350 351 tlsClientConfig.Certificates = []tls.Certificate{tlsCert} 352 } 353 354 if tlsCaFile, ok := conf["tls_ca_file"]; ok { 355 caPool := x509.NewCertPool() 356 357 data, err := ioutil.ReadFile(tlsCaFile) 358 if err != nil { 359 return nil, errwrap.Wrapf("failed to read CA file: {{err}}", err) 360 } 361 362 if !caPool.AppendCertsFromPEM(data) { 363 return nil, fmt.Errorf("failed to parse CA certificate") 364 } 365 366 tlsClientConfig.RootCAs = caPool 367 } 368 369 return tlsClientConfig, nil 370} 371 372// Used to run multiple entries via a transaction 373func (c *ConsulBackend) Transaction(ctx context.Context, txns []*physical.TxnEntry) error { 374 if len(txns) == 0 { 375 return nil 376 } 377 378 ops := make([]*api.KVTxnOp, 0, len(txns)) 379 380 for _, op := range txns { 381 cop := &api.KVTxnOp{ 382 Key: c.path + op.Entry.Key, 383 } 384 switch op.Operation { 385 case physical.DeleteOperation: 386 cop.Verb = api.KVDelete 387 case physical.PutOperation: 388 cop.Verb = api.KVSet 389 cop.Value = op.Entry.Value 390 default: 391 return fmt.Errorf("%q is not a supported transaction operation", op.Operation) 392 } 393 394 ops = append(ops, cop) 395 } 396 397 c.permitPool.Acquire() 398 defer c.permitPool.Release() 399 400 queryOpts := &api.QueryOptions{} 401 queryOpts = queryOpts.WithContext(ctx) 402 403 ok, resp, _, err := c.kv.Txn(ops, queryOpts) 404 if err != nil { 405 return err 406 } 407 if ok && len(resp.Errors) == 0 { 408 return nil 409 } 410 411 var retErr *multierror.Error 412 for _, res := range resp.Errors { 413 retErr = multierror.Append(retErr, errors.New(res.What)) 414 } 415 416 return retErr 417} 418 419// Put is used to insert or update an entry 420func (c *ConsulBackend) Put(ctx context.Context, entry *physical.Entry) error { 421 defer metrics.MeasureSince([]string{"consul", "put"}, time.Now()) 422 423 c.permitPool.Acquire() 424 defer c.permitPool.Release() 425 426 pair := &api.KVPair{ 427 Key: c.path + entry.Key, 428 Value: entry.Value, 429 } 430 431 writeOpts := &api.WriteOptions{} 432 writeOpts = writeOpts.WithContext(ctx) 433 434 _, err := c.kv.Put(pair, writeOpts) 435 return err 436} 437 438// Get is used to fetch an entry 439func (c *ConsulBackend) Get(ctx context.Context, key string) (*physical.Entry, error) { 440 defer metrics.MeasureSince([]string{"consul", "get"}, time.Now()) 441 442 c.permitPool.Acquire() 443 defer c.permitPool.Release() 444 445 queryOpts := &api.QueryOptions{} 446 queryOpts = queryOpts.WithContext(ctx) 447 448 if c.consistencyMode == consistencyModeStrong { 449 queryOpts.RequireConsistent = true 450 } 451 452 pair, _, err := c.kv.Get(c.path+key, queryOpts) 453 if err != nil { 454 return nil, err 455 } 456 if pair == nil { 457 return nil, nil 458 } 459 ent := &physical.Entry{ 460 Key: key, 461 Value: pair.Value, 462 } 463 return ent, nil 464} 465 466// Delete is used to permanently delete an entry 467func (c *ConsulBackend) Delete(ctx context.Context, key string) error { 468 defer metrics.MeasureSince([]string{"consul", "delete"}, time.Now()) 469 470 c.permitPool.Acquire() 471 defer c.permitPool.Release() 472 473 writeOpts := &api.WriteOptions{} 474 writeOpts = writeOpts.WithContext(ctx) 475 476 _, err := c.kv.Delete(c.path+key, writeOpts) 477 return err 478} 479 480// List is used to list all the keys under a given 481// prefix, up to the next prefix. 482func (c *ConsulBackend) List(ctx context.Context, prefix string) ([]string, error) { 483 defer metrics.MeasureSince([]string{"consul", "list"}, time.Now()) 484 scan := c.path + prefix 485 486 // The TrimPrefix call below will not work correctly if we have "//" at the 487 // end. This can happen in cases where you are e.g. listing the root of a 488 // prefix in a logical backend via "/" instead of "" 489 if strings.HasSuffix(scan, "//") { 490 scan = scan[:len(scan)-1] 491 } 492 493 c.permitPool.Acquire() 494 defer c.permitPool.Release() 495 496 queryOpts := &api.QueryOptions{} 497 queryOpts = queryOpts.WithContext(ctx) 498 499 out, _, err := c.kv.Keys(scan, "/", queryOpts) 500 for idx, val := range out { 501 out[idx] = strings.TrimPrefix(val, scan) 502 } 503 504 return out, err 505} 506 507// Lock is used for mutual exclusion based on the given key. 508func (c *ConsulBackend) LockWith(key, value string) (physical.Lock, error) { 509 // Create the lock 510 opts := &api.LockOptions{ 511 Key: c.path + key, 512 Value: []byte(value), 513 SessionName: "Vault Lock", 514 MonitorRetries: 5, 515 SessionTTL: c.sessionTTL, 516 LockWaitTime: c.lockWaitTime, 517 } 518 lock, err := c.client.LockOpts(opts) 519 if err != nil { 520 return nil, errwrap.Wrapf("failed to create lock: {{err}}", err) 521 } 522 cl := &ConsulLock{ 523 client: c.client, 524 key: c.path + key, 525 lock: lock, 526 consistencyMode: c.consistencyMode, 527 } 528 return cl, nil 529} 530 531// HAEnabled indicates whether the HA functionality should be exposed. 532// Currently always returns true. 533func (c *ConsulBackend) HAEnabled() bool { 534 return true 535} 536 537// DetectHostAddr is used to detect the host address by asking the Consul agent 538func (c *ConsulBackend) DetectHostAddr() (string, error) { 539 agent := c.client.Agent() 540 self, err := agent.Self() 541 if err != nil { 542 return "", err 543 } 544 addr, ok := self["Member"]["Addr"].(string) 545 if !ok { 546 return "", fmt.Errorf("unable to convert an address to string") 547 } 548 return addr, nil 549} 550 551// ConsulLock is used to provide the Lock interface backed by Consul 552type ConsulLock struct { 553 client *api.Client 554 key string 555 lock *api.Lock 556 consistencyMode string 557} 558 559func (c *ConsulLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { 560 return c.lock.Lock(stopCh) 561} 562 563func (c *ConsulLock) Unlock() error { 564 return c.lock.Unlock() 565} 566 567func (c *ConsulLock) Value() (bool, string, error) { 568 kv := c.client.KV() 569 570 var queryOptions *api.QueryOptions 571 if c.consistencyMode == consistencyModeStrong { 572 queryOptions = &api.QueryOptions{ 573 RequireConsistent: true, 574 } 575 } 576 577 pair, _, err := kv.Get(c.key, queryOptions) 578 if err != nil { 579 return false, "", err 580 } 581 if pair == nil { 582 return false, "", nil 583 } 584 held := pair.Session != "" 585 value := string(pair.Value) 586 return held, value, nil 587} 588 589func (c *ConsulBackend) NotifyActiveStateChange() error { 590 select { 591 case c.notifyActiveCh <- notifyEvent{}: 592 default: 593 // NOTE: If this occurs Vault's active status could be out of 594 // sync with Consul until reconcileTimer expires. 595 c.logger.Warn("concurrent state change notify dropped") 596 } 597 598 return nil 599} 600 601func (c *ConsulBackend) NotifyPerformanceStandbyStateChange() error { 602 select { 603 case c.notifyPerfStandbyCh <- notifyEvent{}: 604 default: 605 // NOTE: If this occurs Vault's active status could be out of 606 // sync with Consul until reconcileTimer expires. 607 c.logger.Warn("concurrent state change notify dropped") 608 } 609 610 return nil 611} 612 613func (c *ConsulBackend) NotifySealedStateChange() error { 614 select { 615 case c.notifySealedCh <- notifyEvent{}: 616 default: 617 // NOTE: If this occurs Vault's sealed status could be out of 618 // sync with Consul until checkTimer expires. 619 c.logger.Warn("concurrent sealed state change notify dropped") 620 } 621 622 return nil 623} 624 625func (c *ConsulBackend) checkDuration() time.Duration { 626 return lib.DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor) 627} 628 629func (c *ConsulBackend) RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh physical.ShutdownChannel, redirectAddr string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction, perfStandbyFunc physical.PerformanceStandbyFunction) (err error) { 630 if err := c.setRedirectAddr(redirectAddr); err != nil { 631 return err 632 } 633 634 // 'server' command will wait for the below goroutine to complete 635 waitGroup.Add(1) 636 637 go c.runEventDemuxer(waitGroup, shutdownCh, redirectAddr, activeFunc, sealedFunc, perfStandbyFunc) 638 639 return nil 640} 641 642func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh physical.ShutdownChannel, redirectAddr string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction, perfStandbyFunc physical.PerformanceStandbyFunction) { 643 // This defer statement should be executed last. So push it first. 644 defer waitGroup.Done() 645 646 // Fire the reconcileTimer immediately upon starting the event demuxer 647 reconcileTimer := time.NewTimer(0) 648 defer reconcileTimer.Stop() 649 650 // Schedule the first check. Consul TTL checks are passing by 651 // default, checkTimer does not need to be run immediately. 652 checkTimer := time.NewTimer(c.checkDuration()) 653 defer checkTimer.Stop() 654 655 // Use a reactor pattern to handle and dispatch events to singleton 656 // goroutine handlers for execution. It is not acceptable to drop 657 // inbound events from Notify*(). 658 // 659 // goroutines are dispatched if the demuxer can acquire a lock (via 660 // an atomic CAS incr) on the handler. Handlers are responsible for 661 // deregistering themselves (atomic CAS decr). Handlers and the 662 // demuxer share a lock to synchronize information at the beginning 663 // and end of a handler's life (or after a handler wakes up from 664 // sleeping during a back-off/retry). 665 var shutdown bool 666 var registeredServiceID string 667 checkLock := new(int32) 668 serviceRegLock := new(int32) 669 670 for !shutdown { 671 select { 672 case <-c.notifyActiveCh: 673 // Run reconcile immediately upon active state change notification 674 reconcileTimer.Reset(0) 675 case <-c.notifySealedCh: 676 // Run check timer immediately upon a seal state change notification 677 checkTimer.Reset(0) 678 case <-c.notifyPerfStandbyCh: 679 // Run check timer immediately upon a seal state change notification 680 checkTimer.Reset(0) 681 case <-reconcileTimer.C: 682 // Unconditionally rearm the reconcileTimer 683 reconcileTimer.Reset(reconcileTimeout - lib.RandomStagger(reconcileTimeout/checkJitterFactor)) 684 685 // Abort if service discovery is disabled or a 686 // reconcile handler is already active 687 if !c.disableRegistration && atomic.CompareAndSwapInt32(serviceRegLock, 0, 1) { 688 // Enter handler with serviceRegLock held 689 go func() { 690 defer atomic.CompareAndSwapInt32(serviceRegLock, 1, 0) 691 for !shutdown { 692 serviceID, err := c.reconcileConsul(registeredServiceID, activeFunc, sealedFunc, perfStandbyFunc) 693 if err != nil { 694 if c.logger.IsWarn() { 695 c.logger.Warn("reconcile unable to talk with Consul backend", "error", err) 696 } 697 time.Sleep(consulRetryInterval) 698 continue 699 } 700 701 c.serviceLock.Lock() 702 defer c.serviceLock.Unlock() 703 704 registeredServiceID = serviceID 705 return 706 } 707 }() 708 } 709 case <-checkTimer.C: 710 checkTimer.Reset(c.checkDuration()) 711 // Abort if service discovery is disabled or a 712 // reconcile handler is active 713 if !c.disableRegistration && atomic.CompareAndSwapInt32(checkLock, 0, 1) { 714 // Enter handler with checkLock held 715 go func() { 716 defer atomic.CompareAndSwapInt32(checkLock, 1, 0) 717 for !shutdown { 718 sealed := sealedFunc() 719 if err := c.runCheck(sealed); err != nil { 720 if c.logger.IsWarn() { 721 c.logger.Warn("check unable to talk with Consul backend", "error", err) 722 } 723 time.Sleep(consulRetryInterval) 724 continue 725 } 726 return 727 } 728 }() 729 } 730 case <-shutdownCh: 731 c.logger.Info("shutting down consul backend") 732 shutdown = true 733 } 734 } 735 736 c.serviceLock.RLock() 737 defer c.serviceLock.RUnlock() 738 if err := c.client.Agent().ServiceDeregister(registeredServiceID); err != nil { 739 if c.logger.IsWarn() { 740 c.logger.Warn("service deregistration failed", "error", err) 741 } 742 } 743} 744 745// checkID returns the ID used for a Consul Check. Assume at least a read 746// lock is held. 747func (c *ConsulBackend) checkID() string { 748 return fmt.Sprintf("%s:vault-sealed-check", c.serviceID()) 749} 750 751// serviceID returns the Vault ServiceID for use in Consul. Assume at least 752// a read lock is held. 753func (c *ConsulBackend) serviceID() string { 754 return fmt.Sprintf("%s:%s:%d", c.serviceName, c.redirectHost, c.redirectPort) 755} 756 757// reconcileConsul queries the state of Vault Core and Consul and fixes up 758// Consul's state according to what's in Vault. reconcileConsul is called 759// without any locks held and can be run concurrently, therefore no changes 760// to ConsulBackend can be made in this method (i.e. wtb const receiver for 761// compiler enforced safety). 762func (c *ConsulBackend) reconcileConsul(registeredServiceID string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction, perfStandbyFunc physical.PerformanceStandbyFunction) (serviceID string, err error) { 763 // Query vault Core for its current state 764 active := activeFunc() 765 sealed := sealedFunc() 766 perfStandby := perfStandbyFunc() 767 768 agent := c.client.Agent() 769 catalog := c.client.Catalog() 770 771 serviceID = c.serviceID() 772 773 // Get the current state of Vault from Consul 774 var currentVaultService *api.CatalogService 775 if services, _, err := catalog.Service(c.serviceName, "", &api.QueryOptions{AllowStale: true}); err == nil { 776 for _, service := range services { 777 if serviceID == service.ServiceID { 778 currentVaultService = service 779 break 780 } 781 } 782 } 783 784 tags := c.fetchServiceTags(active, perfStandby) 785 786 var reregister bool 787 788 switch { 789 case currentVaultService == nil, registeredServiceID == "": 790 reregister = true 791 default: 792 switch { 793 case !strutil.EquivalentSlices(currentVaultService.ServiceTags, tags): 794 reregister = true 795 } 796 } 797 798 if !reregister { 799 // When re-registration is not required, return a valid serviceID 800 // to avoid registration in the next cycle. 801 return serviceID, nil 802 } 803 804 // If service address was set explicitly in configuration, use that 805 // as the service-specific address instead of the HA redirect address. 806 var serviceAddress string 807 if c.serviceAddress == nil { 808 serviceAddress = c.redirectHost 809 } else { 810 serviceAddress = *c.serviceAddress 811 } 812 813 service := &api.AgentServiceRegistration{ 814 ID: serviceID, 815 Name: c.serviceName, 816 Tags: tags, 817 Port: int(c.redirectPort), 818 Address: serviceAddress, 819 EnableTagOverride: false, 820 } 821 822 checkStatus := api.HealthCritical 823 if !sealed { 824 checkStatus = api.HealthPassing 825 } 826 827 sealedCheck := &api.AgentCheckRegistration{ 828 ID: c.checkID(), 829 Name: "Vault Sealed Status", 830 Notes: "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server", 831 ServiceID: serviceID, 832 AgentServiceCheck: api.AgentServiceCheck{ 833 TTL: c.checkTimeout.String(), 834 Status: checkStatus, 835 }, 836 } 837 838 if err := agent.ServiceRegister(service); err != nil { 839 return "", errwrap.Wrapf(`service registration failed: {{err}}`, err) 840 } 841 842 if err := agent.CheckRegister(sealedCheck); err != nil { 843 return serviceID, errwrap.Wrapf(`service check registration failed: {{err}}`, err) 844 } 845 846 return serviceID, nil 847} 848 849// runCheck immediately pushes a TTL check. 850func (c *ConsulBackend) runCheck(sealed bool) error { 851 // Run a TTL check 852 agent := c.client.Agent() 853 if !sealed { 854 return agent.PassTTL(c.checkID(), "Vault Unsealed") 855 } else { 856 return agent.FailTTL(c.checkID(), "Vault Sealed") 857 } 858} 859 860// fetchServiceTags returns all of the relevant tags for Consul. 861func (c *ConsulBackend) fetchServiceTags(active bool, perfStandby bool) []string { 862 activeTag := "standby" 863 if active { 864 activeTag = "active" 865 } 866 867 result := append(c.serviceTags, activeTag) 868 869 if perfStandby { 870 result = append(c.serviceTags, "performance-standby") 871 } 872 873 return result 874} 875 876func (c *ConsulBackend) setRedirectAddr(addr string) (err error) { 877 if addr == "" { 878 return fmt.Errorf("redirect address must not be empty") 879 } 880 881 url, err := url.Parse(addr) 882 if err != nil { 883 return errwrap.Wrapf(fmt.Sprintf("failed to parse redirect URL %q: {{err}}", addr), err) 884 } 885 886 var portStr string 887 c.redirectHost, portStr, err = net.SplitHostPort(url.Host) 888 if err != nil { 889 if url.Scheme == "http" { 890 portStr = "80" 891 } else if url.Scheme == "https" { 892 portStr = "443" 893 } else if url.Scheme == "unix" { 894 portStr = "-1" 895 c.redirectHost = url.Path 896 } else { 897 return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in redirect address "%v": {{err}}`, url.Host), err) 898 } 899 } 900 c.redirectPort, err = strconv.ParseInt(portStr, 10, 0) 901 if err != nil || c.redirectPort < -1 || c.redirectPort > 65535 { 902 return errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err) 903 } 904 905 return nil 906} 907