1package sink 2 3import ( 4 "context" 5 "errors" 6 "io/ioutil" 7 "math/rand" 8 "os" 9 "sync/atomic" 10 "time" 11 12 "github.com/hashicorp/errwrap" 13 hclog "github.com/hashicorp/go-hclog" 14 "github.com/hashicorp/vault/api" 15 "github.com/hashicorp/vault/helper/dhutil" 16 "github.com/hashicorp/vault/sdk/helper/jsonutil" 17) 18 19type Sink interface { 20 WriteToken(string) error 21} 22 23type SinkReader interface { 24 Token() string 25} 26 27type SinkConfig struct { 28 Sink 29 Logger hclog.Logger 30 Config map[string]interface{} 31 Client *api.Client 32 WrapTTL time.Duration 33 DHType string 34 DHPath string 35 AAD string 36 cachedRemotePubKey []byte 37 cachedPubKey []byte 38 cachedPriKey []byte 39} 40 41type SinkServerConfig struct { 42 Logger hclog.Logger 43 Client *api.Client 44 Context context.Context 45 ExitAfterAuth bool 46} 47 48// SinkServer is responsible for pushing tokens to sinks 49type SinkServer struct { 50 DoneCh chan struct{} 51 logger hclog.Logger 52 client *api.Client 53 random *rand.Rand 54 exitAfterAuth bool 55 remaining *int32 56} 57 58func NewSinkServer(conf *SinkServerConfig) *SinkServer { 59 ss := &SinkServer{ 60 DoneCh: make(chan struct{}), 61 logger: conf.Logger, 62 client: conf.Client, 63 random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))), 64 exitAfterAuth: conf.ExitAfterAuth, 65 remaining: new(int32), 66 } 67 68 return ss 69} 70 71// Run executes the server's run loop, which is responsible for reading 72// in new tokens and pushing them out to the various sinks. 73func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*SinkConfig) { 74 latestToken := new(string) 75 writeSink := func(currSink *SinkConfig, currToken string) error { 76 if currToken != *latestToken { 77 return nil 78 } 79 var err error 80 81 if currSink.WrapTTL != 0 { 82 if currToken, err = currSink.wrapToken(ss.client, currSink.WrapTTL, currToken); err != nil { 83 return err 84 } 85 } 86 87 if currSink.DHType != "" { 88 if currToken, err = currSink.encryptToken(currToken); err != nil { 89 return err 90 } 91 } 92 93 return currSink.WriteToken(currToken) 94 } 95 96 if incoming == nil { 97 panic("incoming channel is nil") 98 } 99 100 ss.logger.Info("starting sink server") 101 defer func() { 102 ss.logger.Info("sink server stopped") 103 close(ss.DoneCh) 104 }() 105 106 type sinkToken struct { 107 sink *SinkConfig 108 token string 109 } 110 sinkCh := make(chan sinkToken, len(sinks)) 111 for { 112 select { 113 case <-ctx.Done(): 114 return 115 116 case token := <-incoming: 117 if token != *latestToken { 118 119 // Drain the existing funcs 120 drainLoop: 121 for { 122 select { 123 case <-sinkCh: 124 atomic.AddInt32(ss.remaining, -1) 125 default: 126 break drainLoop 127 } 128 } 129 130 *latestToken = token 131 132 for _, s := range sinks { 133 atomic.AddInt32(ss.remaining, 1) 134 sinkCh <- sinkToken{s, token} 135 } 136 } 137 138 case st := <-sinkCh: 139 atomic.AddInt32(ss.remaining, -1) 140 select { 141 case <-ctx.Done(): 142 return 143 default: 144 } 145 146 if err := writeSink(st.sink, st.token); err != nil { 147 backoff := 2*time.Second + time.Duration(ss.random.Int63()%int64(time.Second*2)-int64(time.Second)) 148 ss.logger.Error("error returned by sink function, retrying", "error", err, "backoff", backoff.String()) 149 select { 150 case <-ctx.Done(): 151 return 152 case <-time.After(backoff): 153 atomic.AddInt32(ss.remaining, 1) 154 sinkCh <- st 155 } 156 } else { 157 if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth { 158 return 159 } 160 } 161 } 162 } 163} 164 165func (s *SinkConfig) encryptToken(token string) (string, error) { 166 var aesKey []byte 167 var err error 168 resp := new(dhutil.Envelope) 169 switch s.DHType { 170 case "curve25519": 171 if len(s.cachedRemotePubKey) == 0 { 172 _, err = os.Lstat(s.DHPath) 173 if err != nil { 174 if !os.IsNotExist(err) { 175 return "", errwrap.Wrapf("error stat-ing dh parameters file: {{err}}", err) 176 } 177 return "", errors.New("no dh parameters file found, and no cached pub key") 178 } 179 fileBytes, err := ioutil.ReadFile(s.DHPath) 180 if err != nil { 181 return "", errwrap.Wrapf("error reading file for dh parameters: {{err}}", err) 182 } 183 theirPubKey := new(dhutil.PublicKeyInfo) 184 if err := jsonutil.DecodeJSON(fileBytes, theirPubKey); err != nil { 185 return "", errwrap.Wrapf("error decoding public key: {{err}}", err) 186 } 187 if len(theirPubKey.Curve25519PublicKey) == 0 { 188 return "", errors.New("public key is nil") 189 } 190 s.cachedRemotePubKey = theirPubKey.Curve25519PublicKey 191 } 192 if len(s.cachedPubKey) == 0 { 193 s.cachedPubKey, s.cachedPriKey, err = dhutil.GeneratePublicPrivateKey() 194 if err != nil { 195 return "", errwrap.Wrapf("error generating pub/pri curve25519 keys: {{err}}", err) 196 } 197 } 198 resp.Curve25519PublicKey = s.cachedPubKey 199 } 200 201 aesKey, err = dhutil.GenerateSharedKey(s.cachedPriKey, s.cachedRemotePubKey) 202 if err != nil { 203 return "", errwrap.Wrapf("error deriving shared key: {{err}}", err) 204 } 205 if len(aesKey) == 0 { 206 return "", errors.New("derived AES key is empty") 207 } 208 209 resp.EncryptedPayload, resp.Nonce, err = dhutil.EncryptAES(aesKey, []byte(token), []byte(s.AAD)) 210 if err != nil { 211 return "", errwrap.Wrapf("error encrypting with shared key: {{err}}", err) 212 } 213 m, err := jsonutil.EncodeJSON(resp) 214 if err != nil { 215 return "", errwrap.Wrapf("error encoding encrypted payload: {{err}}", err) 216 } 217 218 return string(m), nil 219} 220 221func (s *SinkConfig) wrapToken(client *api.Client, wrapTTL time.Duration, token string) (string, error) { 222 wrapClient, err := client.Clone() 223 if err != nil { 224 return "", errwrap.Wrapf("error deriving client for wrapping, not writing out to sink: {{err}})", err) 225 } 226 wrapClient.SetToken(token) 227 wrapClient.SetWrappingLookupFunc(func(string, string) string { 228 return wrapTTL.String() 229 }) 230 secret, err := wrapClient.Logical().Write("sys/wrapping/wrap", map[string]interface{}{ 231 "token": token, 232 }) 233 if err != nil { 234 return "", errwrap.Wrapf("error wrapping token, not writing out to sink: {{err}})", err) 235 } 236 if secret == nil { 237 return "", errors.New("nil secret returned, not writing out to sink") 238 } 239 if secret.WrapInfo == nil { 240 return "", errors.New("nil wrap info returned, not writing out to sink") 241 } 242 243 m, err := jsonutil.EncodeJSON(secret.WrapInfo) 244 if err != nil { 245 return "", errwrap.Wrapf("error marshaling token, not writing out to sink: {{err}})", err) 246 } 247 248 return string(m), nil 249} 250