1package socket 2 3import ( 4 "bytes" 5 "context" 6 "fmt" 7 "net" 8 "strconv" 9 "sync" 10 "time" 11 12 multierror "github.com/hashicorp/go-multierror" 13 "github.com/hashicorp/vault/audit" 14 "github.com/hashicorp/vault/sdk/helper/parseutil" 15 "github.com/hashicorp/vault/sdk/helper/salt" 16 "github.com/hashicorp/vault/sdk/logical" 17) 18 19func Factory(ctx context.Context, conf *audit.BackendConfig) (audit.Backend, error) { 20 if conf.SaltConfig == nil { 21 return nil, fmt.Errorf("nil salt config") 22 } 23 if conf.SaltView == nil { 24 return nil, fmt.Errorf("nil salt view") 25 } 26 27 address, ok := conf.Config["address"] 28 if !ok { 29 return nil, fmt.Errorf("address is required") 30 } 31 32 socketType, ok := conf.Config["socket_type"] 33 if !ok { 34 socketType = "tcp" 35 } 36 37 writeDeadline, ok := conf.Config["write_timeout"] 38 if !ok { 39 writeDeadline = "2s" 40 } 41 writeDuration, err := parseutil.ParseDurationSecond(writeDeadline) 42 if err != nil { 43 return nil, err 44 } 45 46 format, ok := conf.Config["format"] 47 if !ok { 48 format = "json" 49 } 50 switch format { 51 case "json", "jsonx": 52 default: 53 return nil, fmt.Errorf("unknown format type %q", format) 54 } 55 56 // Check if hashing of accessor is disabled 57 hmacAccessor := true 58 if hmacAccessorRaw, ok := conf.Config["hmac_accessor"]; ok { 59 value, err := strconv.ParseBool(hmacAccessorRaw) 60 if err != nil { 61 return nil, err 62 } 63 hmacAccessor = value 64 } 65 66 // Check if raw logging is enabled 67 logRaw := false 68 if raw, ok := conf.Config["log_raw"]; ok { 69 b, err := strconv.ParseBool(raw) 70 if err != nil { 71 return nil, err 72 } 73 logRaw = b 74 } 75 76 b := &Backend{ 77 saltConfig: conf.SaltConfig, 78 saltView: conf.SaltView, 79 formatConfig: audit.FormatterConfig{ 80 Raw: logRaw, 81 HMACAccessor: hmacAccessor, 82 }, 83 84 writeDuration: writeDuration, 85 address: address, 86 socketType: socketType, 87 } 88 89 switch format { 90 case "json": 91 b.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ 92 Prefix: conf.Config["prefix"], 93 SaltFunc: b.Salt, 94 } 95 case "jsonx": 96 b.formatter.AuditFormatWriter = &audit.JSONxFormatWriter{ 97 Prefix: conf.Config["prefix"], 98 SaltFunc: b.Salt, 99 } 100 } 101 102 return b, nil 103} 104 105// Backend is the audit backend for the socket audit transport. 106type Backend struct { 107 connection net.Conn 108 109 formatter audit.AuditFormatter 110 formatConfig audit.FormatterConfig 111 112 writeDuration time.Duration 113 address string 114 socketType string 115 116 sync.Mutex 117 118 saltMutex sync.RWMutex 119 salt *salt.Salt 120 saltConfig *salt.Config 121 saltView logical.Storage 122} 123 124var _ audit.Backend = (*Backend)(nil) 125 126func (b *Backend) GetHash(ctx context.Context, data string) (string, error) { 127 salt, err := b.Salt(ctx) 128 if err != nil { 129 return "", err 130 } 131 return audit.HashString(salt, data), nil 132} 133 134func (b *Backend) LogRequest(ctx context.Context, in *logical.LogInput) error { 135 var buf bytes.Buffer 136 if err := b.formatter.FormatRequest(ctx, &buf, b.formatConfig, in); err != nil { 137 return err 138 } 139 140 b.Lock() 141 defer b.Unlock() 142 143 err := b.write(ctx, buf.Bytes()) 144 if err != nil { 145 rErr := b.reconnect(ctx) 146 if rErr != nil { 147 err = multierror.Append(err, rErr) 148 } else { 149 // Try once more after reconnecting 150 err = b.write(ctx, buf.Bytes()) 151 } 152 } 153 154 return err 155} 156 157func (b *Backend) LogResponse(ctx context.Context, in *logical.LogInput) error { 158 var buf bytes.Buffer 159 if err := b.formatter.FormatResponse(ctx, &buf, b.formatConfig, in); err != nil { 160 return err 161 } 162 163 b.Lock() 164 defer b.Unlock() 165 166 err := b.write(ctx, buf.Bytes()) 167 if err != nil { 168 rErr := b.reconnect(ctx) 169 if rErr != nil { 170 err = multierror.Append(err, rErr) 171 } else { 172 // Try once more after reconnecting 173 err = b.write(ctx, buf.Bytes()) 174 } 175 } 176 177 return err 178} 179 180func (b *Backend) LogTestMessage(ctx context.Context, in *logical.LogInput, config map[string]string) error { 181 var buf bytes.Buffer 182 temporaryFormatter := audit.NewTemporaryFormatter(config["format"], config["prefix"]) 183 if err := temporaryFormatter.FormatRequest(ctx, &buf, b.formatConfig, in); err != nil { 184 return err 185 } 186 187 b.Lock() 188 defer b.Unlock() 189 190 err := b.write(ctx, buf.Bytes()) 191 if err != nil { 192 rErr := b.reconnect(ctx) 193 if rErr != nil { 194 err = multierror.Append(err, rErr) 195 } else { 196 // Try once more after reconnecting 197 err = b.write(ctx, buf.Bytes()) 198 } 199 } 200 201 return err 202} 203 204func (b *Backend) write(ctx context.Context, buf []byte) error { 205 if b.connection == nil { 206 if err := b.reconnect(ctx); err != nil { 207 return err 208 } 209 } 210 211 err := b.connection.SetWriteDeadline(time.Now().Add(b.writeDuration)) 212 if err != nil { 213 return err 214 } 215 216 _, err = b.connection.Write(buf) 217 if err != nil { 218 return err 219 } 220 221 return err 222} 223 224func (b *Backend) reconnect(ctx context.Context) error { 225 if b.connection != nil { 226 b.connection.Close() 227 b.connection = nil 228 } 229 230 timeoutContext, cancel := context.WithTimeout(ctx, b.writeDuration) 231 defer cancel() 232 233 dialer := net.Dialer{} 234 conn, err := dialer.DialContext(timeoutContext, b.socketType, b.address) 235 if err != nil { 236 return err 237 } 238 239 b.connection = conn 240 241 return nil 242} 243 244func (b *Backend) Reload(ctx context.Context) error { 245 b.Lock() 246 defer b.Unlock() 247 248 err := b.reconnect(ctx) 249 250 return err 251} 252 253func (b *Backend) Salt(ctx context.Context) (*salt.Salt, error) { 254 b.saltMutex.RLock() 255 if b.salt != nil { 256 defer b.saltMutex.RUnlock() 257 return b.salt, nil 258 } 259 b.saltMutex.RUnlock() 260 b.saltMutex.Lock() 261 defer b.saltMutex.Unlock() 262 if b.salt != nil { 263 return b.salt, nil 264 } 265 salt, err := salt.NewSalt(ctx, b.saltView, b.saltConfig) 266 if err != nil { 267 return nil, err 268 } 269 b.salt = salt 270 return salt, nil 271} 272 273func (b *Backend) Invalidate(_ context.Context) { 274 b.saltMutex.Lock() 275 defer b.saltMutex.Unlock() 276 b.salt = nil 277} 278