1// Copyright 2019 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package proto 6 7import ( 8 "bytes" 9 "math" 10 "reflect" 11 12 "google.golang.org/protobuf/encoding/protowire" 13 pref "google.golang.org/protobuf/reflect/protoreflect" 14) 15 16// Equal reports whether two messages are equal. 17// If two messages marshal to the same bytes under deterministic serialization, 18// then Equal is guaranteed to report true. 19// 20// Two messages are equal if they belong to the same message descriptor, 21// have the same set of populated known and extension field values, 22// and the same set of unknown fields values. If either of the top-level 23// messages are invalid, then Equal reports true only if both are invalid. 24// 25// Scalar values are compared with the equivalent of the == operator in Go, 26// except bytes values which are compared using bytes.Equal and 27// floating point values which specially treat NaNs as equal. 28// Message values are compared by recursively calling Equal. 29// Lists are equal if each element value is also equal. 30// Maps are equal if they have the same set of keys, where the pair of values 31// for each key is also equal. 32func Equal(x, y Message) bool { 33 if x == nil || y == nil { 34 return x == nil && y == nil 35 } 36 mx := x.ProtoReflect() 37 my := y.ProtoReflect() 38 if mx.IsValid() != my.IsValid() { 39 return false 40 } 41 return equalMessage(mx, my) 42} 43 44// equalMessage compares two messages. 45func equalMessage(mx, my pref.Message) bool { 46 if mx.Descriptor() != my.Descriptor() { 47 return false 48 } 49 50 nx := 0 51 equal := true 52 mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool { 53 nx++ 54 vy := my.Get(fd) 55 equal = my.Has(fd) && equalField(fd, vx, vy) 56 return equal 57 }) 58 if !equal { 59 return false 60 } 61 ny := 0 62 my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool { 63 ny++ 64 return true 65 }) 66 if nx != ny { 67 return false 68 } 69 70 return equalUnknown(mx.GetUnknown(), my.GetUnknown()) 71} 72 73// equalField compares two fields. 74func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool { 75 switch { 76 case fd.IsList(): 77 return equalList(fd, x.List(), y.List()) 78 case fd.IsMap(): 79 return equalMap(fd, x.Map(), y.Map()) 80 default: 81 return equalValue(fd, x, y) 82 } 83} 84 85// equalMap compares two maps. 86func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool { 87 if x.Len() != y.Len() { 88 return false 89 } 90 equal := true 91 x.Range(func(k pref.MapKey, vx pref.Value) bool { 92 vy := y.Get(k) 93 equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy) 94 return equal 95 }) 96 return equal 97} 98 99// equalList compares two lists. 100func equalList(fd pref.FieldDescriptor, x, y pref.List) bool { 101 if x.Len() != y.Len() { 102 return false 103 } 104 for i := x.Len() - 1; i >= 0; i-- { 105 if !equalValue(fd, x.Get(i), y.Get(i)) { 106 return false 107 } 108 } 109 return true 110} 111 112// equalValue compares two singular values. 113func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool { 114 switch { 115 case fd.Message() != nil: 116 return equalMessage(x.Message(), y.Message()) 117 case fd.Kind() == pref.BytesKind: 118 return bytes.Equal(x.Bytes(), y.Bytes()) 119 case fd.Kind() == pref.FloatKind, fd.Kind() == pref.DoubleKind: 120 fx := x.Float() 121 fy := y.Float() 122 if math.IsNaN(fx) || math.IsNaN(fy) { 123 return math.IsNaN(fx) && math.IsNaN(fy) 124 } 125 return fx == fy 126 default: 127 return x.Interface() == y.Interface() 128 } 129} 130 131// equalUnknown compares unknown fields by direct comparison on the raw bytes 132// of each individual field number. 133func equalUnknown(x, y pref.RawFields) bool { 134 if len(x) != len(y) { 135 return false 136 } 137 if bytes.Equal([]byte(x), []byte(y)) { 138 return true 139 } 140 141 mx := make(map[pref.FieldNumber]pref.RawFields) 142 my := make(map[pref.FieldNumber]pref.RawFields) 143 for len(x) > 0 { 144 fnum, _, n := protowire.ConsumeField(x) 145 mx[fnum] = append(mx[fnum], x[:n]...) 146 x = x[n:] 147 } 148 for len(y) > 0 { 149 fnum, _, n := protowire.ConsumeField(y) 150 my[fnum] = append(my[fnum], y[:n]...) 151 y = y[n:] 152 } 153 return reflect.DeepEqual(mx, my) 154} 155