1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package internal_gengo
6
7import (
8	"unicode"
9	"unicode/utf8"
10
11	"google.golang.org/protobuf/compiler/protogen"
12	"google.golang.org/protobuf/encoding/protowire"
13
14	"google.golang.org/protobuf/types/descriptorpb"
15)
16
17type fileInfo struct {
18	*protogen.File
19
20	allEnums      []*enumInfo
21	allMessages   []*messageInfo
22	allExtensions []*extensionInfo
23
24	allEnumsByPtr         map[*enumInfo]int    // value is index into allEnums
25	allMessagesByPtr      map[*messageInfo]int // value is index into allMessages
26	allMessageFieldsByPtr map[*messageInfo]*structFields
27
28	// needRawDesc specifies whether the generator should emit logic to provide
29	// the legacy raw descriptor in GZIP'd form.
30	// This is updated by enum and message generation logic as necessary,
31	// and checked at the end of file generation.
32	needRawDesc bool
33}
34
35type structFields struct {
36	count      int
37	unexported map[int]string
38}
39
40func (sf *structFields) append(name string) {
41	if r, _ := utf8.DecodeRuneInString(name); !unicode.IsUpper(r) {
42		if sf.unexported == nil {
43			sf.unexported = make(map[int]string)
44		}
45		sf.unexported[sf.count] = name
46	}
47	sf.count++
48}
49
50func newFileInfo(file *protogen.File) *fileInfo {
51	f := &fileInfo{File: file}
52
53	// Collect all enums, messages, and extensions in "flattened ordering".
54	// See filetype.TypeBuilder.
55	var walkMessages func([]*protogen.Message, func(*protogen.Message))
56	walkMessages = func(messages []*protogen.Message, f func(*protogen.Message)) {
57		for _, m := range messages {
58			f(m)
59			walkMessages(m.Messages, f)
60		}
61	}
62	initEnumInfos := func(enums []*protogen.Enum) {
63		for _, enum := range enums {
64			f.allEnums = append(f.allEnums, newEnumInfo(f, enum))
65		}
66	}
67	initMessageInfos := func(messages []*protogen.Message) {
68		for _, message := range messages {
69			f.allMessages = append(f.allMessages, newMessageInfo(f, message))
70		}
71	}
72	initExtensionInfos := func(extensions []*protogen.Extension) {
73		for _, extension := range extensions {
74			f.allExtensions = append(f.allExtensions, newExtensionInfo(f, extension))
75		}
76	}
77	initEnumInfos(f.Enums)
78	initMessageInfos(f.Messages)
79	initExtensionInfos(f.Extensions)
80	walkMessages(f.Messages, func(m *protogen.Message) {
81		initEnumInfos(m.Enums)
82		initMessageInfos(m.Messages)
83		initExtensionInfos(m.Extensions)
84	})
85
86	// Derive a reverse mapping of enum and message pointers to their index
87	// in allEnums and allMessages.
88	if len(f.allEnums) > 0 {
89		f.allEnumsByPtr = make(map[*enumInfo]int)
90		for i, e := range f.allEnums {
91			f.allEnumsByPtr[e] = i
92		}
93	}
94	if len(f.allMessages) > 0 {
95		f.allMessagesByPtr = make(map[*messageInfo]int)
96		f.allMessageFieldsByPtr = make(map[*messageInfo]*structFields)
97		for i, m := range f.allMessages {
98			f.allMessagesByPtr[m] = i
99			f.allMessageFieldsByPtr[m] = new(structFields)
100		}
101	}
102
103	return f
104}
105
106type enumInfo struct {
107	*protogen.Enum
108
109	genJSONMethod    bool
110	genRawDescMethod bool
111}
112
113func newEnumInfo(f *fileInfo, enum *protogen.Enum) *enumInfo {
114	e := &enumInfo{Enum: enum}
115	e.genJSONMethod = true
116	e.genRawDescMethod = true
117	return e
118}
119
120type messageInfo struct {
121	*protogen.Message
122
123	genRawDescMethod  bool
124	genExtRangeMethod bool
125
126	isTracked bool
127	hasWeak   bool
128}
129
130func newMessageInfo(f *fileInfo, message *protogen.Message) *messageInfo {
131	m := &messageInfo{Message: message}
132	m.genRawDescMethod = true
133	m.genExtRangeMethod = true
134	m.isTracked = isTrackedMessage(m)
135	for _, field := range m.Fields {
136		m.hasWeak = m.hasWeak || field.Desc.IsWeak()
137	}
138	return m
139}
140
141// isTrackedMessage reports whether field tracking is enabled on the message.
142func isTrackedMessage(m *messageInfo) (tracked bool) {
143	const trackFieldUse_fieldNumber = 37383685
144
145	// Decode the option from unknown fields to avoid a dependency on the
146	// annotation proto from protoc-gen-go.
147	b := m.Desc.Options().(*descriptorpb.MessageOptions).ProtoReflect().GetUnknown()
148	for len(b) > 0 {
149		num, typ, n := protowire.ConsumeTag(b)
150		b = b[n:]
151		if num == trackFieldUse_fieldNumber && typ == protowire.VarintType {
152			v, _ := protowire.ConsumeVarint(b)
153			tracked = protowire.DecodeBool(v)
154		}
155		m := protowire.ConsumeFieldValue(num, typ, b)
156		b = b[m:]
157	}
158	return tracked
159}
160
161type extensionInfo struct {
162	*protogen.Extension
163}
164
165func newExtensionInfo(f *fileInfo, extension *protogen.Extension) *extensionInfo {
166	x := &extensionInfo{Extension: extension}
167	return x
168}
169