1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8	"bytes"
9	"math"
10	"reflect"
11
12	"google.golang.org/protobuf/encoding/protowire"
13	pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
16// Equal reports whether two messages are equal.
17// If two messages marshal to the same bytes under deterministic serialization,
18// then Equal is guaranteed to report true.
19//
20// Two messages are equal if they belong to the same message descriptor,
21// have the same set of populated known and extension field values,
22// and the same set of unknown fields values. If either of the top-level
23// messages are invalid, then Equal reports true only if both are invalid.
24//
25// Scalar values are compared with the equivalent of the == operator in Go,
26// except bytes values which are compared using bytes.Equal and
27// floating point values which specially treat NaNs as equal.
28// Message values are compared by recursively calling Equal.
29// Lists are equal if each element value is also equal.
30// Maps are equal if they have the same set of keys, where the pair of values
31// for each key is also equal.
32func Equal(x, y Message) bool {
33	if x == nil || y == nil {
34		return x == nil && y == nil
35	}
36	mx := x.ProtoReflect()
37	my := y.ProtoReflect()
38	if mx.IsValid() != my.IsValid() {
39		return false
40	}
41	return equalMessage(mx, my)
42}
43
44// equalMessage compares two messages.
45func equalMessage(mx, my pref.Message) bool {
46	if mx.Descriptor() != my.Descriptor() {
47		return false
48	}
49
50	nx := 0
51	equal := true
52	mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
53		nx++
54		vy := my.Get(fd)
55		equal = my.Has(fd) && equalField(fd, vx, vy)
56		return equal
57	})
58	if !equal {
59		return false
60	}
61	ny := 0
62	my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
63		ny++
64		return true
65	})
66	if nx != ny {
67		return false
68	}
69
70	return equalUnknown(mx.GetUnknown(), my.GetUnknown())
71}
72
73// equalField compares two fields.
74func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool {
75	switch {
76	case fd.IsList():
77		return equalList(fd, x.List(), y.List())
78	case fd.IsMap():
79		return equalMap(fd, x.Map(), y.Map())
80	default:
81		return equalValue(fd, x, y)
82	}
83}
84
85// equalMap compares two maps.
86func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool {
87	if x.Len() != y.Len() {
88		return false
89	}
90	equal := true
91	x.Range(func(k pref.MapKey, vx pref.Value) bool {
92		vy := y.Get(k)
93		equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
94		return equal
95	})
96	return equal
97}
98
99// equalList compares two lists.
100func equalList(fd pref.FieldDescriptor, x, y pref.List) bool {
101	if x.Len() != y.Len() {
102		return false
103	}
104	for i := x.Len() - 1; i >= 0; i-- {
105		if !equalValue(fd, x.Get(i), y.Get(i)) {
106			return false
107		}
108	}
109	return true
110}
111
112// equalValue compares two singular values.
113func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool {
114	switch fd.Kind() {
115	case pref.BoolKind:
116		return x.Bool() == y.Bool()
117	case pref.EnumKind:
118		return x.Enum() == y.Enum()
119	case pref.Int32Kind, pref.Sint32Kind,
120		pref.Int64Kind, pref.Sint64Kind,
121		pref.Sfixed32Kind, pref.Sfixed64Kind:
122		return x.Int() == y.Int()
123	case pref.Uint32Kind, pref.Uint64Kind,
124		pref.Fixed32Kind, pref.Fixed64Kind:
125		return x.Uint() == y.Uint()
126	case pref.FloatKind, pref.DoubleKind:
127		fx := x.Float()
128		fy := y.Float()
129		if math.IsNaN(fx) || math.IsNaN(fy) {
130			return math.IsNaN(fx) && math.IsNaN(fy)
131		}
132		return fx == fy
133	case pref.StringKind:
134		return x.String() == y.String()
135	case pref.BytesKind:
136		return bytes.Equal(x.Bytes(), y.Bytes())
137	case pref.MessageKind, pref.GroupKind:
138		return equalMessage(x.Message(), y.Message())
139	default:
140		return x.Interface() == y.Interface()
141	}
142}
143
144// equalUnknown compares unknown fields by direct comparison on the raw bytes
145// of each individual field number.
146func equalUnknown(x, y pref.RawFields) bool {
147	if len(x) != len(y) {
148		return false
149	}
150	if bytes.Equal([]byte(x), []byte(y)) {
151		return true
152	}
153
154	mx := make(map[pref.FieldNumber]pref.RawFields)
155	my := make(map[pref.FieldNumber]pref.RawFields)
156	for len(x) > 0 {
157		fnum, _, n := protowire.ConsumeField(x)
158		mx[fnum] = append(mx[fnum], x[:n]...)
159		x = x[n:]
160	}
161	for len(y) > 0 {
162		fnum, _, n := protowire.ConsumeField(y)
163		my[fnum] = append(my[fnum], y[:n]...)
164		y = y[n:]
165	}
166	return reflect.DeepEqual(mx, my)
167}
168