1package sink 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "io/ioutil" 8 "math/rand" 9 "os" 10 "sync/atomic" 11 "time" 12 13 hclog "github.com/hashicorp/go-hclog" 14 "github.com/hashicorp/vault/api" 15 "github.com/hashicorp/vault/helper/dhutil" 16 "github.com/hashicorp/vault/sdk/helper/jsonutil" 17) 18 19type Sink interface { 20 WriteToken(string) error 21} 22 23type SinkReader interface { 24 Token() string 25} 26 27type SinkConfig struct { 28 Sink 29 Logger hclog.Logger 30 Config map[string]interface{} 31 Client *api.Client 32 WrapTTL time.Duration 33 DHType string 34 DHPath string 35 DeriveKey bool 36 AAD string 37 cachedRemotePubKey []byte 38 cachedPubKey []byte 39 cachedPriKey []byte 40} 41 42type SinkServerConfig struct { 43 Logger hclog.Logger 44 Client *api.Client 45 Context context.Context 46 ExitAfterAuth bool 47} 48 49// SinkServer is responsible for pushing tokens to sinks 50type SinkServer struct { 51 logger hclog.Logger 52 client *api.Client 53 random *rand.Rand 54 exitAfterAuth bool 55 remaining *int32 56} 57 58func NewSinkServer(conf *SinkServerConfig) *SinkServer { 59 ss := &SinkServer{ 60 logger: conf.Logger, 61 client: conf.Client, 62 random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), 63 exitAfterAuth: conf.ExitAfterAuth, 64 remaining: new(int32), 65 } 66 67 return ss 68} 69 70// Run executes the server's run loop, which is responsible for reading 71// in new tokens and pushing them out to the various sinks. 72func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*SinkConfig) error { 73 latestToken := new(string) 74 writeSink := func(currSink *SinkConfig, currToken string) error { 75 if currToken != *latestToken { 76 return nil 77 } 78 var err error 79 80 if currSink.WrapTTL != 0 { 81 if currToken, err = currSink.wrapToken(ss.client, currSink.WrapTTL, currToken); err != nil { 82 return err 83 } 84 } 85 86 if currSink.DHType != "" { 87 if currToken, err = currSink.encryptToken(currToken); err != nil { 88 return err 89 } 90 } 91 92 return currSink.WriteToken(currToken) 93 } 94 95 if incoming == nil { 96 return errors.New("sink server: incoming channel is nil") 97 } 98 99 ss.logger.Info("starting sink server") 100 defer func() { 101 ss.logger.Info("sink server stopped") 102 }() 103 104 type sinkToken struct { 105 sink *SinkConfig 106 token string 107 } 108 sinkCh := make(chan sinkToken, len(sinks)) 109 for { 110 select { 111 case <-ctx.Done(): 112 return nil 113 114 case token := <-incoming: 115 if len(sinks) > 0 { 116 if token != *latestToken { 117 118 // Drain the existing funcs 119 drainLoop: 120 for { 121 select { 122 case <-sinkCh: 123 atomic.AddInt32(ss.remaining, -1) 124 default: 125 break drainLoop 126 } 127 } 128 129 *latestToken = token 130 131 for _, s := range sinks { 132 atomic.AddInt32(ss.remaining, 1) 133 sinkCh <- sinkToken{s, token} 134 } 135 } 136 } else { 137 ss.logger.Trace("no sinks, ignoring new token") 138 if ss.exitAfterAuth { 139 ss.logger.Trace("no sinks, exitAfterAuth, bye") 140 return nil 141 } 142 } 143 case st := <-sinkCh: 144 atomic.AddInt32(ss.remaining, -1) 145 select { 146 case <-ctx.Done(): 147 return nil 148 default: 149 } 150 151 if err := writeSink(st.sink, st.token); err != nil { 152 backoff := 2*time.Second + time.Duration(ss.random.Int63()%int64(time.Second*2)-int64(time.Second)) 153 ss.logger.Error("error returned by sink function, retrying", "error", err, "backoff", backoff.String()) 154 select { 155 case <-ctx.Done(): 156 return nil 157 case <-time.After(backoff): 158 atomic.AddInt32(ss.remaining, 1) 159 sinkCh <- st 160 } 161 } else { 162 if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth { 163 return nil 164 } 165 } 166 } 167 } 168} 169 170func (s *SinkConfig) encryptToken(token string) (string, error) { 171 var aesKey []byte 172 var err error 173 resp := new(dhutil.Envelope) 174 switch s.DHType { 175 case "curve25519": 176 if len(s.cachedRemotePubKey) == 0 { 177 _, err = os.Lstat(s.DHPath) 178 if err != nil { 179 if !os.IsNotExist(err) { 180 return "", fmt.Errorf("error stat-ing dh parameters file: %w", err) 181 } 182 return "", errors.New("no dh parameters file found, and no cached pub key") 183 } 184 fileBytes, err := ioutil.ReadFile(s.DHPath) 185 if err != nil { 186 return "", fmt.Errorf("error reading file for dh parameters: %w", err) 187 } 188 theirPubKey := new(dhutil.PublicKeyInfo) 189 if err := jsonutil.DecodeJSON(fileBytes, theirPubKey); err != nil { 190 return "", fmt.Errorf("error decoding public key: %w", err) 191 } 192 if len(theirPubKey.Curve25519PublicKey) == 0 { 193 return "", errors.New("public key is nil") 194 } 195 s.cachedRemotePubKey = theirPubKey.Curve25519PublicKey 196 } 197 if len(s.cachedPubKey) == 0 { 198 s.cachedPubKey, s.cachedPriKey, err = dhutil.GeneratePublicPrivateKey() 199 if err != nil { 200 return "", fmt.Errorf("error generating pub/pri curve25519 keys: %w", err) 201 } 202 } 203 resp.Curve25519PublicKey = s.cachedPubKey 204 } 205 206 secret, err := dhutil.GenerateSharedSecret(s.cachedPriKey, s.cachedRemotePubKey) 207 if err != nil { 208 return "", fmt.Errorf("error calculating shared key: %w", err) 209 } 210 if s.DeriveKey { 211 aesKey, err = dhutil.DeriveSharedKey(secret, s.cachedPubKey, s.cachedRemotePubKey) 212 } else { 213 aesKey = secret 214 } 215 216 if err != nil { 217 return "", fmt.Errorf("error deriving shared key: %w", err) 218 } 219 if len(aesKey) == 0 { 220 return "", errors.New("derived AES key is empty") 221 } 222 223 resp.EncryptedPayload, resp.Nonce, err = dhutil.EncryptAES(aesKey, []byte(token), []byte(s.AAD)) 224 if err != nil { 225 return "", fmt.Errorf("error encrypting with shared key: %w", err) 226 } 227 m, err := jsonutil.EncodeJSON(resp) 228 if err != nil { 229 return "", fmt.Errorf("error encoding encrypted payload: %w", err) 230 } 231 232 return string(m), nil 233} 234 235func (s *SinkConfig) wrapToken(client *api.Client, wrapTTL time.Duration, token string) (string, error) { 236 wrapClient, err := client.Clone() 237 if err != nil { 238 return "", fmt.Errorf("error deriving client for wrapping, not writing out to sink: %w)", err) 239 } 240 wrapClient.SetToken(token) 241 wrapClient.SetWrappingLookupFunc(func(string, string) string { 242 return wrapTTL.String() 243 }) 244 secret, err := wrapClient.Logical().Write("sys/wrapping/wrap", map[string]interface{}{ 245 "token": token, 246 }) 247 if err != nil { 248 return "", fmt.Errorf("error wrapping token, not writing out to sink: %w)", err) 249 } 250 if secret == nil { 251 return "", errors.New("nil secret returned, not writing out to sink") 252 } 253 if secret.WrapInfo == nil { 254 return "", errors.New("nil wrap info returned, not writing out to sink") 255 } 256 257 m, err := jsonutil.EncodeJSON(secret.WrapInfo) 258 if err != nil { 259 return "", fmt.Errorf("error marshaling token, not writing out to sink: %w)", err) 260 } 261 262 return string(m), nil 263} 264