1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package spanner
16
17import (
18	"bytes"
19	"context"
20	"io/ioutil"
21	"log"
22	"os"
23	"testing"
24	"time"
25
26	. "cloud.google.com/go/spanner/internal/testutil"
27	sppb "google.golang.org/genproto/googleapis/spanner/v1"
28	"google.golang.org/grpc/codes"
29	"google.golang.org/grpc/status"
30)
31
32func TestMockPartitionedUpdate(t *testing.T) {
33	t.Parallel()
34	ctx := context.Background()
35	_, client, teardown := setupMockedTestServer(t)
36	defer teardown()
37
38	stmt := NewStatement(UpdateBarSetFoo)
39	rowCount, err := client.PartitionedUpdate(ctx, stmt)
40	if err != nil {
41		t.Fatal(err)
42	}
43	want := int64(UpdateBarSetFooRowCount)
44	if rowCount != want {
45		t.Errorf("got %d, want %d", rowCount, want)
46	}
47}
48
49func TestMockPartitionedUpdateWithQuery(t *testing.T) {
50	t.Parallel()
51	ctx := context.Background()
52	_, client, teardown := setupMockedTestServer(t)
53	defer teardown()
54
55	stmt := NewStatement(SelectFooFromBar)
56	_, err := client.PartitionedUpdate(ctx, stmt)
57	wantCode := codes.InvalidArgument
58	var serr *Error
59	if !errorAs(err, &serr) {
60		t.Errorf("got error %v, want spanner.Error", err)
61	}
62	if ErrCode(serr) != wantCode {
63		t.Errorf("got error %v, want code %s", serr, wantCode)
64	}
65}
66
67// PDML should be retried if the transaction is aborted.
68func TestPartitionedUpdate_Aborted(t *testing.T) {
69	t.Parallel()
70	ctx := context.Background()
71	server, client, teardown := setupMockedTestServer(t)
72	defer teardown()
73
74	server.TestSpanner.PutExecutionTime(MethodExecuteSql,
75		SimulatedExecutionTime{
76			Errors: []error{status.Error(codes.Aborted, "Transaction aborted")},
77		})
78	stmt := NewStatement(UpdateBarSetFoo)
79	rowCount, err := client.PartitionedUpdate(ctx, stmt)
80	if err != nil {
81		t.Fatal(err)
82	}
83	want := int64(UpdateBarSetFooRowCount)
84	if rowCount != want {
85		t.Errorf("Row count mismatch\ngot: %d\nwant: %d", rowCount, want)
86	}
87
88	gotReqs, err := shouldHaveReceived(server.TestSpanner, []interface{}{
89		&sppb.CreateSessionRequest{},
90		&sppb.BeginTransactionRequest{},
91		&sppb.ExecuteSqlRequest{},
92		&sppb.BeginTransactionRequest{},
93		&sppb.ExecuteSqlRequest{},
94		&sppb.DeleteSessionRequest{},
95	})
96	if err != nil {
97		t.Fatal(err)
98	}
99	id1 := gotReqs[2].(*sppb.ExecuteSqlRequest).Transaction.GetId()
100	id2 := gotReqs[4].(*sppb.ExecuteSqlRequest).Transaction.GetId()
101	if bytes.Equal(id1, id2) {
102		t.Errorf("same transaction used twice, expected two different transactions\ngot tx1: %q\ngot tx2: %q", id1, id2)
103	}
104}
105
106// Test that a deadline is respected by PDML, and that the session that was
107// created is also deleted, even though the update timed out.
108func TestPartitionedUpdate_WithDeadline(t *testing.T) {
109	t.Parallel()
110	logger := log.New(os.Stderr, "", log.LstdFlags)
111	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
112		SessionPoolConfig: DefaultSessionPoolConfig,
113		logger:            logger,
114	})
115	defer teardown()
116
117	ctx := context.Background()
118	ctx, cancel := context.WithDeadline(ctx, time.Now().Add(50*time.Millisecond))
119	defer cancel()
120	server.TestSpanner.PutExecutionTime(MethodExecuteSql,
121		SimulatedExecutionTime{
122			MinimumExecutionTime: 100 * time.Millisecond,
123		})
124	stmt := NewStatement(UpdateBarSetFoo)
125	// The following update will cause a 'Failed to delete session' warning to
126	// be logged. This is expected. To avoid spamming the log, we temporarily
127	// set the output to be discarded.
128	logger.SetOutput(ioutil.Discard)
129	_, err := client.PartitionedUpdate(ctx, stmt)
130	logger.SetOutput(os.Stderr)
131	if err == nil {
132		t.Fatalf("missing expected error")
133	}
134	wantCode := codes.DeadlineExceeded
135	if status.Code(err) != wantCode {
136		t.Fatalf("got error %v, want code %s", err, wantCode)
137	}
138}
139