1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package protocmp provides protobuf specific options for the
6// "github.com/google/go-cmp/cmp" package.
7//
8// The primary feature is the Transform option, which transform proto.Message
9// types into a Message map that is suitable for cmp to introspect upon.
10// All other options in this package must be used in conjunction with Transform.
11package protocmp
12
13import (
14	"reflect"
15	"strconv"
16
17	"github.com/google/go-cmp/cmp"
18
19	"google.golang.org/protobuf/encoding/protowire"
20	"google.golang.org/protobuf/internal/genid"
21	"google.golang.org/protobuf/internal/msgfmt"
22	"google.golang.org/protobuf/proto"
23	"google.golang.org/protobuf/reflect/protoreflect"
24	"google.golang.org/protobuf/reflect/protoregistry"
25	"google.golang.org/protobuf/runtime/protoiface"
26	"google.golang.org/protobuf/runtime/protoimpl"
27)
28
29var (
30	enumV2Type    = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem()
31	messageV1Type = reflect.TypeOf((*protoiface.MessageV1)(nil)).Elem()
32	messageV2Type = reflect.TypeOf((*proto.Message)(nil)).Elem()
33)
34
35// Enum is a dynamic representation of a protocol buffer enum that is
36// suitable for cmp.Equal and cmp.Diff to compare upon.
37type Enum struct {
38	num protoreflect.EnumNumber
39	ed  protoreflect.EnumDescriptor
40}
41
42// Descriptor returns the enum descriptor.
43// It returns nil for a zero Enum value.
44func (e Enum) Descriptor() protoreflect.EnumDescriptor {
45	return e.ed
46}
47
48// Number returns the enum value as an integer.
49func (e Enum) Number() protoreflect.EnumNumber {
50	return e.num
51}
52
53// Equal reports whether e1 and e2 represent the same enum value.
54func (e1 Enum) Equal(e2 Enum) bool {
55	if e1.ed.FullName() != e2.ed.FullName() {
56		return false
57	}
58	return e1.num == e2.num
59}
60
61// String returns the name of the enum value if known (e.g., "ENUM_VALUE"),
62// otherwise it returns the formatted decimal enum number (e.g., "14").
63func (e Enum) String() string {
64	if ev := e.ed.Values().ByNumber(e.num); ev != nil {
65		return string(ev.Name())
66	}
67	return strconv.Itoa(int(e.num))
68}
69
70const (
71	messageTypeKey    = "@type"
72	messageInvalidKey = "@invalid"
73)
74
75type messageType struct {
76	md  protoreflect.MessageDescriptor
77	xds map[string]protoreflect.ExtensionDescriptor
78}
79
80func (t messageType) String() string {
81	return string(t.md.FullName())
82}
83
84func (t1 messageType) Equal(t2 messageType) bool {
85	return t1.md.FullName() == t2.md.FullName()
86}
87
88// Message is a dynamic representation of a protocol buffer message that is
89// suitable for cmp.Equal and cmp.Diff to directly operate upon.
90//
91// Every populated known field (excluding extension fields) is stored in the map
92// with the key being the short name of the field (e.g., "field_name") and
93// the value determined by the kind and cardinality of the field.
94//
95// Singular scalars are represented by the same Go type as protoreflect.Value,
96// singular messages are represented by the Message type,
97// singular enums are represented by the Enum type,
98// list fields are represented as a Go slice, and
99// map fields are represented as a Go map.
100//
101// Every populated extension field is stored in the map with the key being the
102// full name of the field surrounded by brackets (e.g., "[extension.full.name]")
103// and the value determined according to the same rules as known fields.
104//
105// Every unknown field is stored in the map with the key being the field number
106// encoded as a decimal string (e.g., "132") and the value being the raw bytes
107// of the encoded field (as the protoreflect.RawFields type).
108//
109// Message values must not be created by or mutated by users.
110type Message map[string]interface{}
111
112// Descriptor return the message descriptor.
113// It returns nil for a zero Message value.
114func (m Message) Descriptor() protoreflect.MessageDescriptor {
115	mt, _ := m[messageTypeKey].(messageType)
116	return mt.md
117}
118
119// ProtoReflect returns a reflective view of m.
120// It only implements the read-only operations of protoreflect.Message.
121// Calling any mutating operations on m panics.
122func (m Message) ProtoReflect() protoreflect.Message {
123	return (reflectMessage)(m)
124}
125
126// ProtoMessage is a marker method from the legacy message interface.
127func (m Message) ProtoMessage() {}
128
129// Reset is the required Reset method from the legacy message interface.
130func (m Message) Reset() {
131	panic("invalid mutation of a read-only message")
132}
133
134// String returns a formatted string for the message.
135// It is intended for human debugging and has no guarantees about its
136// exact format or the stability of its output.
137func (m Message) String() string {
138	switch {
139	case m == nil:
140		return "<nil>"
141	case !m.ProtoReflect().IsValid():
142		return "<invalid>"
143	default:
144		return msgfmt.Format(m)
145	}
146}
147
148type option struct{}
149
150// Transform returns a cmp.Option that converts each proto.Message to a Message.
151// The transformation does not mutate nor alias any converted messages.
152//
153// The google.protobuf.Any message is automatically unmarshaled such that the
154// "value" field is a Message representing the underlying message value
155// assuming it could be resolved and properly unmarshaled.
156//
157// This does not directly transform higher-order composite Go types.
158// For example, []*foopb.Message is not transformed into []Message,
159// but rather the individual message elements of the slice are transformed.
160//
161// Note that there are currently no custom options for Transform,
162// but the use of an unexported type keeps the future open.
163func Transform(...option) cmp.Option {
164	// addrType returns a pointer to t if t isn't a pointer or interface.
165	addrType := func(t reflect.Type) reflect.Type {
166		if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr {
167			return t
168		}
169		return reflect.PtrTo(t)
170	}
171
172	// TODO: Should this transform protoreflect.Enum types to Enum as well?
173	return cmp.FilterPath(func(p cmp.Path) bool {
174		ps := p.Last()
175		if isMessageType(addrType(ps.Type())) {
176			return true
177		}
178
179		// Check whether the concrete values of an interface both satisfy
180		// the Message interface.
181		if ps.Type().Kind() == reflect.Interface {
182			vx, vy := ps.Values()
183			if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() {
184				return false
185			}
186			return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type()))
187		}
188
189		return false
190	}, cmp.Transformer("protocmp.Transform", func(v interface{}) Message {
191		// For user convenience, shallow copy the message value if necessary
192		// in order for it to implement the message interface.
193		if rv := reflect.ValueOf(v); rv.IsValid() && rv.Kind() != reflect.Ptr && !isMessageType(rv.Type()) {
194			pv := reflect.New(rv.Type())
195			pv.Elem().Set(rv)
196			v = pv.Interface()
197		}
198
199		m := protoimpl.X.MessageOf(v)
200		switch {
201		case m == nil:
202			return nil
203		case !m.IsValid():
204			return Message{messageTypeKey: messageType{md: m.Descriptor()}, messageInvalidKey: true}
205		default:
206			return transformMessage(m)
207		}
208	}))
209}
210
211func isMessageType(t reflect.Type) bool {
212	// Avoid tranforming the Message itself.
213	if t == reflect.TypeOf(Message(nil)) || t == reflect.TypeOf((*Message)(nil)) {
214		return false
215	}
216	return t.Implements(messageV1Type) || t.Implements(messageV2Type)
217}
218
219func transformMessage(m protoreflect.Message) Message {
220	mx := Message{}
221	mt := messageType{md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
222
223	// Handle known and extension fields.
224	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
225		s := fd.TextName()
226		if fd.IsExtension() {
227			mt.xds[s] = fd
228		}
229		switch {
230		case fd.IsList():
231			mx[s] = transformList(fd, v.List())
232		case fd.IsMap():
233			mx[s] = transformMap(fd, v.Map())
234		default:
235			mx[s] = transformSingular(fd, v)
236		}
237		return true
238	})
239
240	// Handle unknown fields.
241	for b := m.GetUnknown(); len(b) > 0; {
242		num, _, n := protowire.ConsumeField(b)
243		s := strconv.Itoa(int(num))
244		b2, _ := mx[s].(protoreflect.RawFields)
245		mx[s] = append(b2, b[:n]...)
246		b = b[n:]
247	}
248
249	// Expand Any messages.
250	if mt.md.FullName() == genid.Any_message_fullname {
251		// TODO: Expose Transform option to specify a custom resolver?
252		s, _ := mx[string(genid.Any_TypeUrl_field_name)].(string)
253		b, _ := mx[string(genid.Any_Value_field_name)].([]byte)
254		mt, err := protoregistry.GlobalTypes.FindMessageByURL(s)
255		if mt != nil && err == nil {
256			m2 := mt.New()
257			err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, m2.Interface())
258			if err == nil {
259				mx[string(genid.Any_Value_field_name)] = transformMessage(m2)
260			}
261		}
262	}
263
264	mx[messageTypeKey] = mt
265	return mx
266}
267
268func transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) interface{} {
269	t := protoKindToGoType(fd.Kind())
270	rv := reflect.MakeSlice(reflect.SliceOf(t), lv.Len(), lv.Len())
271	for i := 0; i < lv.Len(); i++ {
272		v := reflect.ValueOf(transformSingular(fd, lv.Get(i)))
273		rv.Index(i).Set(v)
274	}
275	return rv.Interface()
276}
277
278func transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) interface{} {
279	kfd := fd.MapKey()
280	vfd := fd.MapValue()
281	kt := protoKindToGoType(kfd.Kind())
282	vt := protoKindToGoType(vfd.Kind())
283	rv := reflect.MakeMapWithSize(reflect.MapOf(kt, vt), mv.Len())
284	mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
285		kv := reflect.ValueOf(transformSingular(kfd, k.Value()))
286		vv := reflect.ValueOf(transformSingular(vfd, v))
287		rv.SetMapIndex(kv, vv)
288		return true
289	})
290	return rv.Interface()
291}
292
293func transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) interface{} {
294	switch fd.Kind() {
295	case protoreflect.EnumKind:
296		return Enum{num: v.Enum(), ed: fd.Enum()}
297	case protoreflect.MessageKind, protoreflect.GroupKind:
298		return transformMessage(v.Message())
299	case protoreflect.BytesKind:
300		// The protoreflect API does not specify whether an empty bytes is
301		// guaranteed to be nil or not. Always return non-nil bytes to avoid
302		// leaking information about the concrete proto.Message implementation.
303		if len(v.Bytes()) == 0 {
304			return []byte{}
305		}
306		return v.Bytes()
307	default:
308		return v.Interface()
309	}
310}
311
312func protoKindToGoType(k protoreflect.Kind) reflect.Type {
313	switch k {
314	case protoreflect.BoolKind:
315		return reflect.TypeOf(bool(false))
316	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
317		return reflect.TypeOf(int32(0))
318	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
319		return reflect.TypeOf(int64(0))
320	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
321		return reflect.TypeOf(uint32(0))
322	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
323		return reflect.TypeOf(uint64(0))
324	case protoreflect.FloatKind:
325		return reflect.TypeOf(float32(0))
326	case protoreflect.DoubleKind:
327		return reflect.TypeOf(float64(0))
328	case protoreflect.StringKind:
329		return reflect.TypeOf(string(""))
330	case protoreflect.BytesKind:
331		return reflect.TypeOf([]byte(nil))
332	case protoreflect.EnumKind:
333		return reflect.TypeOf(Enum{})
334	case protoreflect.MessageKind, protoreflect.GroupKind:
335		return reflect.TypeOf(Message{})
336	default:
337		panic("invalid kind")
338	}
339}
340