1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package filedesc
6
7import (
8	"fmt"
9	"math"
10	"sort"
11	"sync"
12
13	"google.golang.org/protobuf/internal/genid"
14
15	"google.golang.org/protobuf/encoding/protowire"
16	"google.golang.org/protobuf/internal/descfmt"
17	"google.golang.org/protobuf/internal/errors"
18	"google.golang.org/protobuf/internal/pragma"
19	"google.golang.org/protobuf/reflect/protoreflect"
20	pref "google.golang.org/protobuf/reflect/protoreflect"
21)
22
23type FileImports []pref.FileImport
24
25func (p *FileImports) Len() int                            { return len(*p) }
26func (p *FileImports) Get(i int) pref.FileImport           { return (*p)[i] }
27func (p *FileImports) Format(s fmt.State, r rune)          { descfmt.FormatList(s, r, p) }
28func (p *FileImports) ProtoInternal(pragma.DoNotImplement) {}
29
30type Names struct {
31	List []pref.Name
32	once sync.Once
33	has  map[pref.Name]int // protected by once
34}
35
36func (p *Names) Len() int                            { return len(p.List) }
37func (p *Names) Get(i int) pref.Name                 { return p.List[i] }
38func (p *Names) Has(s pref.Name) bool                { return p.lazyInit().has[s] > 0 }
39func (p *Names) Format(s fmt.State, r rune)          { descfmt.FormatList(s, r, p) }
40func (p *Names) ProtoInternal(pragma.DoNotImplement) {}
41func (p *Names) lazyInit() *Names {
42	p.once.Do(func() {
43		if len(p.List) > 0 {
44			p.has = make(map[pref.Name]int, len(p.List))
45			for _, s := range p.List {
46				p.has[s] = p.has[s] + 1
47			}
48		}
49	})
50	return p
51}
52
53// CheckValid reports any errors with the set of names with an error message
54// that completes the sentence: "ranges is invalid because it has ..."
55func (p *Names) CheckValid() error {
56	for s, n := range p.lazyInit().has {
57		switch {
58		case n > 1:
59			return errors.New("duplicate name: %q", s)
60		case false && !s.IsValid():
61			// NOTE: The C++ implementation does not validate the identifier.
62			// See https://github.com/protocolbuffers/protobuf/issues/6335.
63			return errors.New("invalid name: %q", s)
64		}
65	}
66	return nil
67}
68
69type EnumRanges struct {
70	List   [][2]pref.EnumNumber // start inclusive; end inclusive
71	once   sync.Once
72	sorted [][2]pref.EnumNumber // protected by once
73}
74
75func (p *EnumRanges) Len() int                     { return len(p.List) }
76func (p *EnumRanges) Get(i int) [2]pref.EnumNumber { return p.List[i] }
77func (p *EnumRanges) Has(n pref.EnumNumber) bool {
78	for ls := p.lazyInit().sorted; len(ls) > 0; {
79		i := len(ls) / 2
80		switch r := enumRange(ls[i]); {
81		case n < r.Start():
82			ls = ls[:i] // search lower
83		case n > r.End():
84			ls = ls[i+1:] // search upper
85		default:
86			return true
87		}
88	}
89	return false
90}
91func (p *EnumRanges) Format(s fmt.State, r rune)          { descfmt.FormatList(s, r, p) }
92func (p *EnumRanges) ProtoInternal(pragma.DoNotImplement) {}
93func (p *EnumRanges) lazyInit() *EnumRanges {
94	p.once.Do(func() {
95		p.sorted = append(p.sorted, p.List...)
96		sort.Slice(p.sorted, func(i, j int) bool {
97			return p.sorted[i][0] < p.sorted[j][0]
98		})
99	})
100	return p
101}
102
103// CheckValid reports any errors with the set of names with an error message
104// that completes the sentence: "ranges is invalid because it has ..."
105func (p *EnumRanges) CheckValid() error {
106	var rp enumRange
107	for i, r := range p.lazyInit().sorted {
108		r := enumRange(r)
109		switch {
110		case !(r.Start() <= r.End()):
111			return errors.New("invalid range: %v", r)
112		case !(rp.End() < r.Start()) && i > 0:
113			return errors.New("overlapping ranges: %v with %v", rp, r)
114		}
115		rp = r
116	}
117	return nil
118}
119
120type enumRange [2]protoreflect.EnumNumber
121
122func (r enumRange) Start() protoreflect.EnumNumber { return r[0] } // inclusive
123func (r enumRange) End() protoreflect.EnumNumber   { return r[1] } // inclusive
124func (r enumRange) String() string {
125	if r.Start() == r.End() {
126		return fmt.Sprintf("%d", r.Start())
127	}
128	return fmt.Sprintf("%d to %d", r.Start(), r.End())
129}
130
131type FieldRanges struct {
132	List   [][2]pref.FieldNumber // start inclusive; end exclusive
133	once   sync.Once
134	sorted [][2]pref.FieldNumber // protected by once
135}
136
137func (p *FieldRanges) Len() int                      { return len(p.List) }
138func (p *FieldRanges) Get(i int) [2]pref.FieldNumber { return p.List[i] }
139func (p *FieldRanges) Has(n pref.FieldNumber) bool {
140	for ls := p.lazyInit().sorted; len(ls) > 0; {
141		i := len(ls) / 2
142		switch r := fieldRange(ls[i]); {
143		case n < r.Start():
144			ls = ls[:i] // search lower
145		case n > r.End():
146			ls = ls[i+1:] // search upper
147		default:
148			return true
149		}
150	}
151	return false
152}
153func (p *FieldRanges) Format(s fmt.State, r rune)          { descfmt.FormatList(s, r, p) }
154func (p *FieldRanges) ProtoInternal(pragma.DoNotImplement) {}
155func (p *FieldRanges) lazyInit() *FieldRanges {
156	p.once.Do(func() {
157		p.sorted = append(p.sorted, p.List...)
158		sort.Slice(p.sorted, func(i, j int) bool {
159			return p.sorted[i][0] < p.sorted[j][0]
160		})
161	})
162	return p
163}
164
165// CheckValid reports any errors with the set of ranges with an error message
166// that completes the sentence: "ranges is invalid because it has ..."
167func (p *FieldRanges) CheckValid(isMessageSet bool) error {
168	var rp fieldRange
169	for i, r := range p.lazyInit().sorted {
170		r := fieldRange(r)
171		switch {
172		case !isValidFieldNumber(r.Start(), isMessageSet):
173			return errors.New("invalid field number: %d", r.Start())
174		case !isValidFieldNumber(r.End(), isMessageSet):
175			return errors.New("invalid field number: %d", r.End())
176		case !(r.Start() <= r.End()):
177			return errors.New("invalid range: %v", r)
178		case !(rp.End() < r.Start()) && i > 0:
179			return errors.New("overlapping ranges: %v with %v", rp, r)
180		}
181		rp = r
182	}
183	return nil
184}
185
186// isValidFieldNumber reports whether the field number is valid.
187// Unlike the FieldNumber.IsValid method, it allows ranges that cover the
188// reserved number range.
189func isValidFieldNumber(n protoreflect.FieldNumber, isMessageSet bool) bool {
190	return protowire.MinValidNumber <= n && (n <= protowire.MaxValidNumber || isMessageSet)
191}
192
193// CheckOverlap reports an error if p and q overlap.
194func (p *FieldRanges) CheckOverlap(q *FieldRanges) error {
195	rps := p.lazyInit().sorted
196	rqs := q.lazyInit().sorted
197	for pi, qi := 0, 0; pi < len(rps) && qi < len(rqs); {
198		rp := fieldRange(rps[pi])
199		rq := fieldRange(rqs[qi])
200		if !(rp.End() < rq.Start() || rq.End() < rp.Start()) {
201			return errors.New("overlapping ranges: %v with %v", rp, rq)
202		}
203		if rp.Start() < rq.Start() {
204			pi++
205		} else {
206			qi++
207		}
208	}
209	return nil
210}
211
212type fieldRange [2]protoreflect.FieldNumber
213
214func (r fieldRange) Start() protoreflect.FieldNumber { return r[0] }     // inclusive
215func (r fieldRange) End() protoreflect.FieldNumber   { return r[1] - 1 } // inclusive
216func (r fieldRange) String() string {
217	if r.Start() == r.End() {
218		return fmt.Sprintf("%d", r.Start())
219	}
220	return fmt.Sprintf("%d to %d", r.Start(), r.End())
221}
222
223type FieldNumbers struct {
224	List []pref.FieldNumber
225	once sync.Once
226	has  map[pref.FieldNumber]struct{} // protected by once
227}
228
229func (p *FieldNumbers) Len() int                   { return len(p.List) }
230func (p *FieldNumbers) Get(i int) pref.FieldNumber { return p.List[i] }
231func (p *FieldNumbers) Has(n pref.FieldNumber) bool {
232	p.once.Do(func() {
233		if len(p.List) > 0 {
234			p.has = make(map[pref.FieldNumber]struct{}, len(p.List))
235			for _, n := range p.List {
236				p.has[n] = struct{}{}
237			}
238		}
239	})
240	_, ok := p.has[n]
241	return ok
242}
243func (p *FieldNumbers) Format(s fmt.State, r rune)          { descfmt.FormatList(s, r, p) }
244func (p *FieldNumbers) ProtoInternal(pragma.DoNotImplement) {}
245
246type OneofFields struct {
247	List   []pref.FieldDescriptor
248	once   sync.Once
249	byName map[pref.Name]pref.FieldDescriptor        // protected by once
250	byJSON map[string]pref.FieldDescriptor           // protected by once
251	byText map[string]pref.FieldDescriptor           // protected by once
252	byNum  map[pref.FieldNumber]pref.FieldDescriptor // protected by once
253}
254
255func (p *OneofFields) Len() int                                         { return len(p.List) }
256func (p *OneofFields) Get(i int) pref.FieldDescriptor                   { return p.List[i] }
257func (p *OneofFields) ByName(s pref.Name) pref.FieldDescriptor          { return p.lazyInit().byName[s] }
258func (p *OneofFields) ByJSONName(s string) pref.FieldDescriptor         { return p.lazyInit().byJSON[s] }
259func (p *OneofFields) ByTextName(s string) pref.FieldDescriptor         { return p.lazyInit().byText[s] }
260func (p *OneofFields) ByNumber(n pref.FieldNumber) pref.FieldDescriptor { return p.lazyInit().byNum[n] }
261func (p *OneofFields) Format(s fmt.State, r rune)                       { descfmt.FormatList(s, r, p) }
262func (p *OneofFields) ProtoInternal(pragma.DoNotImplement)              {}
263
264func (p *OneofFields) lazyInit() *OneofFields {
265	p.once.Do(func() {
266		if len(p.List) > 0 {
267			p.byName = make(map[pref.Name]pref.FieldDescriptor, len(p.List))
268			p.byJSON = make(map[string]pref.FieldDescriptor, len(p.List))
269			p.byText = make(map[string]pref.FieldDescriptor, len(p.List))
270			p.byNum = make(map[pref.FieldNumber]pref.FieldDescriptor, len(p.List))
271			for _, f := range p.List {
272				// Field names and numbers are guaranteed to be unique.
273				p.byName[f.Name()] = f
274				p.byJSON[f.JSONName()] = f
275				p.byText[f.TextName()] = f
276				p.byNum[f.Number()] = f
277			}
278		}
279	})
280	return p
281}
282
283type SourceLocations struct {
284	// List is a list of SourceLocations.
285	// The SourceLocation.Next field does not need to be populated
286	// as it will be lazily populated upon first need.
287	List []pref.SourceLocation
288
289	// File is the parent file descriptor that these locations are relative to.
290	// If non-nil, ByDescriptor verifies that the provided descriptor
291	// is a child of this file descriptor.
292	File pref.FileDescriptor
293
294	once   sync.Once
295	byPath map[pathKey]int
296}
297
298func (p *SourceLocations) Len() int                      { return len(p.List) }
299func (p *SourceLocations) Get(i int) pref.SourceLocation { return p.lazyInit().List[i] }
300func (p *SourceLocations) byKey(k pathKey) pref.SourceLocation {
301	if i, ok := p.lazyInit().byPath[k]; ok {
302		return p.List[i]
303	}
304	return pref.SourceLocation{}
305}
306func (p *SourceLocations) ByPath(path pref.SourcePath) pref.SourceLocation {
307	return p.byKey(newPathKey(path))
308}
309func (p *SourceLocations) ByDescriptor(desc pref.Descriptor) pref.SourceLocation {
310	if p.File != nil && desc != nil && p.File != desc.ParentFile() {
311		return pref.SourceLocation{} // mismatching parent files
312	}
313	var pathArr [16]int32
314	path := pathArr[:0]
315	for {
316		switch desc.(type) {
317		case pref.FileDescriptor:
318			// Reverse the path since it was constructed in reverse.
319			for i, j := 0, len(path)-1; i < j; i, j = i+1, j-1 {
320				path[i], path[j] = path[j], path[i]
321			}
322			return p.byKey(newPathKey(path))
323		case pref.MessageDescriptor:
324			path = append(path, int32(desc.Index()))
325			desc = desc.Parent()
326			switch desc.(type) {
327			case pref.FileDescriptor:
328				path = append(path, int32(genid.FileDescriptorProto_MessageType_field_number))
329			case pref.MessageDescriptor:
330				path = append(path, int32(genid.DescriptorProto_NestedType_field_number))
331			default:
332				return pref.SourceLocation{}
333			}
334		case pref.FieldDescriptor:
335			isExtension := desc.(pref.FieldDescriptor).IsExtension()
336			path = append(path, int32(desc.Index()))
337			desc = desc.Parent()
338			if isExtension {
339				switch desc.(type) {
340				case pref.FileDescriptor:
341					path = append(path, int32(genid.FileDescriptorProto_Extension_field_number))
342				case pref.MessageDescriptor:
343					path = append(path, int32(genid.DescriptorProto_Extension_field_number))
344				default:
345					return pref.SourceLocation{}
346				}
347			} else {
348				switch desc.(type) {
349				case pref.MessageDescriptor:
350					path = append(path, int32(genid.DescriptorProto_Field_field_number))
351				default:
352					return pref.SourceLocation{}
353				}
354			}
355		case pref.OneofDescriptor:
356			path = append(path, int32(desc.Index()))
357			desc = desc.Parent()
358			switch desc.(type) {
359			case pref.MessageDescriptor:
360				path = append(path, int32(genid.DescriptorProto_OneofDecl_field_number))
361			default:
362				return pref.SourceLocation{}
363			}
364		case pref.EnumDescriptor:
365			path = append(path, int32(desc.Index()))
366			desc = desc.Parent()
367			switch desc.(type) {
368			case pref.FileDescriptor:
369				path = append(path, int32(genid.FileDescriptorProto_EnumType_field_number))
370			case pref.MessageDescriptor:
371				path = append(path, int32(genid.DescriptorProto_EnumType_field_number))
372			default:
373				return pref.SourceLocation{}
374			}
375		case pref.EnumValueDescriptor:
376			path = append(path, int32(desc.Index()))
377			desc = desc.Parent()
378			switch desc.(type) {
379			case pref.EnumDescriptor:
380				path = append(path, int32(genid.EnumDescriptorProto_Value_field_number))
381			default:
382				return pref.SourceLocation{}
383			}
384		case pref.ServiceDescriptor:
385			path = append(path, int32(desc.Index()))
386			desc = desc.Parent()
387			switch desc.(type) {
388			case pref.FileDescriptor:
389				path = append(path, int32(genid.FileDescriptorProto_Service_field_number))
390			default:
391				return pref.SourceLocation{}
392			}
393		case pref.MethodDescriptor:
394			path = append(path, int32(desc.Index()))
395			desc = desc.Parent()
396			switch desc.(type) {
397			case pref.ServiceDescriptor:
398				path = append(path, int32(genid.ServiceDescriptorProto_Method_field_number))
399			default:
400				return pref.SourceLocation{}
401			}
402		default:
403			return pref.SourceLocation{}
404		}
405	}
406}
407func (p *SourceLocations) lazyInit() *SourceLocations {
408	p.once.Do(func() {
409		if len(p.List) > 0 {
410			// Collect all the indexes for a given path.
411			pathIdxs := make(map[pathKey][]int, len(p.List))
412			for i, l := range p.List {
413				k := newPathKey(l.Path)
414				pathIdxs[k] = append(pathIdxs[k], i)
415			}
416
417			// Update the next index for all locations.
418			p.byPath = make(map[pathKey]int, len(p.List))
419			for k, idxs := range pathIdxs {
420				for i := 0; i < len(idxs)-1; i++ {
421					p.List[idxs[i]].Next = idxs[i+1]
422				}
423				p.List[idxs[len(idxs)-1]].Next = 0
424				p.byPath[k] = idxs[0] // record the first location for this path
425			}
426		}
427	})
428	return p
429}
430func (p *SourceLocations) ProtoInternal(pragma.DoNotImplement) {}
431
432// pathKey is a comparable representation of protoreflect.SourcePath.
433type pathKey struct {
434	arr [16]uint8 // first n-1 path segments; last element is the length
435	str string    // used if the path does not fit in arr
436}
437
438func newPathKey(p pref.SourcePath) (k pathKey) {
439	if len(p) < len(k.arr) {
440		for i, ps := range p {
441			if ps < 0 || math.MaxUint8 <= ps {
442				return pathKey{str: p.String()}
443			}
444			k.arr[i] = uint8(ps)
445		}
446		k.arr[len(k.arr)-1] = uint8(len(p))
447		return k
448	}
449	return pathKey{str: p.String()}
450}
451