1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8	"bytes"
9	"math"
10	"reflect"
11
12	"google.golang.org/protobuf/encoding/protowire"
13	pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
16// Equal reports whether two messages are equal.
17// If two messages marshal to the same bytes under deterministic serialization,
18// then Equal is guaranteed to report true.
19//
20// Two messages are equal if they belong to the same message descriptor,
21// have the same set of populated known and extension field values,
22// and the same set of unknown fields values. If either of the top-level
23// messages are invalid, then Equal reports true only if both are invalid.
24//
25// Scalar values are compared with the equivalent of the == operator in Go,
26// except bytes values which are compared using bytes.Equal and
27// floating point values which specially treat NaNs as equal.
28// Message values are compared by recursively calling Equal.
29// Lists are equal if each element value is also equal.
30// Maps are equal if they have the same set of keys, where the pair of values
31// for each key is also equal.
32func Equal(x, y Message) bool {
33	if x == nil || y == nil {
34		return x == nil && y == nil
35	}
36	mx := x.ProtoReflect()
37	my := y.ProtoReflect()
38	if mx.IsValid() != my.IsValid() {
39		return false
40	}
41	return equalMessage(mx, my)
42}
43
44// equalMessage compares two messages.
45func equalMessage(mx, my pref.Message) bool {
46	if mx.Descriptor() != my.Descriptor() {
47		return false
48	}
49
50	nx := 0
51	equal := true
52	mx.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
53		nx++
54		vy := my.Get(fd)
55		equal = my.Has(fd) && equalField(fd, vx, vy)
56		return equal
57	})
58	if !equal {
59		return false
60	}
61	ny := 0
62	my.Range(func(fd pref.FieldDescriptor, vx pref.Value) bool {
63		ny++
64		return true
65	})
66	if nx != ny {
67		return false
68	}
69
70	return equalUnknown(mx.GetUnknown(), my.GetUnknown())
71}
72
73// equalField compares two fields.
74func equalField(fd pref.FieldDescriptor, x, y pref.Value) bool {
75	switch {
76	case fd.IsList():
77		return equalList(fd, x.List(), y.List())
78	case fd.IsMap():
79		return equalMap(fd, x.Map(), y.Map())
80	default:
81		return equalValue(fd, x, y)
82	}
83}
84
85// equalMap compares two maps.
86func equalMap(fd pref.FieldDescriptor, x, y pref.Map) bool {
87	if x.Len() != y.Len() {
88		return false
89	}
90	equal := true
91	x.Range(func(k pref.MapKey, vx pref.Value) bool {
92		vy := y.Get(k)
93		equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
94		return equal
95	})
96	return equal
97}
98
99// equalList compares two lists.
100func equalList(fd pref.FieldDescriptor, x, y pref.List) bool {
101	if x.Len() != y.Len() {
102		return false
103	}
104	for i := x.Len() - 1; i >= 0; i-- {
105		if !equalValue(fd, x.Get(i), y.Get(i)) {
106			return false
107		}
108	}
109	return true
110}
111
112// equalValue compares two singular values.
113func equalValue(fd pref.FieldDescriptor, x, y pref.Value) bool {
114	switch {
115	case fd.Message() != nil:
116		return equalMessage(x.Message(), y.Message())
117	case fd.Kind() == pref.BytesKind:
118		return bytes.Equal(x.Bytes(), y.Bytes())
119	case fd.Kind() == pref.FloatKind, fd.Kind() == pref.DoubleKind:
120		fx := x.Float()
121		fy := y.Float()
122		if math.IsNaN(fx) || math.IsNaN(fy) {
123			return math.IsNaN(fx) && math.IsNaN(fy)
124		}
125		return fx == fy
126	default:
127		return x.Interface() == y.Interface()
128	}
129}
130
131// equalUnknown compares unknown fields by direct comparison on the raw bytes
132// of each individual field number.
133func equalUnknown(x, y pref.RawFields) bool {
134	if len(x) != len(y) {
135		return false
136	}
137	if bytes.Equal([]byte(x), []byte(y)) {
138		return true
139	}
140
141	mx := make(map[pref.FieldNumber]pref.RawFields)
142	my := make(map[pref.FieldNumber]pref.RawFields)
143	for len(x) > 0 {
144		fnum, _, n := protowire.ConsumeField(x)
145		mx[fnum] = append(mx[fnum], x[:n]...)
146		x = x[n:]
147	}
148	for len(y) > 0 {
149		fnum, _, n := protowire.ConsumeField(y)
150		my[fnum] = append(my[fnum], y[:n]...)
151		y = y[n:]
152	}
153	return reflect.DeepEqual(mx, my)
154}
155