1// Copyright 2021 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package managedwriter
16
17import (
18	"context"
19	"testing"
20
21	storagepb "google.golang.org/genproto/googleapis/cloud/bigquery/storage/v1"
22	"google.golang.org/grpc/codes"
23	"google.golang.org/grpc/status"
24	"google.golang.org/protobuf/proto"
25	"google.golang.org/protobuf/types/descriptorpb"
26)
27
28func TestManagedStream_OpenWithRetry(t *testing.T) {
29
30	testCases := []struct {
31		desc     string
32		errors   []error
33		wantFail bool
34	}{
35		{
36			desc:     "no error",
37			errors:   []error{nil},
38			wantFail: false,
39		},
40		{
41			desc: "transient failures",
42			errors: []error{
43				status.Errorf(codes.Unavailable, "try 1"),
44				status.Errorf(codes.Unavailable, "try 2"),
45				nil},
46			wantFail: false,
47		},
48		{
49			desc:     "terminal error",
50			errors:   []error{status.Errorf(codes.InvalidArgument, "bad args")},
51			wantFail: true,
52		},
53	}
54
55	for _, tc := range testCases {
56		ms := &ManagedStream{
57			ctx: context.Background(),
58			open: func(s string) (storagepb.BigQueryWrite_AppendRowsClient, error) {
59				if len(tc.errors) == 0 {
60					panic("out of errors")
61				}
62				err := tc.errors[0]
63				tc.errors = tc.errors[1:]
64				if err == nil {
65					return &testAppendRowsClient{}, nil
66				}
67				return nil, err
68			},
69		}
70		arc, ch, err := ms.openWithRetry()
71		if tc.wantFail && err == nil {
72			t.Errorf("case %s: wanted failure, got success", tc.desc)
73		}
74		if !tc.wantFail && err != nil {
75			t.Errorf("case %s: wanted success, got %v", tc.desc, err)
76		}
77		if err == nil {
78			if arc == nil {
79				t.Errorf("case %s: expected append client, got nil", tc.desc)
80			}
81			if ch == nil {
82				t.Errorf("case %s: expected channel, got nil", tc.desc)
83			}
84		}
85	}
86}
87
88func TestManagedStream_FirstAppendBehavior(t *testing.T) {
89
90	ctx := context.Background()
91
92	var testARC *testAppendRowsClient
93	testARC = &testAppendRowsClient{
94		recvF: func() (*storagepb.AppendRowsResponse, error) {
95			return &storagepb.AppendRowsResponse{
96				Response: &storagepb.AppendRowsResponse_AppendResult_{},
97			}, nil
98		},
99		sendF: func(req *storagepb.AppendRowsRequest) error {
100			testARC.requests = append(testARC.requests, req)
101			return nil
102		},
103	}
104	schema := &descriptorpb.DescriptorProto{
105		Name: proto.String("testDescriptor"),
106	}
107
108	ms := &ManagedStream{
109		ctx: ctx,
110		open: func(s string) (storagepb.BigQueryWrite_AppendRowsClient, error) {
111			testARC.openCount = testARC.openCount + 1
112			return testARC, nil
113		},
114		streamSettings: defaultStreamSettings(),
115		fc:             newFlowController(0, 0),
116	}
117	ms.streamSettings.streamID = "FOO"
118	ms.streamSettings.TraceID = "TRACE"
119	ms.schemaDescriptor = schema
120
121	fakeData := [][]byte{
122		[]byte("foo"),
123		[]byte("bar"),
124	}
125
126	wantReqs := 3
127
128	for i := 0; i < wantReqs; i++ {
129		_, err := ms.AppendRows(ctx, fakeData, NoStreamOffset)
130		if err != nil {
131			t.Errorf("AppendRows; %v", err)
132		}
133	}
134
135	if testARC.openCount != 1 {
136		t.Errorf("expected a single open, got %d", testARC.openCount)
137	}
138
139	if len(testARC.requests) != wantReqs {
140		t.Errorf("expected %d requests, got %d", wantReqs, len(testARC.requests))
141	}
142
143	for k, v := range testARC.requests {
144		if v == nil {
145			t.Errorf("request %d was nil", k)
146		}
147		if k == 0 {
148			if v.GetTraceId() == "" {
149				t.Errorf("expected TraceId on first request, was empty")
150			}
151			if v.GetWriteStream() == "" {
152				t.Errorf("expected WriteStream on first request, was empty")
153			}
154			if v.GetProtoRows().GetWriterSchema().GetProtoDescriptor() == nil {
155				t.Errorf("expected WriterSchema on first request, was empty")
156			}
157
158		} else {
159			if v.GetTraceId() != "" {
160				t.Errorf("expected no TraceID on request %d, got %s", k, v.GetTraceId())
161			}
162			if v.GetWriteStream() != "" {
163				t.Errorf("expected no WriteStream on request %d, got %s", k, v.GetWriteStream())
164			}
165			if v.GetProtoRows().GetWriterSchema().GetProtoDescriptor() != nil {
166				t.Errorf("expected test WriterSchema on request %d, got %s", k, v.GetProtoRows().GetWriterSchema().GetProtoDescriptor().String())
167			}
168		}
169	}
170}
171
172type testAppendRowsClient struct {
173	storagepb.BigQueryWrite_AppendRowsClient
174	openCount int
175	requests  []*storagepb.AppendRowsRequest
176	sendF     func(*storagepb.AppendRowsRequest) error
177	recvF     func() (*storagepb.AppendRowsResponse, error)
178}
179
180func (tarc *testAppendRowsClient) Send(req *storagepb.AppendRowsRequest) error {
181	return tarc.sendF(req)
182}
183
184func (tarc *testAppendRowsClient) Recv() (*storagepb.AppendRowsResponse, error) {
185	return tarc.recvF()
186}
187