1package vault
2
3import (
4	"context"
5	"fmt"
6	"sync"
7	"time"
8
9	metrics "github.com/armon/go-metrics"
10	log "github.com/hashicorp/go-hclog"
11	multierror "github.com/hashicorp/go-multierror"
12	"github.com/hashicorp/vault/audit"
13	"github.com/hashicorp/vault/sdk/logical"
14)
15
16type backendEntry struct {
17	backend audit.Backend
18	view    *BarrierView
19	local   bool
20}
21
22// AuditBroker is used to provide a single ingest interface to auditable
23// events given that multiple backends may be configured.
24type AuditBroker struct {
25	sync.RWMutex
26	backends map[string]backendEntry
27	logger   log.Logger
28}
29
30// NewAuditBroker creates a new audit broker
31func NewAuditBroker(log log.Logger) *AuditBroker {
32	b := &AuditBroker{
33		backends: make(map[string]backendEntry),
34		logger:   log,
35	}
36	return b
37}
38
39// Register is used to add new audit backend to the broker
40func (a *AuditBroker) Register(name string, b audit.Backend, v *BarrierView, local bool) {
41	a.Lock()
42	defer a.Unlock()
43	a.backends[name] = backendEntry{
44		backend: b,
45		view:    v,
46		local:   local,
47	}
48}
49
50// Deregister is used to remove an audit backend from the broker
51func (a *AuditBroker) Deregister(name string) {
52	a.Lock()
53	defer a.Unlock()
54	delete(a.backends, name)
55}
56
57// IsRegistered is used to check if a given audit backend is registered
58func (a *AuditBroker) IsRegistered(name string) bool {
59	a.RLock()
60	defer a.RUnlock()
61	_, ok := a.backends[name]
62	return ok
63}
64
65// IsLocal is used to check if a given audit backend is registered
66func (a *AuditBroker) IsLocal(name string) (bool, error) {
67	a.RLock()
68	defer a.RUnlock()
69	be, ok := a.backends[name]
70	if ok {
71		return be.local, nil
72	}
73	return false, fmt.Errorf("unknown audit backend %q", name)
74}
75
76// GetHash returns a hash using the salt of the given backend
77func (a *AuditBroker) GetHash(ctx context.Context, name string, input string) (string, error) {
78	a.RLock()
79	defer a.RUnlock()
80	be, ok := a.backends[name]
81	if !ok {
82		return "", fmt.Errorf("unknown audit backend %q", name)
83	}
84
85	return be.backend.GetHash(ctx, input)
86}
87
88// LogRequest is used to ensure all the audit backends have an opportunity to
89// log the given request and that *at least one* succeeds.
90func (a *AuditBroker) LogRequest(ctx context.Context, in *logical.LogInput, headersConfig *AuditedHeadersConfig) (ret error) {
91	defer metrics.MeasureSince([]string{"audit", "log_request"}, time.Now())
92	a.RLock()
93	defer a.RUnlock()
94
95	var retErr *multierror.Error
96
97	defer func() {
98		if r := recover(); r != nil {
99			a.logger.Error("panic during logging", "request_path", in.Request.Path, "error", r)
100			retErr = multierror.Append(retErr, fmt.Errorf("panic generating audit log"))
101		}
102
103		ret = retErr.ErrorOrNil()
104		failure := float32(0.0)
105		if ret != nil {
106			failure = 1.0
107		}
108		metrics.IncrCounter([]string{"audit", "log_request_failure"}, failure)
109	}()
110
111	// All logged requests must have an identifier
112	//if req.ID == "" {
113	//	a.logger.Error("missing identifier in request object", "request_path", req.Path)
114	//	retErr = multierror.Append(retErr, fmt.Errorf("missing identifier in request object: %s", req.Path))
115	//	return
116	//}
117
118	headers := in.Request.Headers
119	defer func() {
120		in.Request.Headers = headers
121	}()
122
123	// Ensure at least one backend logs
124	anyLogged := false
125	for name, be := range a.backends {
126		in.Request.Headers = nil
127		transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend.GetHash)
128		if thErr != nil {
129			a.logger.Error("backend failed to include headers", "backend", name, "error", thErr)
130			continue
131		}
132		in.Request.Headers = transHeaders
133
134		start := time.Now()
135		lrErr := be.backend.LogRequest(ctx, in)
136		metrics.MeasureSince([]string{"audit", name, "log_request"}, start)
137		if lrErr != nil {
138			a.logger.Error("backend failed to log request", "backend", name, "error", lrErr)
139		} else {
140			anyLogged = true
141		}
142	}
143	if !anyLogged && len(a.backends) > 0 {
144		retErr = multierror.Append(retErr, fmt.Errorf("no audit backend succeeded in logging the request"))
145	}
146
147	return retErr.ErrorOrNil()
148}
149
150// LogResponse is used to ensure all the audit backends have an opportunity to
151// log the given response and that *at least one* succeeds.
152func (a *AuditBroker) LogResponse(ctx context.Context, in *logical.LogInput, headersConfig *AuditedHeadersConfig) (ret error) {
153	defer metrics.MeasureSince([]string{"audit", "log_response"}, time.Now())
154	a.RLock()
155	defer a.RUnlock()
156
157	var retErr *multierror.Error
158
159	defer func() {
160		if r := recover(); r != nil {
161			a.logger.Error("panic during logging", "request_path", in.Request.Path, "error", r)
162			retErr = multierror.Append(retErr, fmt.Errorf("panic generating audit log"))
163		}
164
165		ret = retErr.ErrorOrNil()
166
167		failure := float32(0.0)
168		if ret != nil {
169			failure = 1.0
170		}
171		metrics.IncrCounter([]string{"audit", "log_response_failure"}, failure)
172	}()
173
174	headers := in.Request.Headers
175	defer func() {
176		in.Request.Headers = headers
177	}()
178
179	// Ensure at least one backend logs
180	anyLogged := false
181	for name, be := range a.backends {
182		in.Request.Headers = nil
183		transHeaders, thErr := headersConfig.ApplyConfig(ctx, headers, be.backend.GetHash)
184		if thErr != nil {
185			a.logger.Error("backend failed to include headers", "backend", name, "error", thErr)
186			continue
187		}
188		in.Request.Headers = transHeaders
189
190		start := time.Now()
191		lrErr := be.backend.LogResponse(ctx, in)
192		metrics.MeasureSince([]string{"audit", name, "log_response"}, start)
193		if lrErr != nil {
194			a.logger.Error("backend failed to log response", "backend", name, "error", lrErr)
195		} else {
196			anyLogged = true
197		}
198	}
199	if !anyLogged && len(a.backends) > 0 {
200		retErr = multierror.Append(retErr, fmt.Errorf("no audit backend succeeded in logging the response"))
201	}
202
203	return retErr.ErrorOrNil()
204}
205
206func (a *AuditBroker) Invalidate(ctx context.Context, key string) {
207	// For now we ignore the key as this would only apply to salts. We just
208	// sort of brute force it on each one.
209	a.Lock()
210	defer a.Unlock()
211	for _, be := range a.backends {
212		be.backend.Invalidate(ctx)
213	}
214}
215