1package consul 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "io/ioutil" 8 "net" 9 "net/http" 10 "net/url" 11 "regexp" 12 "strconv" 13 "strings" 14 "sync" 15 "sync/atomic" 16 "time" 17 18 "golang.org/x/net/http2" 19 20 log "github.com/hashicorp/go-hclog" 21 22 "crypto/tls" 23 "crypto/x509" 24 25 metrics "github.com/armon/go-metrics" 26 "github.com/hashicorp/consul/api" 27 "github.com/hashicorp/errwrap" 28 multierror "github.com/hashicorp/go-multierror" 29 "github.com/hashicorp/vault/sdk/helper/consts" 30 "github.com/hashicorp/vault/sdk/helper/parseutil" 31 "github.com/hashicorp/vault/sdk/helper/strutil" 32 "github.com/hashicorp/vault/sdk/helper/tlsutil" 33 "github.com/hashicorp/vault/sdk/physical" 34) 35 36const ( 37 // checkJitterFactor specifies the jitter factor used to stagger checks 38 checkJitterFactor = 16 39 40 // checkMinBuffer specifies provides a guarantee that a check will not 41 // be executed too close to the TTL check timeout 42 checkMinBuffer = 100 * time.Millisecond 43 44 // consulRetryInterval specifies the retry duration to use when an 45 // API call to the Consul agent fails. 46 consulRetryInterval = 1 * time.Second 47 48 // defaultCheckTimeout changes the timeout of TTL checks 49 defaultCheckTimeout = 5 * time.Second 50 51 // DefaultServiceName is the default Consul service name used when 52 // advertising a Vault instance. 53 DefaultServiceName = "vault" 54 55 // reconcileTimeout is how often Vault should query Consul to detect 56 // and fix any state drift. 57 reconcileTimeout = 60 * time.Second 58 59 // consistencyModeDefault is the configuration value used to tell 60 // consul to use default consistency. 61 consistencyModeDefault = "default" 62 63 // consistencyModeStrong is the configuration value used to tell 64 // consul to use strong consistency. 65 consistencyModeStrong = "strong" 66) 67 68type notifyEvent struct{} 69 70// Verify ConsulBackend satisfies the correct interfaces 71var _ physical.Backend = (*ConsulBackend)(nil) 72var _ physical.HABackend = (*ConsulBackend)(nil) 73var _ physical.Lock = (*ConsulLock)(nil) 74var _ physical.Transactional = (*ConsulBackend)(nil) 75var _ physical.ServiceDiscovery = (*ConsulBackend)(nil) 76 77var ( 78 hostnameRegex = regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`) 79) 80 81// ConsulBackend is a physical backend that stores data at specific 82// prefix within Consul. It is used for most production situations as 83// it allows Vault to run on multiple machines in a highly-available manner. 84type ConsulBackend struct { 85 path string 86 logger log.Logger 87 client *api.Client 88 kv *api.KV 89 permitPool *physical.PermitPool 90 serviceLock sync.RWMutex 91 redirectHost string 92 redirectPort int64 93 serviceName string 94 serviceTags []string 95 serviceAddress *string 96 disableRegistration bool 97 checkTimeout time.Duration 98 consistencyMode string 99 100 notifyActiveCh chan notifyEvent 101 notifySealedCh chan notifyEvent 102 notifyPerfStandbyCh chan notifyEvent 103 104 sessionTTL string 105 lockWaitTime time.Duration 106} 107 108// NewConsulBackend constructs a Consul backend using the given API client 109// and the prefix in the KV store. 110func NewConsulBackend(conf map[string]string, logger log.Logger) (physical.Backend, error) { 111 // Get the path in Consul 112 path, ok := conf["path"] 113 if !ok { 114 path = "vault/" 115 } 116 if logger.IsDebug() { 117 logger.Debug("config path set", "path", path) 118 } 119 120 // Ensure path is suffixed but not prefixed 121 if !strings.HasSuffix(path, "/") { 122 logger.Warn("appending trailing forward slash to path") 123 path += "/" 124 } 125 if strings.HasPrefix(path, "/") { 126 logger.Warn("trimming path of its forward slash") 127 path = strings.TrimPrefix(path, "/") 128 } 129 130 // Allow admins to disable consul integration 131 disableReg, ok := conf["disable_registration"] 132 var disableRegistration bool 133 if ok && disableReg != "" { 134 b, err := parseutil.ParseBool(disableReg) 135 if err != nil { 136 return nil, errwrap.Wrapf("failed parsing disable_registration parameter: {{err}}", err) 137 } 138 disableRegistration = b 139 } 140 if logger.IsDebug() { 141 logger.Debug("config disable_registration set", "disable_registration", disableRegistration) 142 } 143 144 // Get the service name to advertise in Consul 145 service, ok := conf["service"] 146 if !ok { 147 service = DefaultServiceName 148 } 149 if !hostnameRegex.MatchString(service) { 150 return nil, errors.New("service name must be valid per RFC 1123 and can contain only alphanumeric characters or dashes") 151 } 152 if logger.IsDebug() { 153 logger.Debug("config service set", "service", service) 154 } 155 156 // Get the additional tags to attach to the registered service name 157 tags := conf["service_tags"] 158 if logger.IsDebug() { 159 logger.Debug("config service_tags set", "service_tags", tags) 160 } 161 162 // Get the service-specific address to override the use of the HA redirect address 163 var serviceAddr *string 164 serviceAddrStr, ok := conf["service_address"] 165 if ok { 166 serviceAddr = &serviceAddrStr 167 } 168 if logger.IsDebug() { 169 logger.Debug("config service_address set", "service_address", serviceAddr) 170 } 171 172 checkTimeout := defaultCheckTimeout 173 checkTimeoutStr, ok := conf["check_timeout"] 174 if ok { 175 d, err := parseutil.ParseDurationSecond(checkTimeoutStr) 176 if err != nil { 177 return nil, err 178 } 179 180 min, _ := DurationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor) 181 if min < checkMinBuffer { 182 return nil, fmt.Errorf("consul check_timeout must be greater than %v", min) 183 } 184 185 checkTimeout = d 186 if logger.IsDebug() { 187 logger.Debug("config check_timeout set", "check_timeout", d) 188 } 189 } 190 191 sessionTTL := api.DefaultLockSessionTTL 192 sessionTTLStr, ok := conf["session_ttl"] 193 if ok { 194 _, err := parseutil.ParseDurationSecond(sessionTTLStr) 195 if err != nil { 196 return nil, errwrap.Wrapf("invalid session_ttl: {{err}}", err) 197 } 198 sessionTTL = sessionTTLStr 199 if logger.IsDebug() { 200 logger.Debug("config session_ttl set", "session_ttl", sessionTTL) 201 } 202 } 203 204 lockWaitTime := api.DefaultLockWaitTime 205 lockWaitTimeRaw, ok := conf["lock_wait_time"] 206 if ok { 207 d, err := parseutil.ParseDurationSecond(lockWaitTimeRaw) 208 if err != nil { 209 return nil, errwrap.Wrapf("invalid lock_wait_time: {{err}}", err) 210 } 211 lockWaitTime = d 212 if logger.IsDebug() { 213 logger.Debug("config lock_wait_time set", "lock_wait_time", d) 214 } 215 } 216 217 // Configure the client 218 consulConf := api.DefaultConfig() 219 // Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore 220 consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount 221 222 if addr, ok := conf["address"]; ok { 223 consulConf.Address = addr 224 if logger.IsDebug() { 225 logger.Debug("config address set", "address", addr) 226 } 227 228 // Copied from the Consul API module; set the Scheme based on 229 // the protocol field if address looks ike a URL. 230 // This can enable the TLS configuration below. 231 parts := strings.SplitN(addr, "://", 2) 232 if len(parts) == 2 { 233 if parts[0] == "http" || parts[0] == "https" { 234 consulConf.Scheme = parts[0] 235 consulConf.Address = parts[1] 236 if logger.IsDebug() { 237 logger.Debug("config address parsed", "scheme", parts[0]) 238 logger.Debug("config scheme parsed", "address", parts[1]) 239 } 240 } // allow "unix:" or whatever else consul supports in the future 241 } 242 } 243 if scheme, ok := conf["scheme"]; ok { 244 consulConf.Scheme = scheme 245 if logger.IsDebug() { 246 logger.Debug("config scheme set", "scheme", scheme) 247 } 248 } 249 if token, ok := conf["token"]; ok { 250 consulConf.Token = token 251 logger.Debug("config token set") 252 } 253 254 if consulConf.Scheme == "https" { 255 // Use the parsed Address instead of the raw conf['address'] 256 tlsClientConfig, err := setupTLSConfig(conf, consulConf.Address) 257 if err != nil { 258 return nil, err 259 } 260 261 consulConf.Transport.TLSClientConfig = tlsClientConfig 262 if err := http2.ConfigureTransport(consulConf.Transport); err != nil { 263 return nil, err 264 } 265 logger.Debug("configured TLS") 266 } 267 268 consulConf.HttpClient = &http.Client{Transport: consulConf.Transport} 269 client, err := api.NewClient(consulConf) 270 if err != nil { 271 return nil, errwrap.Wrapf("client setup failed: {{err}}", err) 272 } 273 274 maxParStr, ok := conf["max_parallel"] 275 var maxParInt int 276 if ok { 277 maxParInt, err = strconv.Atoi(maxParStr) 278 if err != nil { 279 return nil, errwrap.Wrapf("failed parsing max_parallel parameter: {{err}}", err) 280 } 281 if logger.IsDebug() { 282 logger.Debug("max_parallel set", "max_parallel", maxParInt) 283 } 284 } 285 286 consistencyMode, ok := conf["consistency_mode"] 287 if ok { 288 switch consistencyMode { 289 case consistencyModeDefault, consistencyModeStrong: 290 default: 291 return nil, fmt.Errorf("invalid consistency_mode value: %q", consistencyMode) 292 } 293 } else { 294 consistencyMode = consistencyModeDefault 295 } 296 297 // Setup the backend 298 c := &ConsulBackend{ 299 path: path, 300 logger: logger, 301 client: client, 302 kv: client.KV(), 303 permitPool: physical.NewPermitPool(maxParInt), 304 serviceName: service, 305 serviceTags: strutil.ParseDedupLowercaseAndSortStrings(tags, ","), 306 serviceAddress: serviceAddr, 307 checkTimeout: checkTimeout, 308 disableRegistration: disableRegistration, 309 consistencyMode: consistencyMode, 310 notifyActiveCh: make(chan notifyEvent), 311 notifySealedCh: make(chan notifyEvent), 312 notifyPerfStandbyCh: make(chan notifyEvent), 313 sessionTTL: sessionTTL, 314 lockWaitTime: lockWaitTime, 315 } 316 return c, nil 317} 318 319func setupTLSConfig(conf map[string]string, address string) (*tls.Config, error) { 320 serverName, _, err := net.SplitHostPort(address) 321 switch { 322 case err == nil: 323 case strings.Contains(err.Error(), "missing port"): 324 serverName = conf["address"] 325 default: 326 return nil, err 327 } 328 329 insecureSkipVerify := false 330 tlsSkipVerify, ok := conf["tls_skip_verify"] 331 332 if ok && tlsSkipVerify != "" { 333 b, err := parseutil.ParseBool(tlsSkipVerify) 334 if err != nil { 335 return nil, errwrap.Wrapf("failed parsing tls_skip_verify parameter: {{err}}", err) 336 } 337 insecureSkipVerify = b 338 } 339 340 tlsMinVersionStr, ok := conf["tls_min_version"] 341 if !ok { 342 // Set the default value 343 tlsMinVersionStr = "tls12" 344 } 345 346 tlsMinVersion, ok := tlsutil.TLSLookup[tlsMinVersionStr] 347 if !ok { 348 return nil, fmt.Errorf("invalid 'tls_min_version'") 349 } 350 351 tlsClientConfig := &tls.Config{ 352 MinVersion: tlsMinVersion, 353 InsecureSkipVerify: insecureSkipVerify, 354 ServerName: serverName, 355 } 356 357 _, okCert := conf["tls_cert_file"] 358 _, okKey := conf["tls_key_file"] 359 360 if okCert && okKey { 361 tlsCert, err := tls.LoadX509KeyPair(conf["tls_cert_file"], conf["tls_key_file"]) 362 if err != nil { 363 return nil, errwrap.Wrapf("client tls setup failed: {{err}}", err) 364 } 365 366 tlsClientConfig.Certificates = []tls.Certificate{tlsCert} 367 } 368 369 if tlsCaFile, ok := conf["tls_ca_file"]; ok { 370 caPool := x509.NewCertPool() 371 372 data, err := ioutil.ReadFile(tlsCaFile) 373 if err != nil { 374 return nil, errwrap.Wrapf("failed to read CA file: {{err}}", err) 375 } 376 377 if !caPool.AppendCertsFromPEM(data) { 378 return nil, fmt.Errorf("failed to parse CA certificate") 379 } 380 381 tlsClientConfig.RootCAs = caPool 382 } 383 384 return tlsClientConfig, nil 385} 386 387// Used to run multiple entries via a transaction 388func (c *ConsulBackend) Transaction(ctx context.Context, txns []*physical.TxnEntry) error { 389 if len(txns) == 0 { 390 return nil 391 } 392 393 ops := make([]*api.KVTxnOp, 0, len(txns)) 394 395 for _, op := range txns { 396 cop := &api.KVTxnOp{ 397 Key: c.path + op.Entry.Key, 398 } 399 switch op.Operation { 400 case physical.DeleteOperation: 401 cop.Verb = api.KVDelete 402 case physical.PutOperation: 403 cop.Verb = api.KVSet 404 cop.Value = op.Entry.Value 405 default: 406 return fmt.Errorf("%q is not a supported transaction operation", op.Operation) 407 } 408 409 ops = append(ops, cop) 410 } 411 412 c.permitPool.Acquire() 413 defer c.permitPool.Release() 414 415 queryOpts := &api.QueryOptions{} 416 queryOpts = queryOpts.WithContext(ctx) 417 418 ok, resp, _, err := c.kv.Txn(ops, queryOpts) 419 if err != nil { 420 if strings.Contains(err.Error(), "is too large") { 421 return errwrap.Wrapf(fmt.Sprintf("%s: {{err}}", physical.ErrValueTooLarge), err) 422 } 423 return err 424 } 425 if ok && len(resp.Errors) == 0 { 426 return nil 427 } 428 429 var retErr *multierror.Error 430 for _, res := range resp.Errors { 431 retErr = multierror.Append(retErr, errors.New(res.What)) 432 } 433 434 return retErr 435} 436 437// Put is used to insert or update an entry 438func (c *ConsulBackend) Put(ctx context.Context, entry *physical.Entry) error { 439 defer metrics.MeasureSince([]string{"consul", "put"}, time.Now()) 440 441 c.permitPool.Acquire() 442 defer c.permitPool.Release() 443 444 pair := &api.KVPair{ 445 Key: c.path + entry.Key, 446 Value: entry.Value, 447 } 448 449 writeOpts := &api.WriteOptions{} 450 writeOpts = writeOpts.WithContext(ctx) 451 452 _, err := c.kv.Put(pair, writeOpts) 453 if err != nil { 454 if strings.Contains(err.Error(), "Value exceeds") { 455 return errwrap.Wrapf(fmt.Sprintf("%s: {{err}}", physical.ErrValueTooLarge), err) 456 } 457 return err 458 } 459 return nil 460} 461 462// Get is used to fetch an entry 463func (c *ConsulBackend) Get(ctx context.Context, key string) (*physical.Entry, error) { 464 defer metrics.MeasureSince([]string{"consul", "get"}, time.Now()) 465 466 c.permitPool.Acquire() 467 defer c.permitPool.Release() 468 469 queryOpts := &api.QueryOptions{} 470 queryOpts = queryOpts.WithContext(ctx) 471 472 if c.consistencyMode == consistencyModeStrong { 473 queryOpts.RequireConsistent = true 474 } 475 476 pair, _, err := c.kv.Get(c.path+key, queryOpts) 477 if err != nil { 478 return nil, err 479 } 480 if pair == nil { 481 return nil, nil 482 } 483 ent := &physical.Entry{ 484 Key: key, 485 Value: pair.Value, 486 } 487 return ent, nil 488} 489 490// Delete is used to permanently delete an entry 491func (c *ConsulBackend) Delete(ctx context.Context, key string) error { 492 defer metrics.MeasureSince([]string{"consul", "delete"}, time.Now()) 493 494 c.permitPool.Acquire() 495 defer c.permitPool.Release() 496 497 writeOpts := &api.WriteOptions{} 498 writeOpts = writeOpts.WithContext(ctx) 499 500 _, err := c.kv.Delete(c.path+key, writeOpts) 501 return err 502} 503 504// List is used to list all the keys under a given 505// prefix, up to the next prefix. 506func (c *ConsulBackend) List(ctx context.Context, prefix string) ([]string, error) { 507 defer metrics.MeasureSince([]string{"consul", "list"}, time.Now()) 508 scan := c.path + prefix 509 510 // The TrimPrefix call below will not work correctly if we have "//" at the 511 // end. This can happen in cases where you are e.g. listing the root of a 512 // prefix in a logical backend via "/" instead of "" 513 if strings.HasSuffix(scan, "//") { 514 scan = scan[:len(scan)-1] 515 } 516 517 c.permitPool.Acquire() 518 defer c.permitPool.Release() 519 520 queryOpts := &api.QueryOptions{} 521 queryOpts = queryOpts.WithContext(ctx) 522 523 out, _, err := c.kv.Keys(scan, "/", queryOpts) 524 for idx, val := range out { 525 out[idx] = strings.TrimPrefix(val, scan) 526 } 527 528 return out, err 529} 530 531// Lock is used for mutual exclusion based on the given key. 532func (c *ConsulBackend) LockWith(key, value string) (physical.Lock, error) { 533 // Create the lock 534 opts := &api.LockOptions{ 535 Key: c.path + key, 536 Value: []byte(value), 537 SessionName: "Vault Lock", 538 MonitorRetries: 5, 539 SessionTTL: c.sessionTTL, 540 LockWaitTime: c.lockWaitTime, 541 } 542 lock, err := c.client.LockOpts(opts) 543 if err != nil { 544 return nil, errwrap.Wrapf("failed to create lock: {{err}}", err) 545 } 546 cl := &ConsulLock{ 547 client: c.client, 548 key: c.path + key, 549 lock: lock, 550 consistencyMode: c.consistencyMode, 551 } 552 return cl, nil 553} 554 555// HAEnabled indicates whether the HA functionality should be exposed. 556// Currently always returns true. 557func (c *ConsulBackend) HAEnabled() bool { 558 return true 559} 560 561// DetectHostAddr is used to detect the host address by asking the Consul agent 562func (c *ConsulBackend) DetectHostAddr() (string, error) { 563 agent := c.client.Agent() 564 self, err := agent.Self() 565 if err != nil { 566 return "", err 567 } 568 addr, ok := self["Member"]["Addr"].(string) 569 if !ok { 570 return "", fmt.Errorf("unable to convert an address to string") 571 } 572 return addr, nil 573} 574 575// ConsulLock is used to provide the Lock interface backed by Consul 576type ConsulLock struct { 577 client *api.Client 578 key string 579 lock *api.Lock 580 consistencyMode string 581} 582 583func (c *ConsulLock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { 584 return c.lock.Lock(stopCh) 585} 586 587func (c *ConsulLock) Unlock() error { 588 return c.lock.Unlock() 589} 590 591func (c *ConsulLock) Value() (bool, string, error) { 592 kv := c.client.KV() 593 594 var queryOptions *api.QueryOptions 595 if c.consistencyMode == consistencyModeStrong { 596 queryOptions = &api.QueryOptions{ 597 RequireConsistent: true, 598 } 599 } 600 601 pair, _, err := kv.Get(c.key, queryOptions) 602 if err != nil { 603 return false, "", err 604 } 605 if pair == nil { 606 return false, "", nil 607 } 608 held := pair.Session != "" 609 value := string(pair.Value) 610 return held, value, nil 611} 612 613func (c *ConsulBackend) NotifyActiveStateChange() error { 614 select { 615 case c.notifyActiveCh <- notifyEvent{}: 616 default: 617 // NOTE: If this occurs Vault's active status could be out of 618 // sync with Consul until reconcileTimer expires. 619 c.logger.Warn("concurrent state change notify dropped") 620 } 621 622 return nil 623} 624 625func (c *ConsulBackend) NotifyPerformanceStandbyStateChange() error { 626 select { 627 case c.notifyPerfStandbyCh <- notifyEvent{}: 628 default: 629 // NOTE: If this occurs Vault's active status could be out of 630 // sync with Consul until reconcileTimer expires. 631 c.logger.Warn("concurrent state change notify dropped") 632 } 633 634 return nil 635} 636 637func (c *ConsulBackend) NotifySealedStateChange() error { 638 select { 639 case c.notifySealedCh <- notifyEvent{}: 640 default: 641 // NOTE: If this occurs Vault's sealed status could be out of 642 // sync with Consul until checkTimer expires. 643 c.logger.Warn("concurrent sealed state change notify dropped") 644 } 645 646 return nil 647} 648 649func (c *ConsulBackend) checkDuration() time.Duration { 650 return DurationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor) 651} 652 653func (c *ConsulBackend) RunServiceDiscovery(waitGroup *sync.WaitGroup, shutdownCh physical.ShutdownChannel, redirectAddr string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction, perfStandbyFunc physical.PerformanceStandbyFunction) (err error) { 654 if err := c.setRedirectAddr(redirectAddr); err != nil { 655 return err 656 } 657 658 // 'server' command will wait for the below goroutine to complete 659 waitGroup.Add(1) 660 661 go c.runEventDemuxer(waitGroup, shutdownCh, redirectAddr, activeFunc, sealedFunc, perfStandbyFunc) 662 663 return nil 664} 665 666func (c *ConsulBackend) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh physical.ShutdownChannel, redirectAddr string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction, perfStandbyFunc physical.PerformanceStandbyFunction) { 667 // This defer statement should be executed last. So push it first. 668 defer waitGroup.Done() 669 670 // Fire the reconcileTimer immediately upon starting the event demuxer 671 reconcileTimer := time.NewTimer(0) 672 defer reconcileTimer.Stop() 673 674 // Schedule the first check. Consul TTL checks are passing by 675 // default, checkTimer does not need to be run immediately. 676 checkTimer := time.NewTimer(c.checkDuration()) 677 defer checkTimer.Stop() 678 679 // Use a reactor pattern to handle and dispatch events to singleton 680 // goroutine handlers for execution. It is not acceptable to drop 681 // inbound events from Notify*(). 682 // 683 // goroutines are dispatched if the demuxer can acquire a lock (via 684 // an atomic CAS incr) on the handler. Handlers are responsible for 685 // deregistering themselves (atomic CAS decr). Handlers and the 686 // demuxer share a lock to synchronize information at the beginning 687 // and end of a handler's life (or after a handler wakes up from 688 // sleeping during a back-off/retry). 689 var shutdown bool 690 var registeredServiceID string 691 checkLock := new(int32) 692 serviceRegLock := new(int32) 693 694 for !shutdown { 695 select { 696 case <-c.notifyActiveCh: 697 // Run reconcile immediately upon active state change notification 698 reconcileTimer.Reset(0) 699 case <-c.notifySealedCh: 700 // Run check timer immediately upon a seal state change notification 701 checkTimer.Reset(0) 702 case <-c.notifyPerfStandbyCh: 703 // Run check timer immediately upon a seal state change notification 704 checkTimer.Reset(0) 705 case <-reconcileTimer.C: 706 // Unconditionally rearm the reconcileTimer 707 reconcileTimer.Reset(reconcileTimeout - RandomStagger(reconcileTimeout/checkJitterFactor)) 708 709 // Abort if service discovery is disabled or a 710 // reconcile handler is already active 711 if !c.disableRegistration && atomic.CompareAndSwapInt32(serviceRegLock, 0, 1) { 712 // Enter handler with serviceRegLock held 713 go func() { 714 defer atomic.CompareAndSwapInt32(serviceRegLock, 1, 0) 715 for !shutdown { 716 serviceID, err := c.reconcileConsul(registeredServiceID, activeFunc, sealedFunc, perfStandbyFunc) 717 if err != nil { 718 if c.logger.IsWarn() { 719 c.logger.Warn("reconcile unable to talk with Consul backend", "error", err) 720 } 721 time.Sleep(consulRetryInterval) 722 continue 723 } 724 725 c.serviceLock.Lock() 726 defer c.serviceLock.Unlock() 727 728 registeredServiceID = serviceID 729 return 730 } 731 }() 732 } 733 case <-checkTimer.C: 734 checkTimer.Reset(c.checkDuration()) 735 // Abort if service discovery is disabled or a 736 // reconcile handler is active 737 if !c.disableRegistration && atomic.CompareAndSwapInt32(checkLock, 0, 1) { 738 // Enter handler with checkLock held 739 go func() { 740 defer atomic.CompareAndSwapInt32(checkLock, 1, 0) 741 for !shutdown { 742 sealed := sealedFunc() 743 if err := c.runCheck(sealed); err != nil { 744 if c.logger.IsWarn() { 745 c.logger.Warn("check unable to talk with Consul backend", "error", err) 746 } 747 time.Sleep(consulRetryInterval) 748 continue 749 } 750 return 751 } 752 }() 753 } 754 case <-shutdownCh: 755 c.logger.Info("shutting down consul backend") 756 shutdown = true 757 } 758 } 759 760 c.serviceLock.RLock() 761 defer c.serviceLock.RUnlock() 762 if err := c.client.Agent().ServiceDeregister(registeredServiceID); err != nil { 763 if c.logger.IsWarn() { 764 c.logger.Warn("service deregistration failed", "error", err) 765 } 766 } 767} 768 769// checkID returns the ID used for a Consul Check. Assume at least a read 770// lock is held. 771func (c *ConsulBackend) checkID() string { 772 return fmt.Sprintf("%s:vault-sealed-check", c.serviceID()) 773} 774 775// serviceID returns the Vault ServiceID for use in Consul. Assume at least 776// a read lock is held. 777func (c *ConsulBackend) serviceID() string { 778 return fmt.Sprintf("%s:%s:%d", c.serviceName, c.redirectHost, c.redirectPort) 779} 780 781// reconcileConsul queries the state of Vault Core and Consul and fixes up 782// Consul's state according to what's in Vault. reconcileConsul is called 783// without any locks held and can be run concurrently, therefore no changes 784// to ConsulBackend can be made in this method (i.e. wtb const receiver for 785// compiler enforced safety). 786func (c *ConsulBackend) reconcileConsul(registeredServiceID string, activeFunc physical.ActiveFunction, sealedFunc physical.SealedFunction, perfStandbyFunc physical.PerformanceStandbyFunction) (serviceID string, err error) { 787 // Query vault Core for its current state 788 active := activeFunc() 789 sealed := sealedFunc() 790 perfStandby := perfStandbyFunc() 791 792 agent := c.client.Agent() 793 catalog := c.client.Catalog() 794 795 serviceID = c.serviceID() 796 797 // Get the current state of Vault from Consul 798 var currentVaultService *api.CatalogService 799 if services, _, err := catalog.Service(c.serviceName, "", &api.QueryOptions{AllowStale: true}); err == nil { 800 for _, service := range services { 801 if serviceID == service.ServiceID { 802 currentVaultService = service 803 break 804 } 805 } 806 } 807 808 tags := c.fetchServiceTags(active, perfStandby) 809 810 var reregister bool 811 812 switch { 813 case currentVaultService == nil, registeredServiceID == "": 814 reregister = true 815 default: 816 switch { 817 case !strutil.EquivalentSlices(currentVaultService.ServiceTags, tags): 818 reregister = true 819 } 820 } 821 822 if !reregister { 823 // When re-registration is not required, return a valid serviceID 824 // to avoid registration in the next cycle. 825 return serviceID, nil 826 } 827 828 // If service address was set explicitly in configuration, use that 829 // as the service-specific address instead of the HA redirect address. 830 var serviceAddress string 831 if c.serviceAddress == nil { 832 serviceAddress = c.redirectHost 833 } else { 834 serviceAddress = *c.serviceAddress 835 } 836 837 service := &api.AgentServiceRegistration{ 838 ID: serviceID, 839 Name: c.serviceName, 840 Tags: tags, 841 Port: int(c.redirectPort), 842 Address: serviceAddress, 843 EnableTagOverride: false, 844 } 845 846 checkStatus := api.HealthCritical 847 if !sealed { 848 checkStatus = api.HealthPassing 849 } 850 851 sealedCheck := &api.AgentCheckRegistration{ 852 ID: c.checkID(), 853 Name: "Vault Sealed Status", 854 Notes: "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server", 855 ServiceID: serviceID, 856 AgentServiceCheck: api.AgentServiceCheck{ 857 TTL: c.checkTimeout.String(), 858 Status: checkStatus, 859 }, 860 } 861 862 if err := agent.ServiceRegister(service); err != nil { 863 return "", errwrap.Wrapf(`service registration failed: {{err}}`, err) 864 } 865 866 if err := agent.CheckRegister(sealedCheck); err != nil { 867 return serviceID, errwrap.Wrapf(`service check registration failed: {{err}}`, err) 868 } 869 870 return serviceID, nil 871} 872 873// runCheck immediately pushes a TTL check. 874func (c *ConsulBackend) runCheck(sealed bool) error { 875 // Run a TTL check 876 agent := c.client.Agent() 877 if !sealed { 878 return agent.PassTTL(c.checkID(), "Vault Unsealed") 879 } else { 880 return agent.FailTTL(c.checkID(), "Vault Sealed") 881 } 882} 883 884// fetchServiceTags returns all of the relevant tags for Consul. 885func (c *ConsulBackend) fetchServiceTags(active bool, perfStandby bool) []string { 886 activeTag := "standby" 887 if active { 888 activeTag = "active" 889 } 890 891 result := append(c.serviceTags, activeTag) 892 893 if perfStandby { 894 result = append(c.serviceTags, "performance-standby") 895 } 896 897 return result 898} 899 900func (c *ConsulBackend) setRedirectAddr(addr string) (err error) { 901 if addr == "" { 902 return fmt.Errorf("redirect address must not be empty") 903 } 904 905 url, err := url.Parse(addr) 906 if err != nil { 907 return errwrap.Wrapf(fmt.Sprintf("failed to parse redirect URL %q: {{err}}", addr), err) 908 } 909 910 var portStr string 911 c.redirectHost, portStr, err = net.SplitHostPort(url.Host) 912 if err != nil { 913 if url.Scheme == "http" { 914 portStr = "80" 915 } else if url.Scheme == "https" { 916 portStr = "443" 917 } else if url.Scheme == "unix" { 918 portStr = "-1" 919 c.redirectHost = url.Path 920 } else { 921 return errwrap.Wrapf(fmt.Sprintf(`failed to find a host:port in redirect address "%v": {{err}}`, url.Host), err) 922 } 923 } 924 c.redirectPort, err = strconv.ParseInt(portStr, 10, 0) 925 if err != nil || c.redirectPort < -1 || c.redirectPort > 65535 { 926 return errwrap.Wrapf(fmt.Sprintf(`failed to parse valid port "%v": {{err}}`, portStr), err) 927 } 928 929 return nil 930} 931