1package socket
2
3import (
4	"bytes"
5	"context"
6	"fmt"
7	"net"
8	"strconv"
9	"sync"
10	"time"
11
12	multierror "github.com/hashicorp/go-multierror"
13	"github.com/hashicorp/vault/audit"
14	"github.com/hashicorp/vault/sdk/helper/parseutil"
15	"github.com/hashicorp/vault/sdk/helper/salt"
16	"github.com/hashicorp/vault/sdk/logical"
17)
18
19func Factory(ctx context.Context, conf *audit.BackendConfig) (audit.Backend, error) {
20	if conf.SaltConfig == nil {
21		return nil, fmt.Errorf("nil salt config")
22	}
23	if conf.SaltView == nil {
24		return nil, fmt.Errorf("nil salt view")
25	}
26
27	address, ok := conf.Config["address"]
28	if !ok {
29		return nil, fmt.Errorf("address is required")
30	}
31
32	socketType, ok := conf.Config["socket_type"]
33	if !ok {
34		socketType = "tcp"
35	}
36
37	writeDeadline, ok := conf.Config["write_timeout"]
38	if !ok {
39		writeDeadline = "2s"
40	}
41	writeDuration, err := parseutil.ParseDurationSecond(writeDeadline)
42	if err != nil {
43		return nil, err
44	}
45
46	format, ok := conf.Config["format"]
47	if !ok {
48		format = "json"
49	}
50	switch format {
51	case "json", "jsonx":
52	default:
53		return nil, fmt.Errorf("unknown format type %q", format)
54	}
55
56	// Check if hashing of accessor is disabled
57	hmacAccessor := true
58	if hmacAccessorRaw, ok := conf.Config["hmac_accessor"]; ok {
59		value, err := strconv.ParseBool(hmacAccessorRaw)
60		if err != nil {
61			return nil, err
62		}
63		hmacAccessor = value
64	}
65
66	// Check if raw logging is enabled
67	logRaw := false
68	if raw, ok := conf.Config["log_raw"]; ok {
69		b, err := strconv.ParseBool(raw)
70		if err != nil {
71			return nil, err
72		}
73		logRaw = b
74	}
75
76	b := &Backend{
77		saltConfig: conf.SaltConfig,
78		saltView:   conf.SaltView,
79		formatConfig: audit.FormatterConfig{
80			Raw:          logRaw,
81			HMACAccessor: hmacAccessor,
82		},
83
84		writeDuration: writeDuration,
85		address:       address,
86		socketType:    socketType,
87	}
88
89	switch format {
90	case "json":
91		b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{
92			Prefix:   conf.Config["prefix"],
93			SaltFunc: b.Salt,
94		}
95	case "jsonx":
96		b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{
97			Prefix:   conf.Config["prefix"],
98			SaltFunc: b.Salt,
99		}
100	}
101
102	return b, nil
103}
104
105// Backend is the audit backend for the socket audit transport.
106type Backend struct {
107	connection net.Conn
108
109	formatter    audit.AuditFormatter
110	formatConfig audit.FormatterConfig
111
112	writeDuration time.Duration
113	address       string
114	socketType    string
115
116	sync.Mutex
117
118	saltMutex  sync.RWMutex
119	salt       *salt.Salt
120	saltConfig *salt.Config
121	saltView   logical.Storage
122}
123
124var _ audit.Backend = (*Backend)(nil)
125
126func (b *Backend) GetHash(ctx context.Context, data string) (string, error) {
127	salt, err := b.Salt(ctx)
128	if err != nil {
129		return "", err
130	}
131	return audit.HashString(salt, data), nil
132}
133
134func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error {
135	var buf bytes.Buffer
136	if err := b.formatter.FormatRequest(ctx, &buf, b.formatConfig, in); err != nil {
137		return err
138	}
139
140	b.Lock()
141	defer b.Unlock()
142
143	err := b.write(ctx, buf.Bytes())
144	if err != nil {
145		rErr := b.reconnect(ctx)
146		if rErr != nil {
147			err = multierror.Append(err, rErr)
148		} else {
149			// Try once more after reconnecting
150			err = b.write(ctx, buf.Bytes())
151		}
152	}
153
154	return err
155}
156
157func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error {
158	var buf bytes.Buffer
159	if err := b.formatter.FormatResponse(ctx, &buf, b.formatConfig, in); err != nil {
160		return err
161	}
162
163	b.Lock()
164	defer b.Unlock()
165
166	err := b.write(ctx, buf.Bytes())
167	if err != nil {
168		rErr := b.reconnect(ctx)
169		if rErr != nil {
170			err = multierror.Append(err, rErr)
171		} else {
172			// Try once more after reconnecting
173			err = b.write(ctx, buf.Bytes())
174		}
175	}
176
177	return err
178}
179
180func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error {
181	var buf bytes.Buffer
182	temporaryFormatter := audit.NewTemporaryFormatter(config["format"], config["prefix"])
183	if err := temporaryFormatter.FormatRequest(ctx, &buf, b.formatConfig, in); err != nil {
184		return err
185	}
186
187	b.Lock()
188	defer b.Unlock()
189
190	err := b.write(ctx, buf.Bytes())
191	if err != nil {
192		rErr := b.reconnect(ctx)
193		if rErr != nil {
194			err = multierror.Append(err, rErr)
195		} else {
196			// Try once more after reconnecting
197			err = b.write(ctx, buf.Bytes())
198		}
199	}
200
201	return err
202}
203
204func (b *Backend) write(ctx context.Context, buf []byte) error {
205	if b.connection == nil {
206		if err := b.reconnect(ctx); err != nil {
207			return err
208		}
209	}
210
211	err := b.connection.SetWriteDeadline(time.Now().Add(b.writeDuration))
212	if err != nil {
213		return err
214	}
215
216	_, err = b.connection.Write(buf)
217	if err != nil {
218		return err
219	}
220
221	return err
222}
223
224func (b *Backend) reconnect(ctx context.Context) error {
225	if b.connection != nil {
226		b.connection.Close()
227		b.connection = nil
228	}
229
230	timeoutContext, cancel := context.WithTimeout(ctx, b.writeDuration)
231	defer cancel()
232
233	dialer := net.Dialer{}
234	conn, err := dialer.DialContext(timeoutContext, b.socketType, b.address)
235	if err != nil {
236		return err
237	}
238
239	b.connection = conn
240
241	return nil
242}
243
244func (b *Backend) Reload(ctx context.Context) error {
245	b.Lock()
246	defer b.Unlock()
247
248	err := b.reconnect(ctx)
249
250	return err
251}
252
253func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) {
254	b.saltMutex.RLock()
255	if b.salt != nil {
256		defer b.saltMutex.RUnlock()
257		return b.salt, nil
258	}
259	b.saltMutex.RUnlock()
260	b.saltMutex.Lock()
261	defer b.saltMutex.Unlock()
262	if b.salt != nil {
263		return b.salt, nil
264	}
265	salt, err := salt.NewSalt(ctx, b.saltView, b.saltConfig)
266	if err != nil {
267		return nil, err
268	}
269	b.salt = salt
270	return salt, nil
271}
272
273func (b *Backend) Invalidate(_ context.Context) {
274	b.saltMutex.Lock()
275	defer b.saltMutex.Unlock()
276	b.salt = nil
277}
278