1// Copyright 2017 The etcd Authors 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package concurrency_test 16 17import ( 18 "context" 19 "fmt" 20 "log" 21 "math/rand" 22 "sync" 23 24 "go.etcd.io/etcd/clientv3" 25 "go.etcd.io/etcd/clientv3/concurrency" 26) 27 28// ExampleSTM_apply shows how to use STM with a transactional 29// transfer between balances. 30func ExampleSTM_apply() { 31 cli, err := clientv3.New(clientv3.Config{Endpoints: endpoints}) 32 if err != nil { 33 log.Fatal(err) 34 } 35 defer cli.Close() 36 37 // set up "accounts" 38 totalAccounts := 5 39 for i := 0; i < totalAccounts; i++ { 40 k := fmt.Sprintf("accts/%d", i) 41 if _, err = cli.Put(context.TODO(), k, "100"); err != nil { 42 log.Fatal(err) 43 } 44 } 45 46 exchange := func(stm concurrency.STM) error { 47 from, to := rand.Intn(totalAccounts), rand.Intn(totalAccounts) 48 if from == to { 49 // nothing to do 50 return nil 51 } 52 // read values 53 fromK, toK := fmt.Sprintf("accts/%d", from), fmt.Sprintf("accts/%d", to) 54 fromV, toV := stm.Get(fromK), stm.Get(toK) 55 fromInt, toInt := 0, 0 56 fmt.Sscanf(fromV, "%d", &fromInt) 57 fmt.Sscanf(toV, "%d", &toInt) 58 59 // transfer amount 60 xfer := fromInt / 2 61 fromInt, toInt = fromInt-xfer, toInt+xfer 62 63 // write back 64 stm.Put(fromK, fmt.Sprintf("%d", fromInt)) 65 stm.Put(toK, fmt.Sprintf("%d", toInt)) 66 return nil 67 } 68 69 // concurrently exchange values between accounts 70 var wg sync.WaitGroup 71 wg.Add(10) 72 for i := 0; i < 10; i++ { 73 go func() { 74 defer wg.Done() 75 if _, serr := concurrency.NewSTM(cli, exchange); serr != nil { 76 log.Fatal(serr) 77 } 78 }() 79 } 80 wg.Wait() 81 82 // confirm account sum matches sum from beginning. 83 sum := 0 84 accts, err := cli.Get(context.TODO(), "accts/", clientv3.WithPrefix()) 85 if err != nil { 86 log.Fatal(err) 87 } 88 for _, kv := range accts.Kvs { 89 v := 0 90 fmt.Sscanf(string(kv.Value), "%d", &v) 91 sum += v 92 } 93 94 fmt.Println("account sum is", sum) 95 // Output: 96 // account sum is 500 97} 98