1package runtime
2
3import (
4	"fmt"
5	"net/url"
6	"reflect"
7	"strconv"
8	"strings"
9	"time"
10
11	"github.com/golang/protobuf/proto"
12	"github.com/grpc-ecosystem/grpc-gateway/utilities"
13	"google.golang.org/grpc/grpclog"
14)
15
16// PopulateQueryParameters populates "values" into "msg".
17// A value is ignored if its key starts with one of the elements in "filter".
18func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
19	for key, values := range values {
20		fieldPath := strings.Split(key, ".")
21		if filter.HasCommonPrefix(fieldPath) {
22			continue
23		}
24		if err := populateFieldValueFromPath(msg, fieldPath, values); err != nil {
25			return err
26		}
27	}
28	return nil
29}
30
31// PopulateFieldFromPath sets a value in a nested Protobuf structure.
32// It instantiates missing protobuf fields as it goes.
33func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
34	fieldPath := strings.Split(fieldPathString, ".")
35	return populateFieldValueFromPath(msg, fieldPath, []string{value})
36}
37
38func populateFieldValueFromPath(msg proto.Message, fieldPath []string, values []string) error {
39	m := reflect.ValueOf(msg)
40	if m.Kind() != reflect.Ptr {
41		return fmt.Errorf("unexpected type %T: %v", msg, msg)
42	}
43	var props *proto.Properties
44	m = m.Elem()
45	for i, fieldName := range fieldPath {
46		isLast := i == len(fieldPath)-1
47		if !isLast && m.Kind() != reflect.Struct {
48			return fmt.Errorf("non-aggregate type in the mid of path: %s", strings.Join(fieldPath, "."))
49		}
50		var f reflect.Value
51		var err error
52		f, props, err = fieldByProtoName(m, fieldName)
53		if err != nil {
54			return err
55		} else if !f.IsValid() {
56			grpclog.Printf("field not found in %T: %s", msg, strings.Join(fieldPath, "."))
57			return nil
58		}
59
60		switch f.Kind() {
61		case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, reflect.String, reflect.Uint32, reflect.Uint64:
62			if !isLast {
63				return fmt.Errorf("unexpected nested field %s in %s", fieldPath[i+1], strings.Join(fieldPath[:i+1], "."))
64			}
65			m = f
66		case reflect.Slice:
67			// TODO(yugui) Support []byte
68			if !isLast {
69				return fmt.Errorf("unexpected repeated field in %s", strings.Join(fieldPath, "."))
70			}
71			return populateRepeatedField(f, values, props)
72		case reflect.Ptr:
73			if f.IsNil() {
74				m = reflect.New(f.Type().Elem())
75				f.Set(m.Convert(f.Type()))
76			}
77			m = f.Elem()
78			continue
79		case reflect.Struct:
80			m = f
81			continue
82		default:
83			return fmt.Errorf("unexpected type %s in %T", f.Type(), msg)
84		}
85	}
86	switch len(values) {
87	case 0:
88		return fmt.Errorf("no value of field: %s", strings.Join(fieldPath, "."))
89	case 1:
90	default:
91		grpclog.Printf("too many field values: %s", strings.Join(fieldPath, "."))
92	}
93	return populateField(m, values[0], props)
94}
95
96// fieldByProtoName looks up a field whose corresponding protobuf field name is "name".
97// "m" must be a struct value. It returns zero reflect.Value if no such field found.
98func fieldByProtoName(m reflect.Value, name string) (reflect.Value, *proto.Properties, error) {
99	props := proto.GetProperties(m.Type())
100
101	// look up field name in oneof map
102	if op, ok := props.OneofTypes[name]; ok {
103		v := reflect.New(op.Type.Elem())
104		field := m.Field(op.Field)
105		if !field.IsNil() {
106			return reflect.Value{}, nil, fmt.Errorf("field already set for %s oneof", props.Prop[op.Field].OrigName)
107		}
108		field.Set(v)
109		return v.Elem().Field(0), op.Prop, nil
110	}
111
112	for _, p := range props.Prop {
113		if p.OrigName == name {
114			return m.FieldByName(p.Name), p, nil
115		}
116		if p.JSONName == name {
117			return m.FieldByName(p.Name), p, nil
118		}
119	}
120	return reflect.Value{}, nil, nil
121}
122
123func populateRepeatedField(f reflect.Value, values []string, props *proto.Properties) error {
124	elemType := f.Type().Elem()
125
126	// is the destination field a slice of an enumeration type?
127	if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil {
128		return populateFieldEnumRepeated(f, values, enumValMap)
129	}
130
131	conv, ok := convFromType[elemType.Kind()]
132	if !ok {
133		return fmt.Errorf("unsupported field type %s", elemType)
134	}
135	f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type()))
136	for i, v := range values {
137		result := conv.Call([]reflect.Value{reflect.ValueOf(v)})
138		if err := result[1].Interface(); err != nil {
139			return err.(error)
140		}
141		f.Index(i).Set(result[0].Convert(f.Index(i).Type()))
142	}
143	return nil
144}
145
146func populateField(f reflect.Value, value string, props *proto.Properties) error {
147	// Handle well known type
148	type wkt interface {
149		XXX_WellKnownType() string
150	}
151	if wkt, ok := f.Addr().Interface().(wkt); ok {
152		switch wkt.XXX_WellKnownType() {
153		case "Timestamp":
154			if value == "null" {
155				f.Field(0).SetInt(0)
156				f.Field(1).SetInt(0)
157				return nil
158			}
159
160			t, err := time.Parse(time.RFC3339Nano, value)
161			if err != nil {
162				return fmt.Errorf("bad Timestamp: %v", err)
163			}
164			f.Field(0).SetInt(int64(t.Unix()))
165			f.Field(1).SetInt(int64(t.Nanosecond()))
166			return nil
167		case "DoubleValue":
168			fallthrough
169		case "FloatValue":
170			float64Val, err := strconv.ParseFloat(value, 64)
171			if err != nil {
172				return fmt.Errorf("bad DoubleValue: %s", value)
173			}
174			f.Field(0).SetFloat(float64Val)
175			return nil
176		case "Int64Value":
177			fallthrough
178		case "Int32Value":
179			int64Val, err := strconv.ParseInt(value, 10, 64)
180			if err != nil {
181				return fmt.Errorf("bad DoubleValue: %s", value)
182			}
183			f.Field(0).SetInt(int64Val)
184			return nil
185		case "UInt64Value":
186			fallthrough
187		case "UInt32Value":
188			uint64Val, err := strconv.ParseUint(value, 10, 64)
189			if err != nil {
190				return fmt.Errorf("bad DoubleValue: %s", value)
191			}
192			f.Field(0).SetUint(uint64Val)
193			return nil
194		case "BoolValue":
195			if value == "true" {
196				f.Field(0).SetBool(true)
197			} else if value == "false" {
198				f.Field(0).SetBool(false)
199			} else {
200				return fmt.Errorf("bad BoolValue: %s", value)
201			}
202			return nil
203		case "StringValue":
204			f.Field(0).SetString(value)
205			return nil
206		}
207	}
208
209	// is the destination field an enumeration type?
210	if enumValMap := proto.EnumValueMap(props.Enum); enumValMap != nil {
211		return populateFieldEnum(f, value, enumValMap)
212	}
213
214	conv, ok := convFromType[f.Kind()]
215	if !ok {
216		return fmt.Errorf("unsupported field type %T", f)
217	}
218	result := conv.Call([]reflect.Value{reflect.ValueOf(value)})
219	if err := result[1].Interface(); err != nil {
220		return err.(error)
221	}
222	f.Set(result[0].Convert(f.Type()))
223	return nil
224}
225
226func convertEnum(value string, t reflect.Type, enumValMap map[string]int32) (reflect.Value, error) {
227	// see if it's an enumeration string
228	if enumVal, ok := enumValMap[value]; ok {
229		return reflect.ValueOf(enumVal).Convert(t), nil
230	}
231
232	// check for an integer that matches an enumeration value
233	eVal, err := strconv.Atoi(value)
234	if err != nil {
235		return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t)
236	}
237	for _, v := range enumValMap {
238		if v == int32(eVal) {
239			return reflect.ValueOf(eVal).Convert(t), nil
240		}
241	}
242	return reflect.Value{}, fmt.Errorf("%s is not a valid %s", value, t)
243}
244
245func populateFieldEnum(f reflect.Value, value string, enumValMap map[string]int32) error {
246	cval, err := convertEnum(value, f.Type(), enumValMap)
247	if err != nil {
248		return err
249	}
250	f.Set(cval)
251	return nil
252}
253
254func populateFieldEnumRepeated(f reflect.Value, values []string, enumValMap map[string]int32) error {
255	elemType := f.Type().Elem()
256	f.Set(reflect.MakeSlice(f.Type(), len(values), len(values)).Convert(f.Type()))
257	for i, v := range values {
258		result, err := convertEnum(v, elemType, enumValMap)
259		if err != nil {
260			return err
261		}
262		f.Index(i).Set(result)
263	}
264	return nil
265}
266
267var (
268	convFromType = map[reflect.Kind]reflect.Value{
269		reflect.String:  reflect.ValueOf(String),
270		reflect.Bool:    reflect.ValueOf(Bool),
271		reflect.Float64: reflect.ValueOf(Float64),
272		reflect.Float32: reflect.ValueOf(Float32),
273		reflect.Int64:   reflect.ValueOf(Int64),
274		reflect.Int32:   reflect.ValueOf(Int32),
275		reflect.Uint64:  reflect.ValueOf(Uint64),
276		reflect.Uint32:  reflect.ValueOf(Uint32),
277		// TODO(yugui) Support []byte
278	}
279)
280