1// Copyright 2016 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package integration
16
17import (
18	"context"
19	"fmt"
20	"math/rand"
21	"strconv"
22	"testing"
23
24	v3 "github.com/coreos/etcd/clientv3"
25	"github.com/coreos/etcd/clientv3/concurrency"
26	"github.com/coreos/etcd/pkg/testutil"
27)
28
29// TestSTMConflict tests that conflicts are retried.
30func TestSTMConflict(t *testing.T) {
31	clus := NewClusterV3(t, &ClusterConfig{Size: 3})
32	defer clus.Terminate(t)
33
34	etcdc := clus.RandClient()
35	keys := make([]string, 5)
36	for i := 0; i < len(keys); i++ {
37		keys[i] = fmt.Sprintf("foo-%d", i)
38		if _, err := etcdc.Put(context.TODO(), keys[i], "100"); err != nil {
39			t.Fatalf("could not make key (%v)", err)
40		}
41	}
42
43	errc := make(chan error)
44	for i := range keys {
45		curEtcdc := clus.RandClient()
46		srcKey := keys[i]
47		applyf := func(stm concurrency.STM) error {
48			src := stm.Get(srcKey)
49			// must be different key to avoid double-adding
50			dstKey := srcKey
51			for dstKey == srcKey {
52				dstKey = keys[rand.Intn(len(keys))]
53			}
54			dst := stm.Get(dstKey)
55			srcV, _ := strconv.ParseInt(src, 10, 64)
56			dstV, _ := strconv.ParseInt(dst, 10, 64)
57			if srcV == 0 {
58				// can't rand.Intn on 0, so skip this transaction
59				return nil
60			}
61			xfer := int64(rand.Intn(int(srcV)) / 2)
62			stm.Put(srcKey, fmt.Sprintf("%d", srcV-xfer))
63			stm.Put(dstKey, fmt.Sprintf("%d", dstV+xfer))
64			return nil
65		}
66		go func() {
67			iso := concurrency.WithIsolation(concurrency.RepeatableReads)
68			_, err := concurrency.NewSTM(curEtcdc, applyf, iso)
69			errc <- err
70		}()
71	}
72
73	// wait for txns
74	for range keys {
75		if err := <-errc; err != nil {
76			t.Fatalf("apply failed (%v)", err)
77		}
78	}
79
80	// ensure sum matches initial sum
81	sum := 0
82	for _, oldkey := range keys {
83		rk, err := etcdc.Get(context.TODO(), oldkey)
84		if err != nil {
85			t.Fatalf("couldn't fetch key %s (%v)", oldkey, err)
86		}
87		v, _ := strconv.ParseInt(string(rk.Kvs[0].Value), 10, 64)
88		sum += int(v)
89	}
90	if sum != len(keys)*100 {
91		t.Fatalf("bad sum. got %d, expected %d", sum, len(keys)*100)
92	}
93}
94
95// TestSTMPutNewKey confirms a STM put on a new key is visible after commit.
96func TestSTMPutNewKey(t *testing.T) {
97	clus := NewClusterV3(t, &ClusterConfig{Size: 1})
98	defer clus.Terminate(t)
99
100	etcdc := clus.RandClient()
101	applyf := func(stm concurrency.STM) error {
102		stm.Put("foo", "bar")
103		return nil
104	}
105
106	iso := concurrency.WithIsolation(concurrency.RepeatableReads)
107	if _, err := concurrency.NewSTM(etcdc, applyf, iso); err != nil {
108		t.Fatalf("error on stm txn (%v)", err)
109	}
110
111	resp, err := etcdc.Get(context.TODO(), "foo")
112	if err != nil {
113		t.Fatalf("error fetching key (%v)", err)
114	}
115	if string(resp.Kvs[0].Value) != "bar" {
116		t.Fatalf("bad value. got %+v, expected 'bar' value", resp)
117	}
118}
119
120// TestSTMAbort tests that an aborted txn does not modify any keys.
121func TestSTMAbort(t *testing.T) {
122	clus := NewClusterV3(t, &ClusterConfig{Size: 1})
123	defer clus.Terminate(t)
124
125	etcdc := clus.RandClient()
126	ctx, cancel := context.WithCancel(context.TODO())
127	applyf := func(stm concurrency.STM) error {
128		stm.Put("foo", "baz")
129		cancel()
130		stm.Put("foo", "bap")
131		return nil
132	}
133
134	iso := concurrency.WithIsolation(concurrency.RepeatableReads)
135	sctx := concurrency.WithAbortContext(ctx)
136	if _, err := concurrency.NewSTM(etcdc, applyf, iso, sctx); err == nil {
137		t.Fatalf("no error on stm txn")
138	}
139
140	resp, err := etcdc.Get(context.TODO(), "foo")
141	if err != nil {
142		t.Fatalf("error fetching key (%v)", err)
143	}
144	if len(resp.Kvs) != 0 {
145		t.Fatalf("bad value. got %+v, expected nothing", resp)
146	}
147}
148
149// TestSTMSerialize tests that serialization is honored when serializable.
150func TestSTMSerialize(t *testing.T) {
151	clus := NewClusterV3(t, &ClusterConfig{Size: 3})
152	defer clus.Terminate(t)
153
154	etcdc := clus.RandClient()
155
156	// set up initial keys
157	keys := make([]string, 5)
158	for i := 0; i < len(keys); i++ {
159		keys[i] = fmt.Sprintf("foo-%d", i)
160	}
161
162	// update keys in full batches
163	updatec := make(chan struct{})
164	go func() {
165		defer close(updatec)
166		for i := 0; i < 5; i++ {
167			s := fmt.Sprintf("%d", i)
168			ops := []v3.Op{}
169			for _, k := range keys {
170				ops = append(ops, v3.OpPut(k, s))
171			}
172			if _, err := etcdc.Txn(context.TODO()).Then(ops...).Commit(); err != nil {
173				t.Fatalf("couldn't put keys (%v)", err)
174			}
175			updatec <- struct{}{}
176		}
177	}()
178
179	// read all keys in txn, make sure all values match
180	errc := make(chan error)
181	for range updatec {
182		curEtcdc := clus.RandClient()
183		applyf := func(stm concurrency.STM) error {
184			vs := []string{}
185			for i := range keys {
186				vs = append(vs, stm.Get(keys[i]))
187			}
188			for i := range vs {
189				if vs[0] != vs[i] {
190					return fmt.Errorf("got vs[%d] = %v, want %v", i, vs[i], vs[0])
191				}
192			}
193			return nil
194		}
195		go func() {
196			iso := concurrency.WithIsolation(concurrency.Serializable)
197			_, err := concurrency.NewSTM(curEtcdc, applyf, iso)
198			errc <- err
199		}()
200	}
201
202	for i := 0; i < 5; i++ {
203		if err := <-errc; err != nil {
204			t.Error(err)
205		}
206	}
207}
208
209// TestSTMApplyOnConcurrentDeletion ensures that concurrent key deletion
210// fails the first GET revision comparison within STM; trigger retry.
211func TestSTMApplyOnConcurrentDeletion(t *testing.T) {
212	clus := NewClusterV3(t, &ClusterConfig{Size: 1})
213	defer clus.Terminate(t)
214
215	etcdc := clus.RandClient()
216	if _, err := etcdc.Put(context.TODO(), "foo", "bar"); err != nil {
217		t.Fatal(err)
218	}
219	donec, readyc := make(chan struct{}), make(chan struct{})
220	go func() {
221		<-readyc
222		if _, err := etcdc.Delete(context.TODO(), "foo"); err != nil {
223			t.Fatal(err)
224		}
225		close(donec)
226	}()
227
228	try := 0
229	applyf := func(stm concurrency.STM) error {
230		try++
231		stm.Get("foo")
232		if try == 1 {
233			// trigger delete to make GET rev comparison outdated
234			close(readyc)
235			<-donec
236		}
237		stm.Put("foo2", "bar2")
238		return nil
239	}
240
241	iso := concurrency.WithIsolation(concurrency.RepeatableReads)
242	if _, err := concurrency.NewSTM(etcdc, applyf, iso); err != nil {
243		t.Fatalf("error on stm txn (%v)", err)
244	}
245	if try != 2 {
246		t.Fatalf("STM apply expected to run twice, got %d", try)
247	}
248
249	resp, err := etcdc.Get(context.TODO(), "foo2")
250	if err != nil {
251		t.Fatalf("error fetching key (%v)", err)
252	}
253	if string(resp.Kvs[0].Value) != "bar2" {
254		t.Fatalf("bad value. got %+v, expected 'bar2' value", resp)
255	}
256}
257
258func TestSTMSerializableSnapshotPut(t *testing.T) {
259	clus := NewClusterV3(t, &ClusterConfig{Size: 1})
260	defer clus.Terminate(t)
261
262	cli := clus.Client(0)
263	// key with lower create/mod revision than keys being updated
264	_, err := cli.Put(context.TODO(), "a", "0")
265	testutil.AssertNil(t, err)
266
267	tries := 0
268	applyf := func(stm concurrency.STM) error {
269		if tries > 2 {
270			return fmt.Errorf("too many retries")
271		}
272		tries++
273		stm.Get("a")
274		stm.Put("b", "1")
275		return nil
276	}
277
278	iso := concurrency.WithIsolation(concurrency.SerializableSnapshot)
279	_, err = concurrency.NewSTM(cli, applyf, iso)
280	testutil.AssertNil(t, err)
281	_, err = concurrency.NewSTM(cli, applyf, iso)
282	testutil.AssertNil(t, err)
283
284	resp, err := cli.Get(context.TODO(), "b")
285	testutil.AssertNil(t, err)
286	if resp.Kvs[0].Version != 2 {
287		t.Fatalf("bad version. got %+v, expected version 2", resp)
288	}
289}
290