1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package protorange provides functionality to traverse a message value.
6package protorange
7
8import (
9	"bytes"
10	"errors"
11
12	"google.golang.org/protobuf/internal/genid"
13	"google.golang.org/protobuf/internal/order"
14	"google.golang.org/protobuf/proto"
15	"google.golang.org/protobuf/reflect/protopath"
16	"google.golang.org/protobuf/reflect/protoreflect"
17	"google.golang.org/protobuf/reflect/protoregistry"
18)
19
20var (
21	// Break breaks traversal of children in the current value.
22	// It has no effect when traversing values that are not composite types
23	// (e.g., messages, lists, and maps).
24	Break = errors.New("break traversal of children in current value")
25
26	// Terminate terminates the entire range operation.
27	// All necessary Pop operations continue to be called.
28	Terminate = errors.New("terminate range operation")
29)
30
31// Range performs a depth-first traversal over reachable values in a message.
32//
33// See Options.Range for details.
34func Range(m protoreflect.Message, f func(protopath.Values) error) error {
35	return Options{}.Range(m, f, nil)
36}
37
38// Options configures traversal of a message value tree.
39type Options struct {
40	// Stable specifies whether to visit message fields and map entries
41	// in a stable ordering. If false, then the ordering is undefined and
42	// may be non-deterministic.
43	//
44	// Message fields are visited in ascending order by field number.
45	// Map entries are visited in ascending order, where
46	// boolean keys are ordered such that false sorts before true,
47	// numeric keys are ordered based on the numeric value, and
48	// string keys are lexicographically ordered by Unicode codepoints.
49	Stable bool
50
51	// Resolver is used for looking up types when expanding google.protobuf.Any
52	// messages. If nil, this defaults to using protoregistry.GlobalTypes.
53	// To prevent expansion of Any messages, pass an empty protoregistry.Types:
54	//
55	//	Options{Resolver: (*protoregistry.Types)(nil)}
56	//
57	Resolver interface {
58		protoregistry.ExtensionTypeResolver
59		protoregistry.MessageTypeResolver
60	}
61}
62
63// Range performs a depth-first traversal over reachable values in a message.
64// The first push and the last pop are to push/pop a protopath.Root step.
65// If push or pop return any non-nil error (other than Break or Terminate),
66// it terminates the traversal and is returned by Range.
67//
68// The rules for traversing a message is as follows:
69//
70// • For messages, iterate over every populated known and extension field.
71// Each field is preceded by a push of a protopath.FieldAccess step,
72// followed by recursive application of the rules on the field value,
73// and succeeded by a pop of that step.
74// If the message has unknown fields, then push an protopath.UnknownAccess step
75// followed immediately by pop of that step.
76//
77// • As an exception to the above rule, if the current message is a
78// google.protobuf.Any message, expand the underlying message (if resolvable).
79// The expanded message is preceded by a push of a protopath.AnyExpand step,
80// followed by recursive application of the rules on the underlying message,
81// and succeeded by a pop of that step. Mutations to the expanded message
82// are written back to the Any message when popping back out.
83//
84// • For lists, iterate over every element. Each element is preceded by a push
85// of a protopath.ListIndex step, followed by recursive application of the rules
86// on the list element, and succeeded by a pop of that step.
87//
88// • For maps, iterate over every entry. Each entry is preceded by a push
89// of a protopath.MapIndex step, followed by recursive application of the rules
90// on the map entry value, and succeeded by a pop of that step.
91//
92// Mutations should only be made to the last value, otherwise the effects on
93// traversal will be undefined. If the mutation is made to the last value
94// during to a push, then the effects of the mutation will affect traversal.
95// For example, if the last value is currently a message, and the push function
96// populates a few fields in that message, then the newly modified fields
97// will be traversed.
98//
99// The protopath.Values provided to push functions is only valid until the
100// corresponding pop call and the values provided to a pop call is only valid
101// for the duration of the pop call itself.
102func (o Options) Range(m protoreflect.Message, push, pop func(protopath.Values) error) error {
103	var err error
104	p := new(protopath.Values)
105	if o.Resolver == nil {
106		o.Resolver = protoregistry.GlobalTypes
107	}
108
109	pushStep(p, protopath.Root(m.Descriptor()), protoreflect.ValueOfMessage(m))
110	if push != nil {
111		err = amendError(err, push(*p))
112	}
113	if err == nil {
114		err = o.rangeMessage(p, m, push, pop)
115	}
116	if pop != nil {
117		err = amendError(err, pop(*p))
118	}
119	popStep(p)
120
121	if err == Break || err == Terminate {
122		err = nil
123	}
124	return err
125}
126
127func (o Options) rangeMessage(p *protopath.Values, m protoreflect.Message, push, pop func(protopath.Values) error) (err error) {
128	if ok, err := o.rangeAnyMessage(p, m, push, pop); ok {
129		return err
130	}
131
132	fieldOrder := order.AnyFieldOrder
133	if o.Stable {
134		fieldOrder = order.NumberFieldOrder
135	}
136	order.RangeFields(m, fieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
137		pushStep(p, protopath.FieldAccess(fd), v)
138		if push != nil {
139			err = amendError(err, push(*p))
140		}
141		if err == nil {
142			switch {
143			case fd.IsMap():
144				err = o.rangeMap(p, fd, v.Map(), push, pop)
145			case fd.IsList():
146				err = o.rangeList(p, fd, v.List(), push, pop)
147			case fd.Message() != nil:
148				err = o.rangeMessage(p, v.Message(), push, pop)
149			}
150		}
151		if pop != nil {
152			err = amendError(err, pop(*p))
153		}
154		popStep(p)
155		return err == nil
156	})
157
158	if b := m.GetUnknown(); len(b) > 0 && err == nil {
159		pushStep(p, protopath.UnknownAccess(), protoreflect.ValueOfBytes(b))
160		if push != nil {
161			err = amendError(err, push(*p))
162		}
163		if pop != nil {
164			err = amendError(err, pop(*p))
165		}
166		popStep(p)
167	}
168
169	if err == Break {
170		err = nil
171	}
172	return err
173}
174
175func (o Options) rangeAnyMessage(p *protopath.Values, m protoreflect.Message, push, pop func(protopath.Values) error) (ok bool, err error) {
176	md := m.Descriptor()
177	if md.FullName() != "google.protobuf.Any" {
178		return false, nil
179	}
180
181	fds := md.Fields()
182	url := m.Get(fds.ByNumber(genid.Any_TypeUrl_field_number)).String()
183	val := m.Get(fds.ByNumber(genid.Any_Value_field_number)).Bytes()
184	mt, errFind := o.Resolver.FindMessageByURL(url)
185	if errFind != nil {
186		return false, nil
187	}
188
189	// Unmarshal the raw encoded message value into a structured message value.
190	m2 := mt.New()
191	errUnmarshal := proto.UnmarshalOptions{
192		Merge:        true,
193		AllowPartial: true,
194		Resolver:     o.Resolver,
195	}.Unmarshal(val, m2.Interface())
196	if errUnmarshal != nil {
197		// If the the underlying message cannot be unmarshaled,
198		// then just treat this as an normal message type.
199		return false, nil
200	}
201
202	// Marshal Any before ranging to detect possible mutations.
203	b1, errMarshal := proto.MarshalOptions{
204		AllowPartial:  true,
205		Deterministic: true,
206	}.Marshal(m2.Interface())
207	if errMarshal != nil {
208		return true, errMarshal
209	}
210
211	pushStep(p, protopath.AnyExpand(m2.Descriptor()), protoreflect.ValueOfMessage(m2))
212	if push != nil {
213		err = amendError(err, push(*p))
214	}
215	if err == nil {
216		err = o.rangeMessage(p, m2, push, pop)
217	}
218	if pop != nil {
219		err = amendError(err, pop(*p))
220	}
221	popStep(p)
222
223	// Marshal Any after ranging to detect possible mutations.
224	b2, errMarshal := proto.MarshalOptions{
225		AllowPartial:  true,
226		Deterministic: true,
227	}.Marshal(m2.Interface())
228	if errMarshal != nil {
229		return true, errMarshal
230	}
231
232	// Mutations detected, write the new sequence of bytes to the Any message.
233	if !bytes.Equal(b1, b2) {
234		m.Set(fds.ByNumber(genid.Any_Value_field_number), protoreflect.ValueOfBytes(b2))
235	}
236
237	if err == Break {
238		err = nil
239	}
240	return true, err
241}
242
243func (o Options) rangeList(p *protopath.Values, fd protoreflect.FieldDescriptor, ls protoreflect.List, push, pop func(protopath.Values) error) (err error) {
244	for i := 0; i < ls.Len() && err == nil; i++ {
245		v := ls.Get(i)
246		pushStep(p, protopath.ListIndex(i), v)
247		if push != nil {
248			err = amendError(err, push(*p))
249		}
250		if err == nil && fd.Message() != nil {
251			err = o.rangeMessage(p, v.Message(), push, pop)
252		}
253		if pop != nil {
254			err = amendError(err, pop(*p))
255		}
256		popStep(p)
257	}
258
259	if err == Break {
260		err = nil
261	}
262	return err
263}
264
265func (o Options) rangeMap(p *protopath.Values, fd protoreflect.FieldDescriptor, ms protoreflect.Map, push, pop func(protopath.Values) error) (err error) {
266	keyOrder := order.AnyKeyOrder
267	if o.Stable {
268		keyOrder = order.GenericKeyOrder
269	}
270	order.RangeEntries(ms, keyOrder, func(k protoreflect.MapKey, v protoreflect.Value) bool {
271		pushStep(p, protopath.MapIndex(k), v)
272		if push != nil {
273			err = amendError(err, push(*p))
274		}
275		if err == nil && fd.MapValue().Message() != nil {
276			err = o.rangeMessage(p, v.Message(), push, pop)
277		}
278		if pop != nil {
279			err = amendError(err, pop(*p))
280		}
281		popStep(p)
282		return err == nil
283	})
284
285	if err == Break {
286		err = nil
287	}
288	return err
289}
290
291func pushStep(p *protopath.Values, s protopath.Step, v protoreflect.Value) {
292	p.Path = append(p.Path, s)
293	p.Values = append(p.Values, v)
294}
295
296func popStep(p *protopath.Values) {
297	p.Path = p.Path[:len(p.Path)-1]
298	p.Values = p.Values[:len(p.Values)-1]
299}
300
301// amendError amends the previous error with the current error if it is
302// considered more serious. The precedence order for errors is:
303//	nil < Break < Terminate < previous non-nil < current non-nil
304func amendError(prev, curr error) error {
305	switch {
306	case curr == nil:
307		return prev
308	case curr == Break && prev != nil:
309		return prev
310	case curr == Terminate && prev != nil && prev != Break:
311		return prev
312	default:
313		return curr
314	}
315}
316