1package vault 2 3import ( 4 "context" 5 "errors" 6 "sync" 7 "sync/atomic" 8 9 "github.com/hashicorp/errwrap" 10 "github.com/hashicorp/vault/sdk/helper/consts" 11 "github.com/hashicorp/vault/sdk/helper/strutil" 12 "github.com/hashicorp/vault/sdk/logical" 13) 14 15const ( 16 CORSDisabled uint32 = iota 17 CORSEnabled 18) 19 20var StdAllowedHeaders = []string{ 21 "Content-Type", 22 "X-Requested-With", 23 "X-Vault-AWS-IAM-Server-ID", 24 "X-Vault-MFA", 25 "X-Vault-No-Request-Forwarding", 26 "X-Vault-Wrap-Format", 27 "X-Vault-Wrap-TTL", 28 "X-Vault-Policy-Override", 29 "Authorization", 30 consts.AuthHeaderName, 31} 32 33// CORSConfig stores the state of the CORS configuration. 34type CORSConfig struct { 35 sync.RWMutex `json:"-"` 36 core *Core 37 Enabled *uint32 `json:"enabled"` 38 AllowedOrigins []string `json:"allowed_origins,omitempty"` 39 AllowedHeaders []string `json:"allowed_headers,omitempty"` 40} 41 42func (c *Core) saveCORSConfig(ctx context.Context) error { 43 view := c.systemBarrierView.SubView("config/") 44 45 enabled := atomic.LoadUint32(c.corsConfig.Enabled) 46 localConfig := &CORSConfig{ 47 Enabled: &enabled, 48 } 49 c.corsConfig.RLock() 50 localConfig.AllowedOrigins = c.corsConfig.AllowedOrigins 51 localConfig.AllowedHeaders = c.corsConfig.AllowedHeaders 52 c.corsConfig.RUnlock() 53 54 entry, err := logical.StorageEntryJSON("cors", localConfig) 55 if err != nil { 56 return errwrap.Wrapf("failed to create CORS config entry: {{err}}", err) 57 } 58 59 if err := view.Put(ctx, entry); err != nil { 60 return errwrap.Wrapf("failed to save CORS config: {{err}}", err) 61 } 62 63 return nil 64} 65 66// This should only be called with the core state lock held for writing 67func (c *Core) loadCORSConfig(ctx context.Context) error { 68 view := c.systemBarrierView.SubView("config/") 69 70 // Load the config in 71 out, err := view.Get(ctx, "cors") 72 if err != nil { 73 return errwrap.Wrapf("failed to read CORS config: {{err}}", err) 74 } 75 if out == nil { 76 return nil 77 } 78 79 newConfig := new(CORSConfig) 80 err = out.DecodeJSON(newConfig) 81 if err != nil { 82 return err 83 } 84 85 if newConfig.Enabled == nil { 86 newConfig.Enabled = new(uint32) 87 } 88 89 newConfig.core = c 90 91 c.corsConfig = newConfig 92 93 return nil 94} 95 96// Enable takes either a '*' or a comma-separated list of URLs that can make 97// cross-origin requests to Vault. 98func (c *CORSConfig) Enable(ctx context.Context, urls []string, headers []string) error { 99 if len(urls) == 0 { 100 return errors.New("at least one origin or the wildcard must be provided") 101 } 102 103 if strutil.StrListContains(urls, "*") && len(urls) > 1 { 104 return errors.New("to allow all origins the '*' must be the only value for allowed_origins") 105 } 106 107 c.Lock() 108 c.AllowedOrigins = urls 109 110 // Start with the standard headers to Vault accepts. 111 c.AllowedHeaders = append([]string{}, StdAllowedHeaders...) 112 113 // Allow the user to add additional headers to the list of 114 // headers allowed on cross-origin requests. 115 if len(headers) > 0 { 116 c.AllowedHeaders = append(c.AllowedHeaders, headers...) 117 } 118 c.Unlock() 119 120 atomic.StoreUint32(c.Enabled, CORSEnabled) 121 122 return c.core.saveCORSConfig(ctx) 123} 124 125// IsEnabled returns the value of CORSConfig.isEnabled 126func (c *CORSConfig) IsEnabled() bool { 127 return atomic.LoadUint32(c.Enabled) == CORSEnabled 128} 129 130// Disable sets CORS to disabled and clears the allowed origins & headers. 131func (c *CORSConfig) Disable(ctx context.Context) error { 132 atomic.StoreUint32(c.Enabled, CORSDisabled) 133 c.Lock() 134 135 c.AllowedOrigins = nil 136 c.AllowedHeaders = nil 137 138 c.Unlock() 139 140 return c.core.saveCORSConfig(ctx) 141} 142 143// IsValidOrigin determines if the origin of the request is allowed to make 144// cross-origin requests based on the CORSConfig. 145func (c *CORSConfig) IsValidOrigin(origin string) bool { 146 // If we aren't enabling CORS then all origins are valid 147 if !c.IsEnabled() { 148 return true 149 } 150 151 c.RLock() 152 defer c.RUnlock() 153 154 if len(c.AllowedOrigins) == 0 { 155 return false 156 } 157 158 if len(c.AllowedOrigins) == 1 && (c.AllowedOrigins)[0] == "*" { 159 return true 160 } 161 162 return strutil.StrListContains(c.AllowedOrigins, origin) 163} 164