1package runtime
2
3import (
4	"encoding/base64"
5	"fmt"
6	"net/url"
7	"reflect"
8	"regexp"
9	"strconv"
10	"strings"
11	"time"
12
13	"github.com/golang/protobuf/proto"
14	"github.com/grpc-ecosystem/grpc-gateway/utilities"
15	"google.golang.org/grpc/grpclog"
16)
17
18// PopulateQueryParameters populates "values" into "msg".
19// A value is ignored if its key starts with one of the elements in "filter".
20func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
21	for key, values := range values {
22		re, err := regexp.Compile("^(.*)\\[(.*)\\]$")
23		if err != nil {
24			return err
25		}
26		match := re.FindStringSubmatch(key)
27		if len(match) == 3 {
28			key = match[1]
29			values = append([]string{match[2]}, values...)
30		}
31		fieldPath := strings.Split(key, ".")
32		if filter.HasCommonPrefix(fieldPath) {
33			continue
34		}
35		if err := populateFieldValueFromPath(msg, fieldPath, values); err != nil {
36			return err
37		}
38	}
39	return nil
40}
41
42// PopulateFieldFromPath sets a value in a nested Protobuf structure.
43// It instantiates missing protobuf fields as it goes.
44func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
45	fieldPath := strings.Split(fieldPathString, ".")
46	return populateFieldValueFromPath(msg, fieldPath, []string{value})
47}
48
49func populateFieldValueFromPath(msg proto.Message, fieldPath []string, values []string) error {
50	m := reflect.ValueOf(msg)
51	if m.Kind() != reflect.Ptr {
52		return fmt.Errorf("unexpected type %T: %v", msg, msg)
53	}
54	var props *proto.Properties
55	m = m.Elem()
56	for i, fieldName := range fieldPath {
57		isLast := i == len(fieldPath)-1
58		if !isLast && m.Kind() != reflect.Struct {
59			return fmt.Errorf("non-aggregate type in the mid of path: %s", strings.Join(fieldPath, "."))
60		}
61		var f reflect.Value
62		var err error
63		f, props, err = fieldByProtoName(m, fieldName)
64		if err != nil {
65			return err
66		} else if !f.IsValid() {
67			grpclog.Infof("field not found in %T: %s", msg, strings.Join(fieldPath, "."))
68			return nil
69		}
70
71		switch f.Kind() {
72		case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, reflect.String, reflect.Uint32, reflect.Uint64:
73			if !isLast {
74				return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], "."))
75			}
76			m = f
77		case reflect.Slice:
78			if !isLast {
79				return fmt.Errorf("unexpected repeated field in %s", strings.Join(fieldPath, "."))
80			}
81			// Handle []byte
82			if f.Type().Elem().Kind() == reflect.Uint8 {
83				m = f
84				break
85			}
86			return populateRepeatedField(f, values, props)
87		case reflect.Ptr:
88			if f.IsNil() {
89				m = reflect.New(f.Type().Elem())
90				f.Set(m.Convert(f.Type()))
91			}
92			m = f.Elem()
93			continue
94		case reflect.Struct:
95			m = f
96			continue
97		case reflect.Map:
98			if !isLast {
99				return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], "."))
100			}
101			return populateMapField(f, values, props)
102		default:
103			return fmt.Errorf("unexpected type %s in %T", f.Type(), msg)
104		}
105	}
106	switch len(values) {
107	case 0:
108		return fmt.Errorf("no value of field: %s", strings.Join(fieldPath, "."))
109	case 1:
110	default:
111		grpclog.Infof("too many field values: %s", strings.Join(fieldPath, "."))
112	}
113	return populateField(m, values[0], props)
114}
115
116// fieldByProtoName looks up a field whose corresponding protobuf field name is "name".
117// "m" must be a struct value. It returns zero reflect.Value if no such field found.
118func fieldByProtoName(m reflect.Value, name string) (reflect.Value, *proto.Properties, error) {
119	props := proto.GetProperties(m.Type())
120
121	// look up field name in oneof map
122	if op, ok := props.OneofTypes[name]; ok {
123		v := reflect.New(op.Type.Elem())
124		field := m.Field(op.Field)
125		if !field.IsNil() {
126			return reflect.Value{}, nil, fmt.Errorf("field already set for %s oneof", props.Prop[op.Field].OrigName)
127		}
128		field.Set(v)
129		return v.Elem().Field(0), op.Prop, nil
130	}
131
132	for _, p := range props.Prop {
133		if p.OrigName == name {
134			return m.FieldByName(p.Name), p, nil
135		}
136		if p.JSONName == name {
137			return m.FieldByName(p.Name), p, nil
138		}
139	}
140	return reflect.Value{}, nil, nil
141}
142
143func populateMapField(f reflect.Value, values []string, props *proto.Properties) error {
144	if len(values) != 2 {
145		return fmt.Errorf("more than one value provided for key %s in map %s", values[0], props.Name)
146	}
147
148	key, value := values[0], values[1]
149	keyType := f.Type().Key()
150	valueType := f.Type().Elem()
151	if f.IsNil() {
152		f.Set(reflect.MakeMap(f.Type()))
153	}
154
155	keyConv, ok := convFromType[keyType.Kind()]
156	if !ok {
157		return fmt.Errorf("unsupported key type %s in map %s", keyType, props.Name)
158	}
159	valueConv, ok := convFromType[valueType.Kind()]
160	if !ok {
161		return fmt.Errorf("unsupported value type %s in map %s", valueType, props.Name)
162	}
163
164	keyV := keyConv.Call([]reflect.Value{reflect.ValueOf(key)})
165	if err := keyV[1].Interface(); err != nil {
166		return err.(error)
167	}
168	valueV := valueConv.Call([]reflect.Value{reflect.ValueOf(value)})
169	if err := valueV[1].Interface(); err != nil {
170		return err.(error)
171	}
172
173	f.SetMapIndex(keyV[0].Convert(keyType), valueV[0].Convert(valueType))
174
175	return nil
176}
177
178func populateRepeatedField(f reflect.Value, values []string, props *proto.Properties) error {
179	elemType := f.Type().Elem()
180
181	// is the destination field a slice of an enumeration type?
182	if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil {
183		return populateFieldEnumRepeated(f, values, enumValMap)
184	}
185
186	conv, ok := convFromType[elemType.Kind()]
187	if !ok {
188		return fmt.Errorf("unsupported field type %s", elemType)
189	}
190	f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type()))
191	for i, v := range values {
192		result := conv.Call([]reflect.Value{reflect.ValueOf(v)})
193		if err := result[1].Interface(); err != nil {
194			return err.(error)
195		}
196		f.Index(i).Set(result[0].Convert(f.Index(i).Type()))
197	}
198	return nil
199}
200
201func populateField(f reflect.Value, value string, props *proto.Properties) error {
202	i := f.Addr().Interface()
203
204	// Handle protobuf well known types
205	type wkt interface {
206		XXX_WellKnownType() string
207	}
208	if wkt, ok := i.(wkt); ok {
209		switch wkt.XXX_WellKnownType() {
210		case "Timestamp":
211			if value == "null" {
212				f.Field(0).SetInt(0)
213				f.Field(1).SetInt(0)
214				return nil
215			}
216
217			t, err := time.Parse(time.RFC3339Nano, value)
218			if err != nil {
219				return fmt.Errorf("bad Timestamp: %v", err)
220			}
221			f.Field(0).SetInt(int64(t.Unix()))
222			f.Field(1).SetInt(int64(t.Nanosecond()))
223			return nil
224		case "Duration":
225			if value == "null" {
226				f.Field(0).SetInt(0)
227				f.Field(1).SetInt(0)
228				return nil
229			}
230			d, err := time.ParseDuration(value)
231			if err != nil {
232				return fmt.Errorf("bad Duration: %v", err)
233			}
234
235			ns := d.Nanoseconds()
236			s := ns / 1e9
237			ns %= 1e9
238			f.Field(0).SetInt(s)
239			f.Field(1).SetInt(ns)
240			return nil
241		case "DoubleValue":
242			fallthrough
243		case "FloatValue":
244			float64Val, err := strconv.ParseFloat(value, 64)
245			if err != nil {
246				return fmt.Errorf("bad DoubleValue: %s", value)
247			}
248			f.Field(0).SetFloat(float64Val)
249			return nil
250		case "Int64Value":
251			fallthrough
252		case "Int32Value":
253			int64Val, err := strconv.ParseInt(value, 10, 64)
254			if err != nil {
255				return fmt.Errorf("bad DoubleValue: %s", value)
256			}
257			f.Field(0).SetInt(int64Val)
258			return nil
259		case "UInt64Value":
260			fallthrough
261		case "UInt32Value":
262			uint64Val, err := strconv.ParseUint(value, 10, 64)
263			if err != nil {
264				return fmt.Errorf("bad DoubleValue: %s", value)
265			}
266			f.Field(0).SetUint(uint64Val)
267			return nil
268		case "BoolValue":
269			if value == "true" {
270				f.Field(0).SetBool(true)
271			} else if value == "false" {
272				f.Field(0).SetBool(false)
273			} else {
274				return fmt.Errorf("bad BoolValue: %s", value)
275			}
276			return nil
277		case "StringValue":
278			f.Field(0).SetString(value)
279			return nil
280		case "BytesValue":
281			bytesVal, err := base64.StdEncoding.DecodeString(value)
282			if err != nil {
283				return fmt.Errorf("bad BytesValue: %s", value)
284			}
285			f.Field(0).SetBytes(bytesVal)
286			return nil
287		}
288	}
289
290	// Handle google well known types
291	if gwkt, ok := i.(proto.Message); ok {
292		switch proto.MessageName(gwkt) {
293		case "google.protobuf.FieldMask":
294			p := f.Field(0)
295			for _, v := range strings.Split(value, ",") {
296				if v != "" {
297					p.Set(reflect.Append(p, reflect.ValueOf(v)))
298				}
299			}
300			return nil
301		}
302	}
303
304	// Handle Time and Duration stdlib types
305	switch t := i.(type) {
306	case *time.Time:
307		pt, err := time.Parse(time.RFC3339Nano, value)
308		if err != nil {
309			return fmt.Errorf("bad Timestamp: %v", err)
310		}
311		*t = pt
312		return nil
313	case *time.Duration:
314		d, err := time.ParseDuration(value)
315		if err != nil {
316			return fmt.Errorf("bad Duration: %v", err)
317		}
318		*t = d
319		return nil
320	}
321
322	// is the destination field an enumeration type?
323	if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil {
324		return populateFieldEnum(f, value, enumValMap)
325	}
326
327	conv, ok := convFromType[f.Kind()]
328	if !ok {
329		return fmt.Errorf("field type %T is not supported in query parameters", i)
330	}
331	result := conv.Call([]reflect.Value{reflect.ValueOf(value)})
332	if err := result[1].Interface(); err != nil {
333		return err.(error)
334	}
335	f.Set(result[0].Convert(f.Type()))
336	return nil
337}
338
339func convertEnum(value string, t reflect.Type, enumValMap map[string]int32) (reflect.Value, error) {
340	// see if it's an enumeration string
341	if enumVal, ok := enumValMap[value]; ok {
342		return reflect.ValueOf(enumVal).Convert(t), nil
343	}
344
345	// check for an integer that matches an enumeration value
346	eVal, err := strconv.Atoi(value)
347	if err != nil {
348		return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t)
349	}
350	for _, v := range enumValMap {
351		if v == int32(eVal) {
352			return reflect.ValueOf(eVal).Convert(t), nil
353		}
354	}
355	return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t)
356}
357
358func populateFieldEnum(f reflect.Value, value string, enumValMap map[string]int32) error {
359	cval, err := convertEnum(value, f.Type(), enumValMap)
360	if err != nil {
361		return err
362	}
363	f.Set(cval)
364	return nil
365}
366
367func populateFieldEnumRepeated(f reflect.Value, values []string, enumValMap map[string]int32) error {
368	elemType := f.Type().Elem()
369	f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type()))
370	for i, v := range values {
371		result, err := convertEnum(v, elemType, enumValMap)
372		if err != nil {
373			return err
374		}
375		f.Index(i).Set(result)
376	}
377	return nil
378}
379
380var (
381	convFromType = map[reflect.Kind]reflect.Value{
382		reflect.String:  reflect.ValueOf(String),
383		reflect.Bool:    reflect.ValueOf(Bool),
384		reflect.Float64: reflect.ValueOf(Float64),
385		reflect.Float32: reflect.ValueOf(Float32),
386		reflect.Int64:   reflect.ValueOf(Int64),
387		reflect.Int32:   reflect.ValueOf(Int32),
388		reflect.Uint64:  reflect.ValueOf(Uint64),
389		reflect.Uint32:  reflect.ValueOf(Uint32),
390		reflect.Slice:   reflect.ValueOf(Bytes),
391	}
392)
393