1package consul 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "math/rand" 8 "net" 9 "net/http" 10 "net/url" 11 "regexp" 12 "strconv" 13 "strings" 14 "sync" 15 "sync/atomic" 16 "time" 17 18 "github.com/hashicorp/consul/api" 19 log "github.com/hashicorp/go-hclog" 20 "github.com/hashicorp/vault/sdk/helper/consts" 21 "github.com/hashicorp/vault/sdk/helper/parseutil" 22 "github.com/hashicorp/vault/sdk/helper/strutil" 23 "github.com/hashicorp/vault/sdk/helper/tlsutil" 24 sr "github.com/hashicorp/vault/serviceregistration" 25 "github.com/hashicorp/vault/vault/diagnose" 26 atomicB "go.uber.org/atomic" 27 "golang.org/x/net/http2" 28) 29 30const ( 31 // checkJitterFactor specifies the jitter factor used to stagger checks 32 checkJitterFactor = 16 33 34 // checkMinBuffer specifies provides a guarantee that a check will not 35 // be executed too close to the TTL check timeout 36 checkMinBuffer = 100 * time.Millisecond 37 38 // consulRetryInterval specifies the retry duration to use when an 39 // API call to the Consul agent fails. 40 consulRetryInterval = 1 * time.Second 41 42 // defaultCheckTimeout changes the timeout of TTL checks 43 defaultCheckTimeout = 5 * time.Second 44 45 // DefaultServiceName is the default Consul service name used when 46 // advertising a Vault instance. 47 DefaultServiceName = "vault" 48 49 // reconcileTimeout is how often Vault should query Consul to detect 50 // and fix any state drift. 51 reconcileTimeout = 60 * time.Second 52 53 // metaExternalSource is a metadata value for external-source that can be 54 // used by the Consul UI. 55 metaExternalSource = "vault" 56) 57 58var hostnameRegex = regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`) 59 60// serviceRegistration is a ServiceRegistration that advertises the state of 61// Vault to Consul. 62type serviceRegistration struct { 63 Client *api.Client 64 65 logger log.Logger 66 serviceLock sync.RWMutex 67 redirectHost string 68 redirectPort int64 69 serviceName string 70 serviceTags []string 71 serviceAddress *string 72 disableRegistration bool 73 checkTimeout time.Duration 74 75 notifyActiveCh chan struct{} 76 notifySealedCh chan struct{} 77 notifyPerfStandbyCh chan struct{} 78 notifyInitializedCh chan struct{} 79 80 isActive *atomicB.Bool 81 isSealed *atomicB.Bool 82 isPerfStandby *atomicB.Bool 83 isInitialized *atomicB.Bool 84} 85 86// NewConsulServiceRegistration constructs a Consul-based ServiceRegistration. 87func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.State) (sr.ServiceRegistration, error) { 88 // Allow admins to disable consul integration 89 disableReg, ok := conf["disable_registration"] 90 var disableRegistration bool 91 if ok && disableReg != "" { 92 b, err := parseutil.ParseBool(disableReg) 93 if err != nil { 94 return nil, fmt.Errorf("failed parsing disable_registration parameter: %w", err) 95 } 96 disableRegistration = b 97 } 98 if logger.IsDebug() { 99 logger.Debug("config disable_registration set", "disable_registration", disableRegistration) 100 } 101 102 // Get the service name to advertise in Consul 103 service, ok := conf["service"] 104 if !ok { 105 service = DefaultServiceName 106 } 107 if !hostnameRegex.MatchString(service) { 108 return nil, errors.New("service name must be valid per RFC 1123 and can contain only alphanumeric characters or dashes") 109 } 110 if logger.IsDebug() { 111 logger.Debug("config service set", "service", service) 112 } 113 114 // Get the additional tags to attach to the registered service name 115 tags := conf["service_tags"] 116 if logger.IsDebug() { 117 logger.Debug("config service_tags set", "service_tags", tags) 118 } 119 120 // Get the service-specific address to override the use of the HA redirect address 121 var serviceAddr *string 122 serviceAddrStr, ok := conf["service_address"] 123 if ok { 124 serviceAddr = &serviceAddrStr 125 } 126 if logger.IsDebug() { 127 logger.Debug("config service_address set", "service_address", serviceAddr) 128 } 129 130 checkTimeout := defaultCheckTimeout 131 checkTimeoutStr, ok := conf["check_timeout"] 132 if ok { 133 d, err := parseutil.ParseDurationSecond(checkTimeoutStr) 134 if err != nil { 135 return nil, err 136 } 137 138 min, _ := durationMinusBufferDomain(d, checkMinBuffer, checkJitterFactor) 139 if min < checkMinBuffer { 140 return nil, fmt.Errorf("consul check_timeout must be greater than %v", min) 141 } 142 143 checkTimeout = d 144 if logger.IsDebug() { 145 logger.Debug("config check_timeout set", "check_timeout", d) 146 } 147 } 148 149 // Configure the client 150 consulConf := api.DefaultConfig() 151 // Set MaxIdleConnsPerHost to the number of processes used in expiration.Restore 152 consulConf.Transport.MaxIdleConnsPerHost = consts.ExpirationRestoreWorkerCount 153 154 SetupSecureTLS(context.Background(), consulConf, conf, logger, false) 155 156 consulConf.HttpClient = &http.Client{Transport: consulConf.Transport} 157 client, err := api.NewClient(consulConf) 158 if err != nil { 159 return nil, fmt.Errorf("client setup failed: %w", err) 160 } 161 162 // Setup the backend 163 c := &serviceRegistration{ 164 Client: client, 165 166 logger: logger, 167 serviceName: service, 168 serviceTags: strutil.ParseDedupLowercaseAndSortStrings(tags, ","), 169 serviceAddress: serviceAddr, 170 checkTimeout: checkTimeout, 171 disableRegistration: disableRegistration, 172 173 notifyActiveCh: make(chan struct{}), 174 notifySealedCh: make(chan struct{}), 175 notifyPerfStandbyCh: make(chan struct{}), 176 notifyInitializedCh: make(chan struct{}), 177 178 isActive: atomicB.NewBool(state.IsActive), 179 isSealed: atomicB.NewBool(state.IsSealed), 180 isPerfStandby: atomicB.NewBool(state.IsPerformanceStandby), 181 isInitialized: atomicB.NewBool(state.IsInitialized), 182 } 183 return c, nil 184} 185 186func SetupSecureTLS(ctx context.Context, consulConf *api.Config, conf map[string]string, logger log.Logger, isDiagnose bool) error { 187 if addr, ok := conf["address"]; ok { 188 consulConf.Address = addr 189 if logger.IsDebug() { 190 logger.Debug("config address set", "address", addr) 191 } 192 193 // Copied from the Consul API module; set the Scheme based on 194 // the protocol field if address looks ike a URL. 195 // This can enable the TLS configuration below. 196 parts := strings.SplitN(addr, "://", 2) 197 if len(parts) == 2 { 198 if parts[0] == "http" || parts[0] == "https" { 199 consulConf.Scheme = parts[0] 200 consulConf.Address = parts[1] 201 if logger.IsDebug() { 202 logger.Debug("config address parsed", "scheme", parts[0]) 203 logger.Debug("config scheme parsed", "address", parts[1]) 204 } 205 } // allow "unix:" or whatever else consul supports in the future 206 } 207 } 208 if scheme, ok := conf["scheme"]; ok { 209 consulConf.Scheme = scheme 210 if logger.IsDebug() { 211 logger.Debug("config scheme set", "scheme", scheme) 212 } 213 } 214 if token, ok := conf["token"]; ok { 215 consulConf.Token = token 216 logger.Debug("config token set") 217 } 218 219 if consulConf.Scheme == "https" { 220 if isDiagnose { 221 certPath, okCert := conf["tls_cert_file"] 222 keyPath, okKey := conf["tls_key_file"] 223 if okCert && okKey { 224 warnings, err := diagnose.TLSFileChecks(certPath, keyPath) 225 for _, warning := range warnings { 226 diagnose.Warn(ctx, warning) 227 } 228 if err != nil { 229 return err 230 } 231 return nil 232 } 233 return fmt.Errorf("key or cert path: %s, %s, cannot be loaded from consul config file", certPath, keyPath) 234 } 235 236 // Use the parsed Address instead of the raw conf['address'] 237 tlsClientConfig, err := tlsutil.SetupTLSConfig(conf, consulConf.Address) 238 if err != nil { 239 return err 240 } 241 242 consulConf.Transport.TLSClientConfig = tlsClientConfig 243 if err := http2.ConfigureTransport(consulConf.Transport); err != nil { 244 return err 245 } 246 logger.Debug("configured TLS") 247 } else { 248 if isDiagnose { 249 diagnose.Skipped(ctx, "HTTPS is not used, Skipping TLS verification.") 250 } 251 } 252 return nil 253} 254 255func (c *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup, redirectAddr string) error { 256 go func() { 257 if err := c.runServiceRegistration(wait, shutdownCh, redirectAddr); err != nil { 258 if c.logger.IsError() { 259 c.logger.Error(fmt.Sprintf("error running service registration: %s", err)) 260 } 261 } 262 }() 263 return nil 264} 265 266func (c *serviceRegistration) NotifyActiveStateChange(isActive bool) error { 267 c.isActive.Store(isActive) 268 select { 269 case c.notifyActiveCh <- struct{}{}: 270 default: 271 // NOTE: If this occurs Vault's active status could be out of 272 // sync with Consul until reconcileTimer expires. 273 c.logger.Warn("concurrent state change notify dropped") 274 } 275 276 return nil 277} 278 279func (c *serviceRegistration) NotifyPerformanceStandbyStateChange(isStandby bool) error { 280 c.isPerfStandby.Store(isStandby) 281 select { 282 case c.notifyPerfStandbyCh <- struct{}{}: 283 default: 284 // NOTE: If this occurs Vault's active status could be out of 285 // sync with Consul until reconcileTimer expires. 286 c.logger.Warn("concurrent state change notify dropped") 287 } 288 289 return nil 290} 291 292func (c *serviceRegistration) NotifySealedStateChange(isSealed bool) error { 293 c.isSealed.Store(isSealed) 294 select { 295 case c.notifySealedCh <- struct{}{}: 296 default: 297 // NOTE: If this occurs Vault's sealed status could be out of 298 // sync with Consul until checkTimer expires. 299 c.logger.Warn("concurrent sealed state change notify dropped") 300 } 301 302 return nil 303} 304 305func (c *serviceRegistration) NotifyInitializedStateChange(isInitialized bool) error { 306 c.isInitialized.Store(isInitialized) 307 select { 308 case c.notifyInitializedCh <- struct{}{}: 309 default: 310 // NOTE: If this occurs Vault's initialized status could be out of 311 // sync with Consul until checkTimer expires. 312 c.logger.Warn("concurrent initalize state change notify dropped") 313 } 314 315 return nil 316} 317 318func (c *serviceRegistration) checkDuration() time.Duration { 319 return durationMinusBuffer(c.checkTimeout, checkMinBuffer, checkJitterFactor) 320} 321 322func (c *serviceRegistration) runServiceRegistration(waitGroup *sync.WaitGroup, shutdownCh <-chan struct{}, redirectAddr string) (err error) { 323 if err := c.setRedirectAddr(redirectAddr); err != nil { 324 return err 325 } 326 327 // 'server' command will wait for the below goroutine to complete 328 waitGroup.Add(1) 329 330 go c.runEventDemuxer(waitGroup, shutdownCh) 331 332 return nil 333} 334 335func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh <-chan struct{}) { 336 // This defer statement should be executed last. So push it first. 337 defer waitGroup.Done() 338 339 // Fire the reconcileTimer immediately upon starting the event demuxer 340 reconcileTimer := time.NewTimer(0) 341 defer reconcileTimer.Stop() 342 343 // Schedule the first check. Consul TTL checks are passing by 344 // default, checkTimer does not need to be run immediately. 345 checkTimer := time.NewTimer(c.checkDuration()) 346 defer checkTimer.Stop() 347 348 // Use a reactor pattern to handle and dispatch events to singleton 349 // goroutine handlers for execution. It is not acceptable to drop 350 // inbound events from Notify*(). 351 // 352 // goroutines are dispatched if the demuxer can acquire a lock (via 353 // an atomic CAS incr) on the handler. Handlers are responsible for 354 // deregistering themselves (atomic CAS decr). Handlers and the 355 // demuxer share a lock to synchronize information at the beginning 356 // and end of a handler's life (or after a handler wakes up from 357 // sleeping during a back-off/retry). 358 var shutdown atomicB.Bool 359 var registeredServiceID string 360 checkLock := new(int32) 361 serviceRegLock := new(int32) 362 363 for !shutdown.Load() { 364 select { 365 case <-c.notifyActiveCh: 366 // Run reconcile immediately upon active state change notification 367 reconcileTimer.Reset(0) 368 case <-c.notifySealedCh: 369 // Run check timer immediately upon a seal state change notification 370 checkTimer.Reset(0) 371 case <-c.notifyPerfStandbyCh: 372 // Run check timer immediately upon a perfstandby state change notification 373 checkTimer.Reset(0) 374 case <-c.notifyInitializedCh: 375 // Run check timer immediately upon an initialized state change notification 376 checkTimer.Reset(0) 377 case <-reconcileTimer.C: 378 // Unconditionally rearm the reconcileTimer 379 reconcileTimer.Reset(reconcileTimeout - randomStagger(reconcileTimeout/checkJitterFactor)) 380 381 // Abort if service discovery is disabled or a 382 // reconcile handler is already active 383 if !c.disableRegistration && atomic.CompareAndSwapInt32(serviceRegLock, 0, 1) { 384 // Enter handler with serviceRegLock held 385 go func() { 386 defer atomic.CompareAndSwapInt32(serviceRegLock, 1, 0) 387 for !shutdown.Load() { 388 serviceID, err := c.reconcileConsul(registeredServiceID) 389 if err != nil { 390 if c.logger.IsWarn() { 391 c.logger.Warn("reconcile unable to talk with Consul backend", "error", err) 392 } 393 time.Sleep(consulRetryInterval) 394 continue 395 } 396 397 c.serviceLock.Lock() 398 defer c.serviceLock.Unlock() 399 400 registeredServiceID = serviceID 401 return 402 } 403 }() 404 } 405 case <-checkTimer.C: 406 checkTimer.Reset(c.checkDuration()) 407 // Abort if service discovery is disabled or a 408 // reconcile handler is active 409 if !c.disableRegistration && atomic.CompareAndSwapInt32(checkLock, 0, 1) { 410 // Enter handler with checkLock held 411 go func() { 412 defer atomic.CompareAndSwapInt32(checkLock, 1, 0) 413 for !shutdown.Load() { 414 if err := c.runCheck(c.isSealed.Load()); err != nil { 415 if c.logger.IsWarn() { 416 c.logger.Warn("check unable to talk with Consul backend", "error", err) 417 } 418 time.Sleep(consulRetryInterval) 419 continue 420 } 421 return 422 } 423 }() 424 } 425 case <-shutdownCh: 426 c.logger.Info("shutting down consul backend") 427 shutdown.Store(true) 428 } 429 } 430 431 c.serviceLock.RLock() 432 defer c.serviceLock.RUnlock() 433 if err := c.Client.Agent().ServiceDeregister(registeredServiceID); err != nil { 434 if c.logger.IsWarn() { 435 c.logger.Warn("service deregistration failed", "error", err) 436 } 437 } 438} 439 440// checkID returns the ID used for a Consul Check. Assume at least a read 441// lock is held. 442func (c *serviceRegistration) checkID() string { 443 return fmt.Sprintf("%s:vault-sealed-check", c.serviceID()) 444} 445 446// serviceID returns the Vault ServiceID for use in Consul. Assume at least 447// a read lock is held. 448func (c *serviceRegistration) serviceID() string { 449 return fmt.Sprintf("%s:%s:%d", c.serviceName, c.redirectHost, c.redirectPort) 450} 451 452// reconcileConsul queries the state of Vault Core and Consul and fixes up 453// Consul's state according to what's in Vault. reconcileConsul is called 454// without any locks held and can be run concurrently, therefore no changes 455// to serviceRegistration can be made in this method (i.e. wtb const receiver for 456// compiler enforced safety). 457func (c *serviceRegistration) reconcileConsul(registeredServiceID string) (serviceID string, err error) { 458 agent := c.Client.Agent() 459 catalog := c.Client.Catalog() 460 461 serviceID = c.serviceID() 462 463 // Get the current state of Vault from Consul 464 var currentVaultService *api.CatalogService 465 if services, _, err := catalog.Service(c.serviceName, "", &api.QueryOptions{AllowStale: true}); err == nil { 466 for _, service := range services { 467 if serviceID == service.ServiceID { 468 currentVaultService = service 469 break 470 } 471 } 472 } 473 474 tags := c.fetchServiceTags(c.isActive.Load(), c.isPerfStandby.Load(), c.isInitialized.Load()) 475 476 var reregister bool 477 478 switch { 479 case currentVaultService == nil, registeredServiceID == "": 480 reregister = true 481 default: 482 switch { 483 case !strutil.EquivalentSlices(currentVaultService.ServiceTags, tags): 484 reregister = true 485 } 486 } 487 488 if !reregister { 489 // When re-registration is not required, return a valid serviceID 490 // to avoid registration in the next cycle. 491 return serviceID, nil 492 } 493 494 // If service address was set explicitly in configuration, use that 495 // as the service-specific address instead of the HA redirect address. 496 var serviceAddress string 497 if c.serviceAddress == nil { 498 serviceAddress = c.redirectHost 499 } else { 500 serviceAddress = *c.serviceAddress 501 } 502 503 service := &api.AgentServiceRegistration{ 504 ID: serviceID, 505 Name: c.serviceName, 506 Tags: tags, 507 Port: int(c.redirectPort), 508 Address: serviceAddress, 509 EnableTagOverride: false, 510 Meta: map[string]string{ 511 "external-source": metaExternalSource, 512 }, 513 } 514 515 checkStatus := api.HealthCritical 516 if !c.isSealed.Load() { 517 checkStatus = api.HealthPassing 518 } 519 520 sealedCheck := &api.AgentCheckRegistration{ 521 ID: c.checkID(), 522 Name: "Vault Sealed Status", 523 Notes: "Vault service is healthy when Vault is in an unsealed status and can become an active Vault server", 524 ServiceID: serviceID, 525 AgentServiceCheck: api.AgentServiceCheck{ 526 TTL: c.checkTimeout.String(), 527 Status: checkStatus, 528 }, 529 } 530 531 if err := agent.ServiceRegister(service); err != nil { 532 return "", fmt.Errorf(`service registration failed: %w`, err) 533 } 534 535 if err := agent.CheckRegister(sealedCheck); err != nil { 536 return serviceID, fmt.Errorf(`service check registration failed: %w`, err) 537 } 538 539 return serviceID, nil 540} 541 542// runCheck immediately pushes a TTL check. 543func (c *serviceRegistration) runCheck(sealed bool) error { 544 // Run a TTL check 545 agent := c.Client.Agent() 546 if !sealed { 547 return agent.PassTTL(c.checkID(), "Vault Unsealed") 548 } else { 549 return agent.FailTTL(c.checkID(), "Vault Sealed") 550 } 551} 552 553// fetchServiceTags returns all of the relevant tags for Consul. 554func (c *serviceRegistration) fetchServiceTags(active, perfStandby, initialized bool) []string { 555 activeTag := "standby" 556 if active { 557 activeTag = "active" 558 } 559 560 result := append(c.serviceTags, activeTag) 561 562 if perfStandby { 563 result = append(c.serviceTags, "performance-standby") 564 } 565 566 if initialized { 567 result = append(result, "initialized") 568 } 569 570 return result 571} 572 573func (c *serviceRegistration) setRedirectAddr(addr string) (err error) { 574 if addr == "" { 575 return fmt.Errorf("redirect address must not be empty") 576 } 577 578 url, err := url.Parse(addr) 579 if err != nil { 580 return fmt.Errorf("failed to parse redirect URL %q: %w", addr, err) 581 } 582 583 var portStr string 584 c.redirectHost, portStr, err = net.SplitHostPort(url.Host) 585 if err != nil { 586 if url.Scheme == "http" { 587 portStr = "80" 588 } else if url.Scheme == "https" { 589 portStr = "443" 590 } else if url.Scheme == "unix" { 591 portStr = "-1" 592 c.redirectHost = url.Path 593 } else { 594 return fmt.Errorf("failed to find a host:port in redirect address %q: %w", url.Host, err) 595 } 596 } 597 c.redirectPort, err = strconv.ParseInt(portStr, 10, 0) 598 if err != nil || c.redirectPort < -1 || c.redirectPort > 65535 { 599 return fmt.Errorf("failed to parse valid port %q: %w", portStr, err) 600 } 601 602 return nil 603} 604 605// durationMinusBuffer returns a duration, minus a buffer and jitter 606// subtracted from the duration. This function is used primarily for 607// servicing Consul TTL Checks in advance of the TTL. 608func durationMinusBuffer(intv time.Duration, buffer time.Duration, jitter int64) time.Duration { 609 d := intv - buffer 610 if jitter == 0 { 611 d -= randomStagger(d) 612 } else { 613 d -= randomStagger(time.Duration(int64(d) / jitter)) 614 } 615 return d 616} 617 618// durationMinusBufferDomain returns the domain of valid durations from a 619// call to durationMinusBuffer. This function is used to check user 620// specified input values to durationMinusBuffer. 621func durationMinusBufferDomain(intv time.Duration, buffer time.Duration, jitter int64) (min time.Duration, max time.Duration) { 622 max = intv - buffer 623 if jitter == 0 { 624 min = max 625 } else { 626 min = max - time.Duration(int64(max)/jitter) 627 } 628 return min, max 629} 630 631// randomStagger returns an interval between 0 and the duration 632func randomStagger(intv time.Duration) time.Duration { 633 if intv == 0 { 634 return 0 635 } 636 return time.Duration(uint64(rand.Int63()) % uint64(intv)) 637} 638