1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package order provides ordered access to messages and maps.
6package order
7
8import (
9	"sort"
10	"sync"
11
12	pref "google.golang.org/protobuf/reflect/protoreflect"
13)
14
15type messageField struct {
16	fd pref.FieldDescriptor
17	v  pref.Value
18}
19
20var messageFieldPool = sync.Pool{
21	New: func() interface{} { return new([]messageField) },
22}
23
24type (
25	// FieldRnger is an interface for visiting all fields in a message.
26	// The protoreflect.Message type implements this interface.
27	FieldRanger interface{ Range(VisitField) }
28	// VisitField is called everytime a message field is visited.
29	VisitField = func(pref.FieldDescriptor, pref.Value) bool
30)
31
32// RangeFields iterates over the fields of fs according to the specified order.
33func RangeFields(fs FieldRanger, less FieldOrder, fn VisitField) {
34	if less == nil {
35		fs.Range(fn)
36		return
37	}
38
39	// Obtain a pre-allocated scratch buffer.
40	p := messageFieldPool.Get().(*[]messageField)
41	fields := (*p)[:0]
42	defer func() {
43		if cap(fields) < 1024 {
44			*p = fields
45			messageFieldPool.Put(p)
46		}
47	}()
48
49	// Collect all fields in the message and sort them.
50	fs.Range(func(fd pref.FieldDescriptor, v pref.Value) bool {
51		fields = append(fields, messageField{fd, v})
52		return true
53	})
54	sort.Slice(fields, func(i, j int) bool {
55		return less(fields[i].fd, fields[j].fd)
56	})
57
58	// Visit the fields in the specified ordering.
59	for _, f := range fields {
60		if !fn(f.fd, f.v) {
61			return
62		}
63	}
64}
65
66type mapEntry struct {
67	k pref.MapKey
68	v pref.Value
69}
70
71var mapEntryPool = sync.Pool{
72	New: func() interface{} { return new([]mapEntry) },
73}
74
75type (
76	// EntryRanger is an interface for visiting all fields in a message.
77	// The protoreflect.Map type implements this interface.
78	EntryRanger interface{ Range(VisitEntry) }
79	// VisitEntry is called everytime a map entry is visited.
80	VisitEntry = func(pref.MapKey, pref.Value) bool
81)
82
83// RangeEntries iterates over the entries of es according to the specified order.
84func RangeEntries(es EntryRanger, less KeyOrder, fn VisitEntry) {
85	if less == nil {
86		es.Range(fn)
87		return
88	}
89
90	// Obtain a pre-allocated scratch buffer.
91	p := mapEntryPool.Get().(*[]mapEntry)
92	entries := (*p)[:0]
93	defer func() {
94		if cap(entries) < 1024 {
95			*p = entries
96			mapEntryPool.Put(p)
97		}
98	}()
99
100	// Collect all entries in the map and sort them.
101	es.Range(func(k pref.MapKey, v pref.Value) bool {
102		entries = append(entries, mapEntry{k, v})
103		return true
104	})
105	sort.Slice(entries, func(i, j int) bool {
106		return less(entries[i].k, entries[j].k)
107	})
108
109	// Visit the entries in the specified ordering.
110	for _, e := range entries {
111		if !fn(e.k, e.v) {
112			return
113		}
114	}
115}
116