1// Copyright 2019 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Package protocmp provides protobuf specific options for the 6// "github.com/google/go-cmp/cmp" package. 7// 8// The primary feature is the Transform option, which transform proto.Message 9// types into a Message map that is suitable for cmp to introspect upon. 10// All other options in this package must be used in conjunction with Transform. 11package protocmp 12 13import ( 14 "reflect" 15 "strconv" 16 17 "github.com/google/go-cmp/cmp" 18 19 "google.golang.org/protobuf/encoding/protowire" 20 "google.golang.org/protobuf/internal/genid" 21 "google.golang.org/protobuf/internal/msgfmt" 22 "google.golang.org/protobuf/proto" 23 "google.golang.org/protobuf/reflect/protoreflect" 24 "google.golang.org/protobuf/reflect/protoregistry" 25 "google.golang.org/protobuf/runtime/protoiface" 26 "google.golang.org/protobuf/runtime/protoimpl" 27) 28 29var ( 30 enumV2Type = reflect.TypeOf((*protoreflect.Enum)(nil)).Elem() 31 messageV1Type = reflect.TypeOf((*protoiface.MessageV1)(nil)).Elem() 32 messageV2Type = reflect.TypeOf((*proto.Message)(nil)).Elem() 33) 34 35// Enum is a dynamic representation of a protocol buffer enum that is 36// suitable for cmp.Equal and cmp.Diff to compare upon. 37type Enum struct { 38 num protoreflect.EnumNumber 39 ed protoreflect.EnumDescriptor 40} 41 42// Descriptor returns the enum descriptor. 43// It returns nil for a zero Enum value. 44func (e Enum) Descriptor() protoreflect.EnumDescriptor { 45 return e.ed 46} 47 48// Number returns the enum value as an integer. 49func (e Enum) Number() protoreflect.EnumNumber { 50 return e.num 51} 52 53// Equal reports whether e1 and e2 represent the same enum value. 54func (e1 Enum) Equal(e2 Enum) bool { 55 if e1.ed.FullName() != e2.ed.FullName() { 56 return false 57 } 58 return e1.num == e2.num 59} 60 61// String returns the name of the enum value if known (e.g., "ENUM_VALUE"), 62// otherwise it returns the formatted decimal enum number (e.g., "14"). 63func (e Enum) String() string { 64 if ev := e.ed.Values().ByNumber(e.num); ev != nil { 65 return string(ev.Name()) 66 } 67 return strconv.Itoa(int(e.num)) 68} 69 70const ( 71 messageTypeKey = "@type" 72 messageInvalidKey = "@invalid" 73) 74 75type messageType struct { 76 md protoreflect.MessageDescriptor 77 xds map[string]protoreflect.ExtensionDescriptor 78} 79 80func (t messageType) String() string { 81 return string(t.md.FullName()) 82} 83 84func (t1 messageType) Equal(t2 messageType) bool { 85 return t1.md.FullName() == t2.md.FullName() 86} 87 88// Message is a dynamic representation of a protocol buffer message that is 89// suitable for cmp.Equal and cmp.Diff to directly operate upon. 90// 91// Every populated known field (excluding extension fields) is stored in the map 92// with the key being the short name of the field (e.g., "field_name") and 93// the value determined by the kind and cardinality of the field. 94// 95// Singular scalars are represented by the same Go type as protoreflect.Value, 96// singular messages are represented by the Message type, 97// singular enums are represented by the Enum type, 98// list fields are represented as a Go slice, and 99// map fields are represented as a Go map. 100// 101// Every populated extension field is stored in the map with the key being the 102// full name of the field surrounded by brackets (e.g., "[extension.full.name]") 103// and the value determined according to the same rules as known fields. 104// 105// Every unknown field is stored in the map with the key being the field number 106// encoded as a decimal string (e.g., "132") and the value being the raw bytes 107// of the encoded field (as the protoreflect.RawFields type). 108// 109// Message values must not be created by or mutated by users. 110type Message map[string]interface{} 111 112// Descriptor return the message descriptor. 113// It returns nil for a zero Message value. 114func (m Message) Descriptor() protoreflect.MessageDescriptor { 115 mt, _ := m[messageTypeKey].(messageType) 116 return mt.md 117} 118 119// ProtoReflect returns a reflective view of m. 120// It only implements the read-only operations of protoreflect.Message. 121// Calling any mutating operations on m panics. 122func (m Message) ProtoReflect() protoreflect.Message { 123 return (reflectMessage)(m) 124} 125 126// ProtoMessage is a marker method from the legacy message interface. 127func (m Message) ProtoMessage() {} 128 129// Reset is the required Reset method from the legacy message interface. 130func (m Message) Reset() { 131 panic("invalid mutation of a read-only message") 132} 133 134// String returns a formatted string for the message. 135// It is intended for human debugging and has no guarantees about its 136// exact format or the stability of its output. 137func (m Message) String() string { 138 switch { 139 case m == nil: 140 return "<nil>" 141 case !m.ProtoReflect().IsValid(): 142 return "<invalid>" 143 default: 144 return msgfmt.Format(m) 145 } 146} 147 148type option struct{} 149 150// Transform returns a cmp.Option that converts each proto.Message to a Message. 151// The transformation does not mutate nor alias any converted messages. 152// 153// The google.protobuf.Any message is automatically unmarshaled such that the 154// "value" field is a Message representing the underlying message value 155// assuming it could be resolved and properly unmarshaled. 156// 157// This does not directly transform higher-order composite Go types. 158// For example, []*foopb.Message is not transformed into []Message, 159// but rather the individual message elements of the slice are transformed. 160// 161// Note that there are currently no custom options for Transform, 162// but the use of an unexported type keeps the future open. 163func Transform(...option) cmp.Option { 164 // addrType returns a pointer to t if t isn't a pointer or interface. 165 addrType := func(t reflect.Type) reflect.Type { 166 if k := t.Kind(); k == reflect.Interface || k == reflect.Ptr { 167 return t 168 } 169 return reflect.PtrTo(t) 170 } 171 172 // TODO: Should this transform protoreflect.Enum types to Enum as well? 173 return cmp.FilterPath(func(p cmp.Path) bool { 174 ps := p.Last() 175 if isMessageType(addrType(ps.Type())) { 176 return true 177 } 178 179 // Check whether the concrete values of an interface both satisfy 180 // the Message interface. 181 if ps.Type().Kind() == reflect.Interface { 182 vx, vy := ps.Values() 183 if !vx.IsValid() || vx.IsNil() || !vy.IsValid() || vy.IsNil() { 184 return false 185 } 186 return isMessageType(addrType(vx.Elem().Type())) && isMessageType(addrType(vy.Elem().Type())) 187 } 188 189 return false 190 }, cmp.Transformer("protocmp.Transform", func(v interface{}) Message { 191 // For user convenience, shallow copy the message value if necessary 192 // in order for it to implement the message interface. 193 if rv := reflect.ValueOf(v); rv.IsValid() && rv.Kind() != reflect.Ptr && !isMessageType(rv.Type()) { 194 pv := reflect.New(rv.Type()) 195 pv.Elem().Set(rv) 196 v = pv.Interface() 197 } 198 199 m := protoimpl.X.MessageOf(v) 200 switch { 201 case m == nil: 202 return nil 203 case !m.IsValid(): 204 return Message{messageTypeKey: messageType{md: m.Descriptor()}, messageInvalidKey: true} 205 default: 206 return transformMessage(m) 207 } 208 })) 209} 210 211func isMessageType(t reflect.Type) bool { 212 // Avoid tranforming the Message itself. 213 if t == reflect.TypeOf(Message(nil)) || t == reflect.TypeOf((*Message)(nil)) { 214 return false 215 } 216 return t.Implements(messageV1Type) || t.Implements(messageV2Type) 217} 218 219func transformMessage(m protoreflect.Message) Message { 220 mx := Message{} 221 mt := messageType{md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)} 222 223 // Handle known and extension fields. 224 m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { 225 s := fd.TextName() 226 if fd.IsExtension() { 227 mt.xds[s] = fd 228 } 229 switch { 230 case fd.IsList(): 231 mx[s] = transformList(fd, v.List()) 232 case fd.IsMap(): 233 mx[s] = transformMap(fd, v.Map()) 234 default: 235 mx[s] = transformSingular(fd, v) 236 } 237 return true 238 }) 239 240 // Handle unknown fields. 241 for b := m.GetUnknown(); len(b) > 0; { 242 num, _, n := protowire.ConsumeField(b) 243 s := strconv.Itoa(int(num)) 244 b2, _ := mx[s].(protoreflect.RawFields) 245 mx[s] = append(b2, b[:n]...) 246 b = b[n:] 247 } 248 249 // Expand Any messages. 250 if mt.md.FullName() == genid.Any_message_fullname { 251 // TODO: Expose Transform option to specify a custom resolver? 252 s, _ := mx[string(genid.Any_TypeUrl_field_name)].(string) 253 b, _ := mx[string(genid.Any_Value_field_name)].([]byte) 254 mt, err := protoregistry.GlobalTypes.FindMessageByURL(s) 255 if mt != nil && err == nil { 256 m2 := mt.New() 257 err := proto.UnmarshalOptions{AllowPartial: true}.Unmarshal(b, m2.Interface()) 258 if err == nil { 259 mx[string(genid.Any_Value_field_name)] = transformMessage(m2) 260 } 261 } 262 } 263 264 mx[messageTypeKey] = mt 265 return mx 266} 267 268func transformList(fd protoreflect.FieldDescriptor, lv protoreflect.List) interface{} { 269 t := protoKindToGoType(fd.Kind()) 270 rv := reflect.MakeSlice(reflect.SliceOf(t), lv.Len(), lv.Len()) 271 for i := 0; i < lv.Len(); i++ { 272 v := reflect.ValueOf(transformSingular(fd, lv.Get(i))) 273 rv.Index(i).Set(v) 274 } 275 return rv.Interface() 276} 277 278func transformMap(fd protoreflect.FieldDescriptor, mv protoreflect.Map) interface{} { 279 kfd := fd.MapKey() 280 vfd := fd.MapValue() 281 kt := protoKindToGoType(kfd.Kind()) 282 vt := protoKindToGoType(vfd.Kind()) 283 rv := reflect.MakeMapWithSize(reflect.MapOf(kt, vt), mv.Len()) 284 mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { 285 kv := reflect.ValueOf(transformSingular(kfd, k.Value())) 286 vv := reflect.ValueOf(transformSingular(vfd, v)) 287 rv.SetMapIndex(kv, vv) 288 return true 289 }) 290 return rv.Interface() 291} 292 293func transformSingular(fd protoreflect.FieldDescriptor, v protoreflect.Value) interface{} { 294 switch fd.Kind() { 295 case protoreflect.EnumKind: 296 return Enum{num: v.Enum(), ed: fd.Enum()} 297 case protoreflect.MessageKind, protoreflect.GroupKind: 298 return transformMessage(v.Message()) 299 case protoreflect.BytesKind: 300 // The protoreflect API does not specify whether an empty bytes is 301 // guaranteed to be nil or not. Always return non-nil bytes to avoid 302 // leaking information about the concrete proto.Message implementation. 303 if len(v.Bytes()) == 0 { 304 return []byte{} 305 } 306 return v.Bytes() 307 default: 308 return v.Interface() 309 } 310} 311 312func protoKindToGoType(k protoreflect.Kind) reflect.Type { 313 switch k { 314 case protoreflect.BoolKind: 315 return reflect.TypeOf(bool(false)) 316 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 317 return reflect.TypeOf(int32(0)) 318 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 319 return reflect.TypeOf(int64(0)) 320 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 321 return reflect.TypeOf(uint32(0)) 322 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 323 return reflect.TypeOf(uint64(0)) 324 case protoreflect.FloatKind: 325 return reflect.TypeOf(float32(0)) 326 case protoreflect.DoubleKind: 327 return reflect.TypeOf(float64(0)) 328 case protoreflect.StringKind: 329 return reflect.TypeOf(string("")) 330 case protoreflect.BytesKind: 331 return reflect.TypeOf([]byte(nil)) 332 case protoreflect.EnumKind: 333 return reflect.TypeOf(Enum{}) 334 case protoreflect.MessageKind, protoreflect.GroupKind: 335 return reflect.TypeOf(Message{}) 336 default: 337 panic("invalid kind") 338 } 339} 340