1// Copyright 2021 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package managedwriter
16
17import (
18	"context"
19	"fmt"
20	"runtime"
21	"strings"
22
23	storage "cloud.google.com/go/bigquery/storage/apiv1beta2"
24	"cloud.google.com/go/internal/detect"
25	"github.com/googleapis/gax-go/v2"
26	"google.golang.org/api/option"
27	storagepb "google.golang.org/genproto/googleapis/cloud/bigquery/storage/v1beta2"
28	"google.golang.org/grpc"
29	"google.golang.org/grpc/metadata"
30)
31
32// DetectProjectID is a sentinel value that instructs NewClient to detect the
33// project ID. It is given in place of the projectID argument. NewClient will
34// use the project ID from the given credentials or the default credentials
35// (https://developers.google.com/accounts/docs/application-default-credentials)
36// if no credentials were provided. When providing credentials, not all
37// options will allow NewClient to extract the project ID. Specifically a JWT
38// does not have the project ID encoded.
39const DetectProjectID = "*detect-project-id*"
40
41// Client is a managed BigQuery Storage write client scoped to a single project.
42type Client struct {
43	rawClient *storage.BigQueryWriteClient
44	projectID string
45}
46
47// NewClient instantiates a new client.
48func NewClient(ctx context.Context, projectID string, opts ...option.ClientOption) (c *Client, err error) {
49	numConns := runtime.GOMAXPROCS(0)
50	if numConns > 4 {
51		numConns = 4
52	}
53	o := []option.ClientOption{
54		option.WithGRPCConnectionPool(numConns),
55	}
56	o = append(o, opts...)
57
58	rawClient, err := storage.NewBigQueryWriteClient(ctx, o...)
59	if err != nil {
60		return nil, err
61	}
62
63	// Handle project autodetection.
64	projectID, err = detect.ProjectID(ctx, projectID, "", opts...)
65	if err != nil {
66		return nil, err
67	}
68
69	return &Client{
70		rawClient: rawClient,
71		projectID: projectID,
72	}, nil
73}
74
75// Close releases resources held by the client.
76func (c *Client) Close() error {
77	// TODO: consider if we should propagate a cancellation from client to all associated managed streams.
78	if c.rawClient == nil {
79		return fmt.Errorf("already closed")
80	}
81	c.rawClient.Close()
82	c.rawClient = nil
83	return nil
84}
85
86// NewManagedStream establishes a new managed stream for appending data into a table.
87//
88// Context here is retained for use by the underlying streaming connections the managed stream may create.
89func (c *Client) NewManagedStream(ctx context.Context, opts ...WriterOption) (*ManagedStream, error) {
90	return c.buildManagedStream(ctx, c.rawClient.AppendRows, false, opts...)
91}
92
93func (c *Client) buildManagedStream(ctx context.Context, streamFunc streamClientFunc, skipSetup bool, opts ...WriterOption) (*ManagedStream, error) {
94	ctx, cancel := context.WithCancel(ctx)
95
96	ms := &ManagedStream{
97		streamSettings: defaultStreamSettings(),
98		c:              c,
99		ctx:            ctx,
100		cancel:         cancel,
101		open: func(streamID string) (storagepb.BigQueryWrite_AppendRowsClient, error) {
102			arc, err := streamFunc(
103				// Bidi Streaming doesn't append stream ID as request metadata, so we must inject it manually.
104				metadata.AppendToOutgoingContext(ctx, "x-goog-request-params", fmt.Sprintf("write_stream=%s", streamID)),
105				gax.WithGRPCOptions(grpc.MaxCallRecvMsgSize(10*1024*1024)))
106			if err != nil {
107				return nil, err
108			}
109			return arc, nil
110		},
111	}
112
113	// apply writer options
114	for _, opt := range opts {
115		opt(ms)
116	}
117
118	// skipSetup exists for testing scenarios.
119	if !skipSetup {
120		if err := c.validateOptions(ctx, ms); err != nil {
121			return nil, err
122		}
123
124		if ms.streamSettings.streamID == "" {
125			// not instantiated with a stream, construct one.
126			streamName := fmt.Sprintf("%s/_default", ms.destinationTable)
127			if ms.streamSettings.streamType != DefaultStream {
128				// For everything but a default stream, we create a new stream on behalf of the user.
129				req := &storagepb.CreateWriteStreamRequest{
130					Parent: ms.destinationTable,
131					WriteStream: &storagepb.WriteStream{
132						Type: streamTypeToEnum(ms.streamSettings.streamType),
133					}}
134				resp, err := ms.c.rawClient.CreateWriteStream(ctx, req)
135				if err != nil {
136					return nil, fmt.Errorf("couldn't create write stream: %v", err)
137				}
138				streamName = resp.GetName()
139			}
140			ms.streamSettings.streamID = streamName
141		}
142	}
143	if ms.streamSettings != nil {
144		if ms.ctx != nil {
145			ms.ctx = keyContextWithTags(ms.ctx, ms.streamSettings.streamID, ms.streamSettings.dataOrigin)
146		}
147		ms.fc = newFlowController(ms.streamSettings.MaxInflightRequests, ms.streamSettings.MaxInflightBytes)
148	} else {
149		ms.fc = newFlowController(0, 0)
150	}
151	return ms, nil
152}
153
154// validateOptions is used to validate that we received a sane/compatible set of WriterOptions
155// for constructing a new managed stream.
156func (c *Client) validateOptions(ctx context.Context, ms *ManagedStream) error {
157	if ms == nil {
158		return fmt.Errorf("no managed stream definition")
159	}
160	if ms.streamSettings.streamID != "" {
161		// User supplied a stream, we need to verify it exists.
162		info, err := c.getWriteStream(ctx, ms.streamSettings.streamID)
163		if err != nil {
164			return fmt.Errorf("a streamname was specified, but lookup of stream failed: %v", err)
165		}
166		// update type and destination based on stream metadata
167		ms.streamSettings.streamType = StreamType(info.Type.String())
168		ms.destinationTable = TableParentFromStreamName(ms.streamSettings.streamID)
169	}
170	if ms.destinationTable == "" {
171		return fmt.Errorf("no destination table specified")
172	}
173	// we could auto-select DEFAULT here, but let's force users to be specific for now.
174	if ms.StreamType() == "" {
175		return fmt.Errorf("stream type wasn't specified")
176	}
177	return nil
178}
179
180// BatchCommit is used to commit one or more PendingStream streams belonging to the same table
181// as a single transaction.  Streams must be finalized before committing.
182//
183// Format of the parentTable is: projects/{project}/datasets/{dataset}/tables/{table} and the utility
184// function TableParentFromStreamName can be used to derive this from a Stream's name.
185//
186// If the returned response contains stream errors, this indicates that the batch commit failed and no data was
187// committed.
188//
189// TODO: currently returns the raw response.  Determine how we want to surface StreamErrors.
190func (c *Client) BatchCommit(ctx context.Context, parentTable string, streamNames []string) (*storagepb.BatchCommitWriteStreamsResponse, error) {
191
192	// determine table from first streamName, as all must share the same table.
193	if len(streamNames) <= 0 {
194		return nil, fmt.Errorf("no streamnames provided")
195	}
196
197	req := &storagepb.BatchCommitWriteStreamsRequest{
198		Parent:       TableParentFromStreamName(streamNames[0]),
199		WriteStreams: streamNames,
200	}
201	return c.rawClient.BatchCommitWriteStreams(ctx, req)
202}
203
204// getWriteStream returns information about a given write stream.
205//
206// It's primarily used for setup validation, and not exposed directly to end users.
207func (c *Client) getWriteStream(ctx context.Context, streamName string) (*storagepb.WriteStream, error) {
208	req := &storagepb.GetWriteStreamRequest{
209		Name: streamName,
210	}
211	return c.rawClient.GetWriteStream(ctx, req)
212}
213
214// TableParentFromStreamName is a utility function for extracting the parent table
215// prefix from a stream name.  When an invalid stream ID is passed, this simply returns
216// the original stream name.
217func TableParentFromStreamName(streamName string) string {
218	// Stream IDs have the following prefix:
219	// projects/{project}/datasets/{dataset}/tables/{table}/blah
220	parts := strings.SplitN(streamName, "/", 7)
221	if len(parts) < 7 {
222		// invalid; just pass back the input
223		return streamName
224	}
225	return strings.Join(parts[:6], "/")
226}
227