1// Copyright 2021 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package adapt
16
17import (
18	"encoding/base64"
19	"fmt"
20	"strings"
21
22	storagepb "google.golang.org/genproto/googleapis/cloud/bigquery/storage/v1"
23	"google.golang.org/protobuf/proto"
24	"google.golang.org/protobuf/reflect/protodesc"
25	"google.golang.org/protobuf/reflect/protoreflect"
26	"google.golang.org/protobuf/types/descriptorpb"
27	"google.golang.org/protobuf/types/known/wrapperspb"
28)
29
30var bqModeToFieldLabelMapProto2 = map[storagepb.TableFieldSchema_Mode]descriptorpb.FieldDescriptorProto_Label{
31	storagepb.TableFieldSchema_NULLABLE: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL,
32	storagepb.TableFieldSchema_REPEATED: descriptorpb.FieldDescriptorProto_LABEL_REPEATED,
33	storagepb.TableFieldSchema_REQUIRED: descriptorpb.FieldDescriptorProto_LABEL_REQUIRED,
34}
35
36var bqModeToFieldLabelMapProto3 = map[storagepb.TableFieldSchema_Mode]descriptorpb.FieldDescriptorProto_Label{
37	storagepb.TableFieldSchema_NULLABLE: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL,
38	storagepb.TableFieldSchema_REPEATED: descriptorpb.FieldDescriptorProto_LABEL_REPEATED,
39	storagepb.TableFieldSchema_REQUIRED: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL,
40}
41
42func convertModeToLabel(mode storagepb.TableFieldSchema_Mode, useProto3 bool) *descriptorpb.FieldDescriptorProto_Label {
43	if useProto3 {
44		return bqModeToFieldLabelMapProto3[mode].Enum()
45	}
46	return bqModeToFieldLabelMapProto2[mode].Enum()
47}
48
49// Allows conversion between BQ schema type and FieldDescriptorProto's type.
50var bqTypeToFieldTypeMap = map[storagepb.TableFieldSchema_Type]descriptorpb.FieldDescriptorProto_Type{
51	storagepb.TableFieldSchema_BIGNUMERIC: descriptorpb.FieldDescriptorProto_TYPE_BYTES,
52	storagepb.TableFieldSchema_BOOL:       descriptorpb.FieldDescriptorProto_TYPE_BOOL,
53	storagepb.TableFieldSchema_BYTES:      descriptorpb.FieldDescriptorProto_TYPE_BYTES,
54	storagepb.TableFieldSchema_DATE:       descriptorpb.FieldDescriptorProto_TYPE_INT32,
55	storagepb.TableFieldSchema_DATETIME:   descriptorpb.FieldDescriptorProto_TYPE_INT64,
56	storagepb.TableFieldSchema_DOUBLE:     descriptorpb.FieldDescriptorProto_TYPE_DOUBLE,
57	storagepb.TableFieldSchema_GEOGRAPHY:  descriptorpb.FieldDescriptorProto_TYPE_STRING,
58	storagepb.TableFieldSchema_INT64:      descriptorpb.FieldDescriptorProto_TYPE_INT64,
59	storagepb.TableFieldSchema_NUMERIC:    descriptorpb.FieldDescriptorProto_TYPE_BYTES,
60	storagepb.TableFieldSchema_STRING:     descriptorpb.FieldDescriptorProto_TYPE_STRING,
61	storagepb.TableFieldSchema_STRUCT:     descriptorpb.FieldDescriptorProto_TYPE_MESSAGE,
62	storagepb.TableFieldSchema_TIME:       descriptorpb.FieldDescriptorProto_TYPE_INT64,
63	storagepb.TableFieldSchema_TIMESTAMP:  descriptorpb.FieldDescriptorProto_TYPE_INT64,
64}
65
66// For TableFieldSchema OPTIONAL mode, we use the wrapper types to allow for the
67// proper representation of NULL values, as proto3 semantics would just use default value.
68var bqTypeToWrapperMap = map[storagepb.TableFieldSchema_Type]string{
69	storagepb.TableFieldSchema_BIGNUMERIC: ".google.protobuf.BytesValue",
70	storagepb.TableFieldSchema_BOOL:       ".google.protobuf.BoolValue",
71	storagepb.TableFieldSchema_BYTES:      ".google.protobuf.BytesValue",
72	storagepb.TableFieldSchema_DATE:       ".google.protobuf.Int32Value",
73	storagepb.TableFieldSchema_DATETIME:   ".google.protobuf.Int64Value",
74	storagepb.TableFieldSchema_DOUBLE:     ".google.protobuf.DoubleValue",
75	storagepb.TableFieldSchema_GEOGRAPHY:  ".google.protobuf.StringValue",
76	storagepb.TableFieldSchema_INT64:      ".google.protobuf.Int64Value",
77	storagepb.TableFieldSchema_NUMERIC:    ".google.protobuf.BytesValue",
78	storagepb.TableFieldSchema_STRING:     ".google.protobuf.StringValue",
79	storagepb.TableFieldSchema_TIME:       ".google.protobuf.Int64Value",
80	storagepb.TableFieldSchema_TIMESTAMP:  ".google.protobuf.Int64Value",
81}
82
83// filename used by well known types proto
84var wellKnownTypesWrapperName = "google/protobuf/wrappers.proto"
85
86// dependencyCache is used to reduce the number of unique messages we generate by caching based on the tableschema.
87//
88// keys are based on the base64-encoded serialized tableschema value.
89type dependencyCache map[string]protoreflect.Descriptor
90
91func (dm dependencyCache) get(schema *storagepb.TableSchema) protoreflect.Descriptor {
92	if dm == nil {
93		return nil
94	}
95	b, err := proto.Marshal(schema)
96	if err != nil {
97		return nil
98	}
99	encoded := base64.StdEncoding.EncodeToString(b)
100	if desc, ok := (dm)[encoded]; ok {
101		return desc
102	}
103	return nil
104}
105
106func (dm dependencyCache) add(schema *storagepb.TableSchema, descriptor protoreflect.Descriptor) error {
107	if dm == nil {
108		return fmt.Errorf("cache is nil")
109	}
110	b, err := proto.Marshal(schema)
111	if err != nil {
112		return fmt.Errorf("failed to serialize tableschema: %v", err)
113	}
114	encoded := base64.StdEncoding.EncodeToString(b)
115	(dm)[encoded] = descriptor
116	return nil
117}
118
119// StorageSchemaToProto2Descriptor builds a protoreflect.Descriptor for a given table schema using proto2 syntax.
120func StorageSchemaToProto2Descriptor(inSchema *storagepb.TableSchema, scope string) (protoreflect.Descriptor, error) {
121	dc := make(dependencyCache)
122	// TODO: b/193064992 tracks support for wrapper types.  In the interim, disable wrapper usage.
123	return storageSchemaToDescriptorInternal(inSchema, scope, &dc, false)
124}
125
126// StorageSchemaToProto3Descriptor builds a protoreflect.Descriptor for a given table schema using proto3 syntax.
127//
128// NOTE: Currently the write API doesn't yet support proto3 behaviors (default value, wrapper types, etc), but this is provided for
129// completeness.
130func StorageSchemaToProto3Descriptor(inSchema *storagepb.TableSchema, scope string) (protoreflect.Descriptor, error) {
131	dc := make(dependencyCache)
132	return storageSchemaToDescriptorInternal(inSchema, scope, &dc, true)
133}
134
135// internal implementation of the conversion code.
136func storageSchemaToDescriptorInternal(inSchema *storagepb.TableSchema, scope string, cache *dependencyCache, useProto3 bool) (protoreflect.Descriptor, error) {
137	if inSchema == nil {
138		return nil, newConversionError(scope, fmt.Errorf("no input schema was provided"))
139	}
140
141	var fields []*descriptorpb.FieldDescriptorProto
142	var deps []protoreflect.FileDescriptor
143	var fNumber int32
144
145	for _, f := range inSchema.GetFields() {
146		fNumber = fNumber + 1
147		currentScope := fmt.Sprintf("%s__%s", scope, f.GetName())
148		// If we're dealing with a STRUCT type, we must deal with sub messages.
149		// As multiple submessages may share the same type definition, we use a dependency cache
150		// and interrogate it / populate it as we're going.
151		if f.Type == storagepb.TableFieldSchema_STRUCT {
152			foundDesc := cache.get(&storagepb.TableSchema{Fields: f.GetFields()})
153			if foundDesc != nil {
154				// check to see if we already have this in current dependency list
155				haveDep := false
156				for _, curDep := range deps {
157					if foundDesc.ParentFile().FullName() == curDep.FullName() {
158						haveDep = true
159						break
160					}
161				}
162				// if dep is missing, add to current dependencies
163				if !haveDep {
164					deps = append(deps, foundDesc.ParentFile())
165				}
166				// construct field descriptor for the message
167				fdp, err := tableFieldSchemaToFieldDescriptorProto(f, fNumber, string(foundDesc.FullName()), useProto3)
168				if err != nil {
169					return nil, newConversionError(scope, fmt.Errorf("couldn't convert field to FieldDescriptorProto: %v", err))
170				}
171				fields = append(fields, fdp)
172			} else {
173				// Wrap the current struct's fields in a TableSchema outer message, and then build the submessage.
174				ts := &storagepb.TableSchema{
175					Fields: f.GetFields(),
176				}
177				desc, err := storageSchemaToDescriptorInternal(ts, currentScope, cache, useProto3)
178				if err != nil {
179					return nil, newConversionError(currentScope, fmt.Errorf("couldn't convert message: %v", err))
180				}
181				// Now that we have the submessage definition, we append it both to the local dependencies, as well
182				// as inserting it into the cache for possible reuse elsewhere.
183				deps = append(deps, desc.ParentFile())
184				err = cache.add(ts, desc)
185				if err != nil {
186					return nil, newConversionError(currentScope, fmt.Errorf("failed to add descriptor to dependency cache: %v", err))
187				}
188				fdp, err := tableFieldSchemaToFieldDescriptorProto(f, fNumber, currentScope, useProto3)
189				if err != nil {
190					return nil, newConversionError(currentScope, fmt.Errorf("couldn't compute field schema : %v", err))
191				}
192				fields = append(fields, fdp)
193			}
194		} else {
195			fd, err := tableFieldSchemaToFieldDescriptorProto(f, fNumber, currentScope, useProto3)
196			if err != nil {
197				return nil, newConversionError(currentScope, err)
198			}
199			fields = append(fields, fd)
200		}
201	}
202	// Start constructing a DescriptorProto.
203	dp := &descriptorpb.DescriptorProto{
204		Name:  proto.String(scope),
205		Field: fields,
206	}
207
208	// Use the local dependencies to generate a list of filenames.
209	depNames := []string{
210		wellKnownTypesWrapperName,
211	}
212	for _, d := range deps {
213		depNames = append(depNames, d.ParentFile().Path())
214	}
215
216	// Now, construct a FileDescriptorProto.
217	fdp := &descriptorpb.FileDescriptorProto{
218		MessageType: []*descriptorpb.DescriptorProto{dp},
219		Name:        proto.String(fmt.Sprintf("%s.proto", scope)),
220		Syntax:      proto.String("proto3"),
221		Dependency:  depNames,
222	}
223	if !useProto3 {
224		fdp.Syntax = proto.String("proto2")
225	}
226
227	// We'll need a FileDescriptorSet as we have a FileDescriptorProto for the current
228	// descriptor we're building, but we need to include all the referenced dependencies.
229	fds := &descriptorpb.FileDescriptorSet{
230		File: []*descriptorpb.FileDescriptorProto{
231			fdp,
232			protodesc.ToFileDescriptorProto(wrapperspb.File_google_protobuf_wrappers_proto),
233		},
234	}
235	for _, d := range deps {
236		fds.File = append(fds.File, protodesc.ToFileDescriptorProto(d))
237	}
238
239	// Load the set into a registry, then interrogate it for the descriptor corresponding to the top level message.
240	files, err := protodesc.NewFiles(fds)
241	if err != nil {
242		return nil, err
243	}
244	return files.FindDescriptorByName(protoreflect.FullName(scope))
245}
246
247// tableFieldSchemaToFieldDescriptorProto builds individual field descriptors for a proto message.
248//
249// For proto3, in cases where the mode is nullable we use the well known wrapper types.
250// For proto2, we propagate the mode->label annotation as expected.
251//
252// Messages are always nullable, and repeated fields are as well.
253func tableFieldSchemaToFieldDescriptorProto(field *storagepb.TableFieldSchema, idx int32, scope string, useProto3 bool) (*descriptorpb.FieldDescriptorProto, error) {
254	name := strings.ToLower(field.GetName())
255	if field.GetType() == storagepb.TableFieldSchema_STRUCT {
256		return &descriptorpb.FieldDescriptorProto{
257			Name:     proto.String(name),
258			Number:   proto.Int32(idx),
259			TypeName: proto.String(scope),
260			Label:    convertModeToLabel(field.GetMode(), useProto3),
261		}, nil
262	}
263
264	// For (REQUIRED||REPEATED) fields for proto3, or all cases for proto2, we can use the expected scalar types.
265	if field.GetMode() != storagepb.TableFieldSchema_NULLABLE || !useProto3 {
266		return &descriptorpb.FieldDescriptorProto{
267			Name:   proto.String(name),
268			Number: proto.Int32(idx),
269			Type:   bqTypeToFieldTypeMap[field.GetType()].Enum(),
270			Label:  convertModeToLabel(field.GetMode(), useProto3),
271		}, nil
272	}
273	// For NULLABLE proto3 fields, use a wrapper type.
274	return &descriptorpb.FieldDescriptorProto{
275		Name:     proto.String(name),
276		Number:   proto.Int32(idx),
277		Type:     descriptorpb.FieldDescriptorProto_TYPE_MESSAGE.Enum(),
278		TypeName: proto.String(bqTypeToWrapperMap[field.GetType()]),
279		Label:    descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
280	}, nil
281}
282
283// NormalizeDescriptor builds a self-contained DescriptorProto suitable for communicating schema
284// information with the BigQuery Storage write API.  It's primarily used for cases where users are
285// interested in sending data using a predefined protocol buffer message.
286//
287// The storage API accepts a single DescriptorProto for decoding message data.  In many cases, a message
288// is comprised of multiple independent messages, from the same .proto file or from multiple sources.  Rather
289// than being forced to communicate all these messages independently, what this method does is rewrite the
290// DescriptorProto to inline all messages as nested submessages.  As the backend only cares about the types
291// and not the namespaces when decoding, this is sufficient for the needs of the API's representation.
292//
293// In addition to nesting messages, this method also handles some encapsulation of enum types to avoid possible
294// conflicts due to ambiguities.
295func NormalizeDescriptor(in protoreflect.MessageDescriptor) (*descriptorpb.DescriptorProto, error) {
296	return normalizeDescriptorInternal(in, newStringSet(), newStringSet(), newStringSet(), nil)
297}
298
299func normalizeDescriptorInternal(in protoreflect.MessageDescriptor, visitedTypes, enumTypes, structTypes *stringSet, root *descriptorpb.DescriptorProto) (*descriptorpb.DescriptorProto, error) {
300	if in == nil {
301		return nil, fmt.Errorf("no messagedescriptor provided")
302	}
303	resultDP := &descriptorpb.DescriptorProto{}
304	if root == nil {
305		root = resultDP
306	}
307	fullProtoName := string(in.FullName())
308	resultDP.Name = proto.String(normalizeName(fullProtoName))
309	visitedTypes.add(fullProtoName)
310	for i := 0; i < in.Fields().Len(); i++ {
311		inField := in.Fields().Get(i)
312		resultFDP := protodesc.ToFieldDescriptorProto(inField)
313		if inField.Kind() == protoreflect.MessageKind || inField.Kind() == protoreflect.GroupKind {
314			// Handle fields that reference messages.
315			// Groups are a proto2-ism which predated nested messages.
316			msgFullName := string(inField.Message().FullName())
317			if !skipNormalization(msgFullName) {
318				// for everything but well known types, normalize.
319				normName := normalizeName(string(msgFullName))
320				if structTypes.contains(msgFullName) {
321					resultFDP.TypeName = proto.String(normName)
322				} else {
323					if visitedTypes.contains(msgFullName) {
324						return nil, fmt.Errorf("recursize type not supported: %s", inField.FullName())
325					}
326					visitedTypes.add(msgFullName)
327					dp, err := normalizeDescriptorInternal(inField.Message(), visitedTypes, enumTypes, structTypes, root)
328					if err != nil {
329						return nil, fmt.Errorf("error converting message %s: %v", inField.FullName(), err)
330					}
331					root.NestedType = append(root.NestedType, dp)
332					visitedTypes.delete(msgFullName)
333					lastNested := root.GetNestedType()[len(root.GetNestedType())-1].GetName()
334					resultFDP.TypeName = proto.String(lastNested)
335				}
336			}
337		}
338		if inField.Kind() == protoreflect.EnumKind {
339			// For enums, in order to avoid value conflict, we will always define
340			// a enclosing struct called enum_full_name_E that includes the actual
341			// enum.
342			enumFullName := string(inField.Enum().FullName())
343			enclosingTypeName := normalizeName(enumFullName) + "_E"
344			enumName := string(inField.Enum().Name())
345			actualFullName := fmt.Sprintf("%s.%s", enclosingTypeName, enumName)
346			if enumTypes.contains(enumFullName) {
347				resultFDP.TypeName = proto.String(actualFullName)
348			} else {
349				enumDP := protodesc.ToEnumDescriptorProto(inField.Enum())
350				enumDP.Name = proto.String(enumName)
351				resultDP.NestedType = append(resultDP.NestedType, &descriptorpb.DescriptorProto{
352					Name:     proto.String(enclosingTypeName),
353					EnumType: []*descriptorpb.EnumDescriptorProto{enumDP},
354				})
355				resultFDP.TypeName = proto.String(actualFullName)
356				enumTypes.add(enumFullName)
357			}
358		}
359		resultDP.Field = append(resultDP.Field, resultFDP)
360	}
361	structTypes.add(fullProtoName)
362	return resultDP, nil
363}
364
365type stringSet struct {
366	m map[string]struct{}
367}
368
369func (s *stringSet) contains(k string) bool {
370	_, ok := s.m[k]
371	return ok
372}
373
374func (s *stringSet) add(k string) {
375	s.m[k] = struct{}{}
376}
377
378func (s *stringSet) delete(k string) {
379	delete(s.m, k)
380}
381
382func newStringSet() *stringSet {
383	return &stringSet{
384		m: make(map[string]struct{}),
385	}
386}
387
388func normalizeName(in string) string {
389	return strings.Replace(in, ".", "_", -1)
390}
391
392// these types don't get normalized into the fully-contained structure.
393var normalizationSkipList = []string{
394	/*
395		TODO: when backend supports resolving well known types, this list should be enabled.
396		"google.protobuf.DoubleValue",
397		"google.protobuf.FloatValue",
398		"google.protobuf.Int64Value",
399		"google.protobuf.UInt64Value",
400		"google.protobuf.Int32Value",
401		"google.protobuf.Uint32Value",
402		"google.protobuf.BoolValue",
403		"google.protobuf.StringValue",
404		"google.protobuf.BytesValue",
405	*/
406}
407
408func skipNormalization(fullName string) bool {
409	for _, v := range normalizationSkipList {
410		if v == fullName {
411			return true
412		}
413	}
414	return false
415}
416