1package vault
2
3import (
4	"context"
5	"errors"
6	"sync"
7	"sync/atomic"
8
9	"github.com/hashicorp/errwrap"
10	"github.com/hashicorp/vault/sdk/helper/consts"
11	"github.com/hashicorp/vault/sdk/helper/strutil"
12	"github.com/hashicorp/vault/sdk/logical"
13)
14
15const (
16	CORSDisabled uint32 = iota
17	CORSEnabled
18)
19
20var StdAllowedHeaders = []string{
21	"Content-Type",
22	"X-Requested-With",
23	"X-Vault-AWS-IAM-Server-ID",
24	"X-Vault-MFA",
25	"X-Vault-No-Request-Forwarding",
26	"X-Vault-Wrap-Format",
27	"X-Vault-Wrap-TTL",
28	"X-Vault-Policy-Override",
29	"Authorization",
30	consts.AuthHeaderName,
31}
32
33// CORSConfig stores the state of the CORS configuration.
34type CORSConfig struct {
35	sync.RWMutex   `json:"-"`
36	core           *Core
37	Enabled        *uint32  `json:"enabled"`
38	AllowedOrigins []string `json:"allowed_origins,omitempty"`
39	AllowedHeaders []string `json:"allowed_headers,omitempty"`
40}
41
42func (c *Core) saveCORSConfig(ctx context.Context) error {
43	view := c.systemBarrierView.SubView("config/")
44
45	enabled := atomic.LoadUint32(c.corsConfig.Enabled)
46	localConfig := &CORSConfig{
47		Enabled: &enabled,
48	}
49	c.corsConfig.RLock()
50	localConfig.AllowedOrigins = c.corsConfig.AllowedOrigins
51	localConfig.AllowedHeaders = c.corsConfig.AllowedHeaders
52	c.corsConfig.RUnlock()
53
54	entry, err := logical.StorageEntryJSON("cors", localConfig)
55	if err != nil {
56		return errwrap.Wrapf("failed to create CORS config entry: {{err}}", err)
57	}
58
59	if err := view.Put(ctx, entry); err != nil {
60		return errwrap.Wrapf("failed to save CORS config: {{err}}", err)
61	}
62
63	return nil
64}
65
66// This should only be called with the core state lock held for writing
67func (c *Core) loadCORSConfig(ctx context.Context) error {
68	view := c.systemBarrierView.SubView("config/")
69
70	// Load the config in
71	out, err := view.Get(ctx, "cors")
72	if err != nil {
73		return errwrap.Wrapf("failed to read CORS config: {{err}}", err)
74	}
75	if out == nil {
76		return nil
77	}
78
79	newConfig := new(CORSConfig)
80	err = out.DecodeJSON(newConfig)
81	if err != nil {
82		return err
83	}
84
85	if newConfig.Enabled == nil {
86		newConfig.Enabled = new(uint32)
87	}
88
89	newConfig.core = c
90
91	c.corsConfig = newConfig
92
93	return nil
94}
95
96// Enable takes either a '*' or a comma-separated list of URLs that can make
97// cross-origin requests to Vault.
98func (c *CORSConfig) Enable(ctx context.Context, urls []string, headers []string) error {
99	if len(urls) == 0 {
100		return errors.New("at least one origin or the wildcard must be provided")
101	}
102
103	if strutil.StrListContains(urls, "*") && len(urls) > 1 {
104		return errors.New("to allow all origins the '*' must be the only value for allowed_origins")
105	}
106
107	c.Lock()
108	c.AllowedOrigins = urls
109
110	// Start with the standard headers to Vault accepts.
111	c.AllowedHeaders = append([]string{}, StdAllowedHeaders...)
112
113	// Allow the user to add additional headers to the list of
114	// headers allowed on cross-origin requests.
115	if len(headers) > 0 {
116		c.AllowedHeaders = append(c.AllowedHeaders, headers...)
117	}
118	c.Unlock()
119
120	atomic.StoreUint32(c.Enabled, CORSEnabled)
121
122	return c.core.saveCORSConfig(ctx)
123}
124
125// IsEnabled returns the value of CORSConfig.isEnabled
126func (c *CORSConfig) IsEnabled() bool {
127	return atomic.LoadUint32(c.Enabled) == CORSEnabled
128}
129
130// Disable sets CORS to disabled and clears the allowed origins & headers.
131func (c *CORSConfig) Disable(ctx context.Context) error {
132	atomic.StoreUint32(c.Enabled, CORSDisabled)
133	c.Lock()
134
135	c.AllowedOrigins = nil
136	c.AllowedHeaders = nil
137
138	c.Unlock()
139
140	return c.core.saveCORSConfig(ctx)
141}
142
143// IsValidOrigin determines if the origin of the request is allowed to make
144// cross-origin requests based on the CORSConfig.
145func (c *CORSConfig) IsValidOrigin(origin string) bool {
146	// If we aren't enabling CORS then all origins are valid
147	if !c.IsEnabled() {
148		return true
149	}
150
151	c.RLock()
152	defer c.RUnlock()
153
154	if len(c.AllowedOrigins) == 0 {
155		return false
156	}
157
158	if len(c.AllowedOrigins) == 1 && (c.AllowedOrigins)[0] == "*" {
159		return true
160	}
161
162	return strutil.StrListContains(c.AllowedOrigins, origin)
163}
164