1// Copyright 2019 The Go Cloud Development Kit Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package awsdynamodb
16
17import (
18	"fmt"
19	"net/http"
20	"strconv"
21	"testing"
22
23	"github.com/aws/aws-sdk-go/aws"
24	awscreds "github.com/aws/aws-sdk-go/aws/credentials"
25	"github.com/aws/aws-sdk-go/aws/session"
26	"github.com/aws/aws-sdk-go/service/dynamodb"
27	dyn "github.com/aws/aws-sdk-go/service/dynamodb"
28	"github.com/aws/aws-sdk-go/service/dynamodb/expression"
29)
30
31var benchmarkTableName = collectionName3
32
33func BenchmarkPutVSTransact(b *testing.B) {
34	// This benchmark compares two ways to replace N items and retrieve their previous values.
35	// The first way makes N calls to PutItem with ReturnValues set to ALL_OLD.
36	// The second way calls BatchGetItem followed by TransactWriteItem.
37	//
38	// The results show that separate PutItems are faster for up to two items.
39	sess, err := awsSession(region, http.DefaultClient)
40	if err != nil {
41		b.Fatal(err)
42	}
43	db := dynamodb.New(sess)
44
45	for nItems := 1; nItems <= 5; nItems++ {
46		b.Run(fmt.Sprintf("%d-Items", nItems), func(b *testing.B) {
47			var items []map[string]*dynamodb.AttributeValue
48			for i := 0; i < nItems; i++ {
49				items = append(items, map[string]*dynamodb.AttributeValue{
50					"name": new(dyn.AttributeValue).SetS(fmt.Sprintf("pt-vs-transact-%d", i)),
51					"x":    new(dyn.AttributeValue).SetN(strconv.Itoa(i)),
52					"rev":  new(dyn.AttributeValue).SetN("1"),
53				})
54			}
55			for _, item := range items {
56				_, err := db.PutItem(&dynamodb.PutItemInput{
57					TableName: &benchmarkTableName,
58					Item:      item,
59				})
60				if err != nil {
61					b.Fatal(err)
62				}
63			}
64			b.Run("PutItem", func(b *testing.B) {
65				for n := 0; n < b.N; n++ {
66					putItems(b, db, items)
67				}
68			})
69			b.Run("TransactWrite", func(b *testing.B) {
70				for n := 0; n < b.N; n++ {
71					batchGetTransactWrite(b, db, items)
72				}
73			})
74		})
75
76	}
77}
78
79func putItems(b *testing.B, db *dynamodb.DynamoDB, items []map[string]*dynamodb.AttributeValue) {
80	for i, item := range items {
81		item["x"].SetN(strconv.Itoa(i + 1))
82		in := &dynamodb.PutItemInput{
83			TableName:    &benchmarkTableName,
84			Item:         item,
85			ReturnValues: aws.String("ALL_OLD"),
86		}
87		ce, err := expression.NewBuilder().
88			WithCondition(expression.Name("rev").Equal(expression.Value(1))).
89			Build()
90		if err != nil {
91			b.Fatal(err)
92		}
93		in.ExpressionAttributeNames = ce.Names()
94		in.ExpressionAttributeValues = ce.Values()
95		in.ConditionExpression = ce.Condition()
96		out, err := db.PutItem(in)
97		if err != nil {
98			b.Fatal(err)
99		}
100		if got, want := len(out.Attributes), 3; got != want {
101			b.Fatalf("got %d attributes, want %d", got, want)
102		}
103	}
104}
105
106func batchGetTransactWrite(b *testing.B, db *dynamodb.DynamoDB, items []map[string]*dynamodb.AttributeValue) {
107	keys := make([]map[string]*dynamodb.AttributeValue, len(items))
108	tws := make([]*dyn.TransactWriteItem, len(items))
109	for i, item := range items {
110		keys[i] = map[string]*dynamodb.AttributeValue{"name": items[i]["name"]}
111		item["x"].SetN(strconv.Itoa(i + 2))
112		put := &dynamodb.Put{TableName: &benchmarkTableName, Item: items[i]}
113		ce, err := expression.NewBuilder().
114			WithCondition(expression.Name("rev").Equal(expression.Value(1))).
115			Build()
116		if err != nil {
117			b.Fatal(err)
118		}
119		put.ExpressionAttributeNames = ce.Names()
120		put.ExpressionAttributeValues = ce.Values()
121		put.ConditionExpression = ce.Condition()
122		tws[i] = &dynamodb.TransactWriteItem{Put: put}
123	}
124	_, err := db.BatchGetItem(&dynamodb.BatchGetItemInput{
125		RequestItems: map[string]*dynamodb.KeysAndAttributes{
126			benchmarkTableName: {Keys: keys},
127		},
128	})
129	if err != nil {
130		b.Fatal(err)
131	}
132	_, err = db.TransactWriteItems(&dynamodb.TransactWriteItemsInput{TransactItems: tws})
133	if err != nil {
134		b.Fatal(err)
135	}
136}
137
138func awsSession(region string, client *http.Client) (*session.Session, error) {
139	// Provide fake creds if running in replay mode.
140	var creds *awscreds.Credentials
141	return session.NewSession(&aws.Config{
142		HTTPClient:  client,
143		Region:      aws.String(region),
144		Credentials: creds,
145		MaxRetries:  aws.Int(0),
146	})
147}
148