1package sink
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"io/ioutil"
8	"math/rand"
9	"os"
10	"sync/atomic"
11	"time"
12
13	hclog "github.com/hashicorp/go-hclog"
14	"github.com/hashicorp/vault/api"
15	"github.com/hashicorp/vault/helper/dhutil"
16	"github.com/hashicorp/vault/sdk/helper/jsonutil"
17)
18
19type Sink interface {
20	WriteToken(string) error
21}
22
23type SinkReader interface {
24	Token() string
25}
26
27type SinkConfig struct {
28	Sink
29	Logger             hclog.Logger
30	Config             map[string]interface{}
31	Client             *api.Client
32	WrapTTL            time.Duration
33	DHType             string
34	DHPath             string
35	DeriveKey          bool
36	AAD                string
37	cachedRemotePubKey []byte
38	cachedPubKey       []byte
39	cachedPriKey       []byte
40}
41
42type SinkServerConfig struct {
43	Logger        hclog.Logger
44	Client        *api.Client
45	Context       context.Context
46	ExitAfterAuth bool
47}
48
49// SinkServer is responsible for pushing tokens to sinks
50type SinkServer struct {
51	logger        hclog.Logger
52	client        *api.Client
53	random        *rand.Rand
54	exitAfterAuth bool
55	remaining     *int32
56}
57
58func NewSinkServer(conf *SinkServerConfig) *SinkServer {
59	ss := &SinkServer{
60		logger:        conf.Logger,
61		client:        conf.Client,
62		random:        rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
63		exitAfterAuth: conf.ExitAfterAuth,
64		remaining:     new(int32),
65	}
66
67	return ss
68}
69
70// Run executes the server's run loop, which is responsible for reading
71// in new tokens and pushing them out to the various sinks.
72func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*SinkConfig) error {
73	latestToken := new(string)
74	writeSink := func(currSink *SinkConfig, currToken string) error {
75		if currToken != *latestToken {
76			return nil
77		}
78		var err error
79
80		if currSink.WrapTTL != 0 {
81			if currToken, err = currSink.wrapToken(ss.client, currSink.WrapTTL, currToken); err != nil {
82				return err
83			}
84		}
85
86		if currSink.DHType != "" {
87			if currToken, err = currSink.encryptToken(currToken); err != nil {
88				return err
89			}
90		}
91
92		return currSink.WriteToken(currToken)
93	}
94
95	if incoming == nil {
96		return errors.New("sink server: incoming channel is nil")
97	}
98
99	ss.logger.Info("starting sink server")
100	defer func() {
101		ss.logger.Info("sink server stopped")
102	}()
103
104	type sinkToken struct {
105		sink  *SinkConfig
106		token string
107	}
108	sinkCh := make(chan sinkToken, len(sinks))
109	for {
110		select {
111		case <-ctx.Done():
112			return nil
113
114		case token := <-incoming:
115			if len(sinks) > 0 {
116				if token != *latestToken {
117
118					// Drain the existing funcs
119				drainLoop:
120					for {
121						select {
122						case <-sinkCh:
123							atomic.AddInt32(ss.remaining, -1)
124						default:
125							break drainLoop
126						}
127					}
128
129					*latestToken = token
130
131					for _, s := range sinks {
132						atomic.AddInt32(ss.remaining, 1)
133						sinkCh <- sinkToken{s, token}
134					}
135				}
136			} else {
137				ss.logger.Trace("no sinks, ignoring new token")
138				if ss.exitAfterAuth {
139					ss.logger.Trace("no sinks, exitAfterAuth, bye")
140					return nil
141				}
142			}
143		case st := <-sinkCh:
144			atomic.AddInt32(ss.remaining, -1)
145			select {
146			case <-ctx.Done():
147				return nil
148			default:
149			}
150
151			if err := writeSink(st.sink, st.token); err != nil {
152				backoff := 2*time.Second + time.Duration(ss.random.Int63()%int64(time.Second*2)-int64(time.Second))
153				ss.logger.Error("error returned by sink function, retrying", "error", err, "backoff", backoff.String())
154				select {
155				case <-ctx.Done():
156					return nil
157				case <-time.After(backoff):
158					atomic.AddInt32(ss.remaining, 1)
159					sinkCh <- st
160				}
161			} else {
162				if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth {
163					return nil
164				}
165			}
166		}
167	}
168}
169
170func (s *SinkConfig) encryptToken(token string) (string, error) {
171	var aesKey []byte
172	var err error
173	resp := new(dhutil.Envelope)
174	switch s.DHType {
175	case "curve25519":
176		if len(s.cachedRemotePubKey) == 0 {
177			_, err = os.Lstat(s.DHPath)
178			if err != nil {
179				if !os.IsNotExist(err) {
180					return "", fmt.Errorf("error stat-ing dh parameters file: %w", err)
181				}
182				return "", errors.New("no dh parameters file found, and no cached pub key")
183			}
184			fileBytes, err := ioutil.ReadFile(s.DHPath)
185			if err != nil {
186				return "", fmt.Errorf("error reading file for dh parameters: %w", err)
187			}
188			theirPubKey := new(dhutil.PublicKeyInfo)
189			if err := jsonutil.DecodeJSON(fileBytes, theirPubKey); err != nil {
190				return "", fmt.Errorf("error decoding public key: %w", err)
191			}
192			if len(theirPubKey.Curve25519PublicKey) == 0 {
193				return "", errors.New("public key is nil")
194			}
195			s.cachedRemotePubKey = theirPubKey.Curve25519PublicKey
196		}
197		if len(s.cachedPubKey) == 0 {
198			s.cachedPubKey, s.cachedPriKey, err = dhutil.GeneratePublicPrivateKey()
199			if err != nil {
200				return "", fmt.Errorf("error generating pub/pri curve25519 keys: %w", err)
201			}
202		}
203		resp.Curve25519PublicKey = s.cachedPubKey
204	}
205
206	secret, err := dhutil.GenerateSharedSecret(s.cachedPriKey, s.cachedRemotePubKey)
207	if err != nil {
208		return "", fmt.Errorf("error calculating shared key: %w", err)
209	}
210	if s.DeriveKey {
211		aesKey, err = dhutil.DeriveSharedKey(secret, s.cachedPubKey, s.cachedRemotePubKey)
212	} else {
213		aesKey = secret
214	}
215
216	if err != nil {
217		return "", fmt.Errorf("error deriving shared key: %w", err)
218	}
219	if len(aesKey) == 0 {
220		return "", errors.New("derived AES key is empty")
221	}
222
223	resp.EncryptedPayload, resp.Nonce, err = dhutil.EncryptAES(aesKey, []byte(token), []byte(s.AAD))
224	if err != nil {
225		return "", fmt.Errorf("error encrypting with shared key: %w", err)
226	}
227	m, err := jsonutil.EncodeJSON(resp)
228	if err != nil {
229		return "", fmt.Errorf("error encoding encrypted payload: %w", err)
230	}
231
232	return string(m), nil
233}
234
235func (s *SinkConfig) wrapToken(client *api.Client, wrapTTL time.Duration, token string) (string, error) {
236	wrapClient, err := client.Clone()
237	if err != nil {
238		return "", fmt.Errorf("error deriving client for wrapping, not writing out to sink: %w)", err)
239	}
240	wrapClient.SetToken(token)
241	wrapClient.SetWrappingLookupFunc(func(string, string) string {
242		return wrapTTL.String()
243	})
244	secret, err := wrapClient.Logical().Write("sys/wrapping/wrap", map[string]interface{}{
245		"token": token,
246	})
247	if err != nil {
248		return "", fmt.Errorf("error wrapping token, not writing out to sink: %w)", err)
249	}
250	if secret == nil {
251		return "", errors.New("nil secret returned, not writing out to sink")
252	}
253	if secret.WrapInfo == nil {
254		return "", errors.New("nil wrap info returned, not writing out to sink")
255	}
256
257	m, err := jsonutil.EncodeJSON(secret.WrapInfo)
258	if err != nil {
259		return "", fmt.Errorf("error marshaling token, not writing out to sink: %w)", err)
260	}
261
262	return string(m), nil
263}
264