1package sink
2
3import (
4	"context"
5	"errors"
6	"io/ioutil"
7	"math/rand"
8	"os"
9	"sync/atomic"
10	"time"
11
12	"github.com/hashicorp/errwrap"
13	hclog "github.com/hashicorp/go-hclog"
14	"github.com/hashicorp/vault/api"
15	"github.com/hashicorp/vault/helper/dhutil"
16	"github.com/hashicorp/vault/sdk/helper/jsonutil"
17)
18
19type Sink interface {
20	WriteToken(string) error
21}
22
23type SinkReader interface {
24	Token() string
25}
26
27type SinkConfig struct {
28	Sink
29	Logger             hclog.Logger
30	Config             map[string]interface{}
31	Client             *api.Client
32	WrapTTL            time.Duration
33	DHType             string
34	DHPath             string
35	AAD                string
36	cachedRemotePubKey []byte
37	cachedPubKey       []byte
38	cachedPriKey       []byte
39}
40
41type SinkServerConfig struct {
42	Logger        hclog.Logger
43	Client        *api.Client
44	Context       context.Context
45	ExitAfterAuth bool
46}
47
48// SinkServer is responsible for pushing tokens to sinks
49type SinkServer struct {
50	DoneCh        chan struct{}
51	logger        hclog.Logger
52	client        *api.Client
53	random        *rand.Rand
54	exitAfterAuth bool
55	remaining     *int32
56}
57
58func NewSinkServer(conf *SinkServerConfig) *SinkServer {
59	ss := &SinkServer{
60		DoneCh:        make(chan struct{}),
61		logger:        conf.Logger,
62		client:        conf.Client,
63		random:        rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
64		exitAfterAuth: conf.ExitAfterAuth,
65		remaining:     new(int32),
66	}
67
68	return ss
69}
70
71// Run executes the server's run loop, which is responsible for reading
72// in new tokens and pushing them out to the various sinks.
73func (ss *SinkServer) Run(ctx context.Context, incoming chan string, sinks []*SinkConfig) {
74	latestToken := new(string)
75	writeSink := func(currSink *SinkConfig, currToken string) error {
76		if currToken != *latestToken {
77			return nil
78		}
79		var err error
80
81		if currSink.WrapTTL != 0 {
82			if currToken, err = currSink.wrapToken(ss.client, currSink.WrapTTL, currToken); err != nil {
83				return err
84			}
85		}
86
87		if currSink.DHType != "" {
88			if currToken, err = currSink.encryptToken(currToken); err != nil {
89				return err
90			}
91		}
92
93		return currSink.WriteToken(currToken)
94	}
95
96	if incoming == nil {
97		panic("incoming channel is nil")
98	}
99
100	ss.logger.Info("starting sink server")
101	defer func() {
102		ss.logger.Info("sink server stopped")
103		close(ss.DoneCh)
104	}()
105
106	type sinkToken struct {
107		sink  *SinkConfig
108		token string
109	}
110	sinkCh := make(chan sinkToken, len(sinks))
111	for {
112		select {
113		case <-ctx.Done():
114			return
115
116		case token := <-incoming:
117			if token != *latestToken {
118
119				// Drain the existing funcs
120			drainLoop:
121				for {
122					select {
123					case <-sinkCh:
124						atomic.AddInt32(ss.remaining, -1)
125					default:
126						break drainLoop
127					}
128				}
129
130				*latestToken = token
131
132				for _, s := range sinks {
133					atomic.AddInt32(ss.remaining, 1)
134					sinkCh <- sinkToken{s, token}
135				}
136			}
137
138		case st := <-sinkCh:
139			atomic.AddInt32(ss.remaining, -1)
140			select {
141			case <-ctx.Done():
142				return
143			default:
144			}
145
146			if err := writeSink(st.sink, st.token); err != nil {
147				backoff := 2*time.Second + time.Duration(ss.random.Int63()%int64(time.Second*2)-int64(time.Second))
148				ss.logger.Error("error returned by sink function, retrying", "error", err, "backoff", backoff.String())
149				select {
150				case <-ctx.Done():
151					return
152				case <-time.After(backoff):
153					atomic.AddInt32(ss.remaining, 1)
154					sinkCh <- st
155				}
156			} else {
157				if atomic.LoadInt32(ss.remaining) == 0 && ss.exitAfterAuth {
158					return
159				}
160			}
161		}
162	}
163}
164
165func (s *SinkConfig) encryptToken(token string) (string, error) {
166	var aesKey []byte
167	var err error
168	resp := new(dhutil.Envelope)
169	switch s.DHType {
170	case "curve25519":
171		if len(s.cachedRemotePubKey) == 0 {
172			_, err = os.Lstat(s.DHPath)
173			if err != nil {
174				if !os.IsNotExist(err) {
175					return "", errwrap.Wrapf("error stat-ing dh parameters file: {{err}}", err)
176				}
177				return "", errors.New("no dh parameters file found, and no cached pub key")
178			}
179			fileBytes, err := ioutil.ReadFile(s.DHPath)
180			if err != nil {
181				return "", errwrap.Wrapf("error reading file for dh parameters: {{err}}", err)
182			}
183			theirPubKey := new(dhutil.PublicKeyInfo)
184			if err := jsonutil.DecodeJSON(fileBytes, theirPubKey); err != nil {
185				return "", errwrap.Wrapf("error decoding public key: {{err}}", err)
186			}
187			if len(theirPubKey.Curve25519PublicKey) == 0 {
188				return "", errors.New("public key is nil")
189			}
190			s.cachedRemotePubKey = theirPubKey.Curve25519PublicKey
191		}
192		if len(s.cachedPubKey) == 0 {
193			s.cachedPubKey, s.cachedPriKey, err = dhutil.GeneratePublicPrivateKey()
194			if err != nil {
195				return "", errwrap.Wrapf("error generating pub/pri curve25519 keys: {{err}}", err)
196			}
197		}
198		resp.Curve25519PublicKey = s.cachedPubKey
199	}
200
201	aesKey, err = dhutil.GenerateSharedKey(s.cachedPriKey, s.cachedRemotePubKey)
202	if err != nil {
203		return "", errwrap.Wrapf("error deriving shared key: {{err}}", err)
204	}
205	if len(aesKey) == 0 {
206		return "", errors.New("derived AES key is empty")
207	}
208
209	resp.EncryptedPayload, resp.Nonce, err = dhutil.EncryptAES(aesKey, []byte(token), []byte(s.AAD))
210	if err != nil {
211		return "", errwrap.Wrapf("error encrypting with shared key: {{err}}", err)
212	}
213	m, err := jsonutil.EncodeJSON(resp)
214	if err != nil {
215		return "", errwrap.Wrapf("error encoding encrypted payload: {{err}}", err)
216	}
217
218	return string(m), nil
219}
220
221func (s *SinkConfig) wrapToken(client *api.Client, wrapTTL time.Duration, token string) (string, error) {
222	wrapClient, err := client.Clone()
223	if err != nil {
224		return "", errwrap.Wrapf("error deriving client for wrapping, not writing out to sink: {{err}})", err)
225	}
226	wrapClient.SetToken(token)
227	wrapClient.SetWrappingLookupFunc(func(string, string) string {
228		return wrapTTL.String()
229	})
230	secret, err := wrapClient.Logical().Write("sys/wrapping/wrap", map[string]interface{}{
231		"token": token,
232	})
233	if err != nil {
234		return "", errwrap.Wrapf("error wrapping token, not writing out to sink: {{err}})", err)
235	}
236	if secret == nil {
237		return "", errors.New("nil secret returned, not writing out to sink")
238	}
239	if secret.WrapInfo == nil {
240		return "", errors.New("nil wrap info returned, not writing out to sink")
241	}
242
243	m, err := jsonutil.EncodeJSON(secret.WrapInfo)
244	if err != nil {
245		return "", errwrap.Wrapf("error marshaling token, not writing out to sink: {{err}})", err)
246	}
247
248	return string(m), nil
249}
250