1// Copyright 2019 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package proto 6 7import ( 8 "bytes" 9 "math" 10 "reflect" 11 12 "google.golang.org/protobuf/encoding/protowire" 13 pref "google.golang.org/protobuf/reflect/protoreflect" 14) 15 16// Equal reports whether two messages are equal. 17// If two messages marshal to the same bytes under deterministic serialization, 18// then Equal is guaranteed to report true. 19// 20// Two messages are equal if they belong to the same message descriptor, 21// have the same set of populated known and extension field values, 22// and the same set of unknown fields values. If either of the top-level 23// messages are invalid, then Equal reports true only if both are invalid. 24// 25// Scalar values are compared with the equivalent of the == operator in Go, 26// except bytes values which are compared using bytes.Equal and 27// floating point values which specially treat NaNs as equal. 28// Message values are compared by recursively calling Equal. 29// Lists are equal if each element value is also equal. 30// Maps are equal if they have the same set of keys, where the pair of values 31// for each key is also equal. 32func Equal(x, y Message) bool { 33 if x == nil || y == nil { 34 return x == nil && y == nil 35 } 36 mx := x.ProtoReflect() 37 my := y.ProtoReflect() 38 if mx.IsValid() != my.IsValid() { 39 return false 40 } 41 return equalMessage(mx, my) 42} 43 44// equalMessage compares two messages. 45func equalMessage(mx, my pref.Message) bool { 46 if mx.Descriptor() != my.Descriptor() { 47 return false 48 } 49 50 nx := 0 51 equal := true 52 mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool { 53 nx++ 54 vy := my.Get(fd) 55 equal = my.Has(fd) && equalField(fd, vx, vy) 56 return equal 57 }) 58 if !equal { 59 return false 60 } 61 ny := 0 62 my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool { 63 ny++ 64 return true 65 }) 66 if nx != ny { 67 return false 68 } 69 70 return equalUnknown(mx.GetUnknown(), my.GetUnknown()) 71} 72 73// equalField compares two fields. 74func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool { 75 switch { 76 case fd.IsList(): 77 return equalList(fd, x.List(), y.List()) 78 case fd.IsMap(): 79 return equalMap(fd, x.Map(), y.Map()) 80 default: 81 return equalValue(fd, x, y) 82 } 83} 84 85// equalMap compares two maps. 86func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool { 87 if x.Len() != y.Len() { 88 return false 89 } 90 equal := true 91 x.Range(func(k pref.MapKey, vx pref.Value) bool { 92 vy := y.Get(k) 93 equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy) 94 return equal 95 }) 96 return equal 97} 98 99// equalList compares two lists. 100func equalList(fd pref.FieldDescriptor, x, y pref.List) bool { 101 if x.Len() != y.Len() { 102 return false 103 } 104 for i := x.Len() - 1; i >= 0; i-- { 105 if !equalValue(fd, x.Get(i), y.Get(i)) { 106 return false 107 } 108 } 109 return true 110} 111 112// equalValue compares two singular values. 113func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool { 114 switch fd.Kind() { 115 case pref.BoolKind: 116 return x.Bool() == y.Bool() 117 case pref.EnumKind: 118 return x.Enum() == y.Enum() 119 case pref.Int32Kind, pref.Sint32Kind, 120 pref.Int64Kind, pref.Sint64Kind, 121 pref.Sfixed32Kind, pref.Sfixed64Kind: 122 return x.Int() == y.Int() 123 case pref.Uint32Kind, pref.Uint64Kind, 124 pref.Fixed32Kind, pref.Fixed64Kind: 125 return x.Uint() == y.Uint() 126 case pref.FloatKind, pref.DoubleKind: 127 fx := x.Float() 128 fy := y.Float() 129 if math.IsNaN(fx) || math.IsNaN(fy) { 130 return math.IsNaN(fx) && math.IsNaN(fy) 131 } 132 return fx == fy 133 case pref.StringKind: 134 return x.String() == y.String() 135 case pref.BytesKind: 136 return bytes.Equal(x.Bytes(), y.Bytes()) 137 case pref.MessageKind, pref.GroupKind: 138 return equalMessage(x.Message(), y.Message()) 139 default: 140 return x.Interface() == y.Interface() 141 } 142} 143 144// equalUnknown compares unknown fields by direct comparison on the raw bytes 145// of each individual field number. 146func equalUnknown(x, y pref.RawFields) bool { 147 if len(x) != len(y) { 148 return false 149 } 150 if bytes.Equal([]byte(x), []byte(y)) { 151 return true 152 } 153 154 mx := make(map[pref.FieldNumber]pref.RawFields) 155 my := make(map[pref.FieldNumber]pref.RawFields) 156 for len(x) > 0 { 157 fnum, _, n := protowire.ConsumeField(x) 158 mx[fnum] = append(mx[fnum], x[:n]...) 159 x = x[n:] 160 } 161 for len(y) > 0 { 162 fnum, _, n := protowire.ConsumeField(y) 163 my[fnum] = append(my[fnum], y[:n]...) 164 y = y[n:] 165 } 166 return reflect.DeepEqual(mx, my) 167} 168