1package vault
2
3import (
4	"context"
5	"fmt"
6	"strings"
7	"sync"
8
9	"github.com/hashicorp/errwrap"
10	"github.com/hashicorp/vault/sdk/logical"
11)
12
13// N.B.: While we could use textproto to get the canonical mime header, HTTP/2
14// requires all headers to be converted to lower case, so we just do that.
15
16const (
17	// Key used in the BarrierView to store and retrieve the header config
18	auditedHeadersEntry = "audited-headers"
19	// Path used to create a sub view off of BarrierView
20	auditedHeadersSubPath = "audited-headers-config/"
21)
22
23type auditedHeaderSettings struct {
24	HMAC bool `json:"hmac"`
25}
26
27// AuditedHeadersConfig is used by the Audit Broker to write only approved
28// headers to the audit logs. It uses a BarrierView to persist the settings.
29type AuditedHeadersConfig struct {
30	Headers map[string]*auditedHeaderSettings
31
32	view *BarrierView
33	sync.RWMutex
34}
35
36// add adds or overwrites a header in the config and updates the barrier view
37func (a *AuditedHeadersConfig) add(ctx context.Context, header string, hmac bool) error {
38	if header == "" {
39		return fmt.Errorf("header value cannot be empty")
40	}
41
42	// Grab a write lock
43	a.Lock()
44	defer a.Unlock()
45
46	if a.Headers == nil {
47		a.Headers = make(map[string]*auditedHeaderSettings, 1)
48	}
49
50	a.Headers[strings.ToLower(header)] = &auditedHeaderSettings{hmac}
51	entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.Headers)
52	if err != nil {
53		return errwrap.Wrapf("failed to persist audited headers config: {{err}}", err)
54	}
55
56	if err := a.view.Put(ctx, entry); err != nil {
57		return errwrap.Wrapf("failed to persist audited headers config: {{err}}", err)
58	}
59
60	return nil
61}
62
63// remove deletes a header out of the header config and updates the barrier view
64func (a *AuditedHeadersConfig) remove(ctx context.Context, header string) error {
65	if header == "" {
66		return fmt.Errorf("header value cannot be empty")
67	}
68
69	// Grab a write lock
70	a.Lock()
71	defer a.Unlock()
72
73	// Nothing to delete
74	if len(a.Headers) == 0 {
75		return nil
76	}
77
78	delete(a.Headers, strings.ToLower(header))
79	entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.Headers)
80	if err != nil {
81		return errwrap.Wrapf("failed to persist audited headers config: {{err}}", err)
82	}
83
84	if err := a.view.Put(ctx, entry); err != nil {
85		return errwrap.Wrapf("failed to persist audited headers config: {{err}}", err)
86	}
87
88	return nil
89}
90
91// ApplyConfig returns a map of approved headers and their values, either
92// hmac'ed or plaintext
93func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[string][]string, hashFunc func(context.Context, string) (string, error)) (result map[string][]string, retErr error) {
94	// Grab a read lock
95	a.RLock()
96	defer a.RUnlock()
97
98	// Make a copy of the incoming headers with everything lower so we can
99	// case-insensitively compare
100	lowerHeaders := make(map[string][]string, len(headers))
101	for k, v := range headers {
102		lowerHeaders[strings.ToLower(k)] = v
103	}
104
105	result = make(map[string][]string, len(a.Headers))
106	for key, settings := range a.Headers {
107		if val, ok := lowerHeaders[key]; ok {
108			// copy the header values so we don't overwrite them
109			hVals := make([]string, len(val))
110			copy(hVals, val)
111
112			// Optionally hmac the values
113			if settings.HMAC {
114				for i, el := range hVals {
115					hVal, err := hashFunc(ctx, el)
116					if err != nil {
117						return nil, err
118					}
119					hVals[i] = hVal
120				}
121			}
122
123			result[key] = hVals
124		}
125	}
126
127	return result, nil
128}
129
130// Initialize the headers config by loading from the barrier view
131func (c *Core) setupAuditedHeadersConfig(ctx context.Context) error {
132	// Create a sub-view
133	view := c.systemBarrierView.SubView(auditedHeadersSubPath)
134
135	// Create the config
136	out, err := view.Get(ctx, auditedHeadersEntry)
137	if err != nil {
138		return errwrap.Wrapf("failed to read config: {{err}}", err)
139	}
140
141	headers := make(map[string]*auditedHeaderSettings)
142	if out != nil {
143		err = out.DecodeJSON(&headers)
144		if err != nil {
145			return err
146		}
147	}
148
149	// Ensure that we are able to case-sensitively access the headers;
150	// necessary for the upgrade case
151	lowerHeaders := make(map[string]*auditedHeaderSettings, len(headers))
152	for k, v := range headers {
153		lowerHeaders[strings.ToLower(k)] = v
154	}
155
156	c.auditedHeaders = &AuditedHeadersConfig{
157		Headers: lowerHeaders,
158		view:    view,
159	}
160
161	return nil
162}
163