1package internal
2
3import (
4	"reflect"
5
6	"github.com/golang/protobuf/proto"
7)
8
9var typeOfBytes = reflect.TypeOf([]byte(nil))
10
11// GetUnrecognized fetches the bytes of unrecognized fields for the given message.
12func GetUnrecognized(msg proto.Message) []byte {
13	val := reflect.Indirect(reflect.ValueOf(msg))
14	u := val.FieldByName("XXX_unrecognized")
15	if u.IsValid() && u.Type() == typeOfBytes {
16		return u.Interface().([]byte)
17	}
18
19	// Fallback to reflection for API v2 messages
20	get, _, _, ok := unrecognizedGetSetMethods(val)
21	if !ok {
22		return nil
23	}
24
25	return get.Call([]reflect.Value(nil))[0].Convert(typeOfBytes).Interface().([]byte)
26}
27
28// SetUnrecognized adds the given bytes to the unrecognized fields for the given message.
29func SetUnrecognized(msg proto.Message, data []byte) {
30	val := reflect.Indirect(reflect.ValueOf(msg))
31	u := val.FieldByName("XXX_unrecognized")
32	if u.IsValid() && u.Type() == typeOfBytes {
33		// Just store the bytes in the unrecognized field
34		ub := u.Interface().([]byte)
35		ub = append(ub, data...)
36		u.Set(reflect.ValueOf(ub))
37		return
38	}
39
40	// Fallback to reflection for API v2 messages
41	get, set, argType, ok := unrecognizedGetSetMethods(val)
42	if !ok {
43		return
44	}
45
46	existing := get.Call([]reflect.Value(nil))[0].Convert(typeOfBytes).Interface().([]byte)
47	if len(existing) > 0 {
48		data = append(existing, data...)
49	}
50	set.Call([]reflect.Value{reflect.ValueOf(data).Convert(argType)})
51}
52
53func unrecognizedGetSetMethods(val reflect.Value) (get reflect.Value, set reflect.Value, argType reflect.Type, ok bool) {
54	// val could be an APIv2 message. We use reflection to interact with
55	// this message so that we don't have a hard dependency on the new
56	// version of the protobuf package.
57	refMethod := val.MethodByName("ProtoReflect")
58	if !refMethod.IsValid() {
59		if val.CanAddr() {
60			refMethod = val.Addr().MethodByName("ProtoReflect")
61		}
62		if !refMethod.IsValid() {
63			return
64		}
65	}
66	refType := refMethod.Type()
67	if refType.NumIn() != 0 || refType.NumOut() != 1 {
68		return
69	}
70	ref := refMethod.Call([]reflect.Value(nil))
71	getMethod, setMethod := ref[0].MethodByName("GetUnknown"), ref[0].MethodByName("SetUnknown")
72	if !getMethod.IsValid() || !setMethod.IsValid() {
73		return
74	}
75	getType := getMethod.Type()
76	setType := setMethod.Type()
77	if getType.NumIn() != 0 || getType.NumOut() != 1 || setType.NumIn() != 1 || setType.NumOut() != 0 {
78		return
79	}
80	arg := setType.In(0)
81	if !arg.ConvertibleTo(typeOfBytes) || getType.Out(0) != arg {
82		return
83	}
84
85	return getMethod, setMethod, arg, true
86}
87