1package etcd
2
3import (
4	"context"
5	"crypto/md5"
6	"encoding/json"
7	"fmt"
8	"sync"
9	"time"
10
11	"github.com/hashicorp/go-multierror"
12	"github.com/hashicorp/terraform/internal/states/remote"
13	"github.com/hashicorp/terraform/internal/states/statemgr"
14	etcdv3 "go.etcd.io/etcd/clientv3"
15	etcdv3sync "go.etcd.io/etcd/clientv3/concurrency"
16)
17
18const (
19	lockAcquireTimeout = 2 * time.Second
20	lockInfoSuffix     = ".lockinfo"
21)
22
23// RemoteClient is a remote client that will store data in etcd.
24type RemoteClient struct {
25	Client *etcdv3.Client
26	DoLock bool
27	Key    string
28
29	etcdMutex   *etcdv3sync.Mutex
30	etcdSession *etcdv3sync.Session
31	info        *statemgr.LockInfo
32	mu          sync.Mutex
33	modRevision int64
34}
35
36func (c *RemoteClient) Get() (*remote.Payload, error) {
37	c.mu.Lock()
38	defer c.mu.Unlock()
39
40	res, err := c.Client.KV.Get(context.TODO(), c.Key)
41	if err != nil {
42		return nil, err
43	}
44	if res.Count == 0 {
45		return nil, nil
46	}
47	if res.Count >= 2 {
48		return nil, fmt.Errorf("Expected a single result but got %d.", res.Count)
49	}
50
51	c.modRevision = res.Kvs[0].ModRevision
52
53	payload := res.Kvs[0].Value
54	md5 := md5.Sum(payload)
55
56	return &remote.Payload{
57		Data: payload,
58		MD5:  md5[:],
59	}, nil
60}
61
62func (c *RemoteClient) Put(data []byte) error {
63	c.mu.Lock()
64	defer c.mu.Unlock()
65
66	res, err := etcdv3.NewKV(c.Client).Txn(context.TODO()).If(
67		etcdv3.Compare(etcdv3.ModRevision(c.Key), "=", c.modRevision),
68	).Then(
69		etcdv3.OpPut(c.Key, string(data)),
70		etcdv3.OpGet(c.Key),
71	).Commit()
72
73	if err != nil {
74		return err
75	}
76	if !res.Succeeded {
77		return fmt.Errorf("The transaction did not succeed.")
78	}
79	if len(res.Responses) != 2 {
80		return fmt.Errorf("Expected two responses but got %d.", len(res.Responses))
81	}
82
83	c.modRevision = res.Responses[1].GetResponseRange().Kvs[0].ModRevision
84	return nil
85}
86
87func (c *RemoteClient) Delete() error {
88	c.mu.Lock()
89	defer c.mu.Unlock()
90
91	_, err := c.Client.KV.Delete(context.TODO(), c.Key)
92	return err
93}
94
95func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) {
96	c.mu.Lock()
97	defer c.mu.Unlock()
98
99	if !c.DoLock {
100		return "", nil
101	}
102	if c.etcdSession != nil {
103		return "", fmt.Errorf("state %q already locked", c.Key)
104	}
105
106	c.info = info
107	return c.lock()
108}
109
110func (c *RemoteClient) Unlock(id string) error {
111	c.mu.Lock()
112	defer c.mu.Unlock()
113
114	if !c.DoLock {
115		return nil
116	}
117
118	return c.unlock(id)
119}
120
121func (c *RemoteClient) deleteLockInfo(info *statemgr.LockInfo) error {
122	res, err := c.Client.KV.Delete(context.TODO(), c.Key+lockInfoSuffix)
123	if err != nil {
124		return err
125	}
126	if res.Deleted == 0 {
127		return fmt.Errorf("No keys deleted for %s when deleting lock info.", c.Key+lockInfoSuffix)
128	}
129	return nil
130}
131
132func (c *RemoteClient) getLockInfo() (*statemgr.LockInfo, error) {
133	res, err := c.Client.KV.Get(context.TODO(), c.Key+lockInfoSuffix)
134	if err != nil {
135		return nil, err
136	}
137	if res.Count == 0 {
138		return nil, nil
139	}
140
141	li := &statemgr.LockInfo{}
142	err = json.Unmarshal(res.Kvs[0].Value, li)
143	if err != nil {
144		return nil, fmt.Errorf("Error unmarshaling lock info: %s.", err)
145	}
146
147	return li, nil
148}
149
150func (c *RemoteClient) putLockInfo(info *statemgr.LockInfo) error {
151	c.info.Path = c.etcdMutex.Key()
152	c.info.Created = time.Now().UTC()
153
154	_, err := c.Client.KV.Put(context.TODO(), c.Key+lockInfoSuffix, string(c.info.Marshal()))
155	return err
156}
157
158func (c *RemoteClient) lock() (string, error) {
159	session, err := etcdv3sync.NewSession(c.Client)
160	if err != nil {
161		return "", nil
162	}
163
164	ctx, cancel := context.WithTimeout(context.TODO(), lockAcquireTimeout)
165	defer cancel()
166
167	mutex := etcdv3sync.NewMutex(session, c.Key)
168	if err1 := mutex.Lock(ctx); err1 != nil {
169		lockInfo, err2 := c.getLockInfo()
170		if err2 != nil {
171			return "", &statemgr.LockError{Err: err2}
172		}
173		return "", &statemgr.LockError{Info: lockInfo, Err: err1}
174	}
175
176	c.etcdMutex = mutex
177	c.etcdSession = session
178
179	err = c.putLockInfo(c.info)
180	if err != nil {
181		if unlockErr := c.unlock(c.info.ID); unlockErr != nil {
182			err = multierror.Append(err, unlockErr)
183		}
184		return "", err
185	}
186
187	return c.info.ID, nil
188}
189
190func (c *RemoteClient) unlock(id string) error {
191	if c.etcdMutex == nil {
192		return nil
193	}
194
195	var errs error
196
197	if err := c.deleteLockInfo(c.info); err != nil {
198		errs = multierror.Append(errs, err)
199	}
200	if err := c.etcdMutex.Unlock(context.TODO()); err != nil {
201		errs = multierror.Append(errs, err)
202	}
203	if err := c.etcdSession.Close(); err != nil {
204		errs = multierror.Append(errs, err)
205	}
206
207	c.etcdMutex = nil
208	c.etcdSession = nil
209
210	return errs
211}
212