1// Protocol Buffers for Go with Gadgets
2//
3// Copyright (c) 2013, The GoGo Authors. All rights reserved.
4// http://github.com/gogo/protobuf
5//
6// Redistribution and use in source and binary forms, with or without
7// modification, are permitted provided that the following conditions are
8// met:
9//
10//     * Redistributions of source code must retain the above copyright
11// notice, this list of conditions and the following disclaimer.
12//     * Redistributions in binary form must reproduce the above
13// copyright notice, this list of conditions and the following disclaimer
14// in the documentation and/or other materials provided with the
15// distribution.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
21// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
22// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
23// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
25// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
29package descriptor
30
31import (
32	"strings"
33)
34
35func (msg *DescriptorProto) GetMapFields() (*FieldDescriptorProto, *FieldDescriptorProto) {
36	if !msg.GetOptions().GetMapEntry() {
37		return nil, nil
38	}
39	return msg.GetField()[0], msg.GetField()[1]
40}
41
42func dotToUnderscore(r rune) rune {
43	if r == '.' {
44		return '_'
45	}
46	return r
47}
48
49func (field *FieldDescriptorProto) WireType() (wire int) {
50	switch *field.Type {
51	case FieldDescriptorProto_TYPE_DOUBLE:
52		return 1
53	case FieldDescriptorProto_TYPE_FLOAT:
54		return 5
55	case FieldDescriptorProto_TYPE_INT64:
56		return 0
57	case FieldDescriptorProto_TYPE_UINT64:
58		return 0
59	case FieldDescriptorProto_TYPE_INT32:
60		return 0
61	case FieldDescriptorProto_TYPE_UINT32:
62		return 0
63	case FieldDescriptorProto_TYPE_FIXED64:
64		return 1
65	case FieldDescriptorProto_TYPE_FIXED32:
66		return 5
67	case FieldDescriptorProto_TYPE_BOOL:
68		return 0
69	case FieldDescriptorProto_TYPE_STRING:
70		return 2
71	case FieldDescriptorProto_TYPE_GROUP:
72		return 2
73	case FieldDescriptorProto_TYPE_MESSAGE:
74		return 2
75	case FieldDescriptorProto_TYPE_BYTES:
76		return 2
77	case FieldDescriptorProto_TYPE_ENUM:
78		return 0
79	case FieldDescriptorProto_TYPE_SFIXED32:
80		return 5
81	case FieldDescriptorProto_TYPE_SFIXED64:
82		return 1
83	case FieldDescriptorProto_TYPE_SINT32:
84		return 0
85	case FieldDescriptorProto_TYPE_SINT64:
86		return 0
87	}
88	panic("unreachable")
89}
90
91func (field *FieldDescriptorProto) GetKeyUint64() (x uint64) {
92	packed := field.IsPacked()
93	wireType := field.WireType()
94	fieldNumber := field.GetNumber()
95	if packed {
96		wireType = 2
97	}
98	x = uint64(uint32(fieldNumber)<<3 | uint32(wireType))
99	return x
100}
101
102func (field *FieldDescriptorProto) GetKey3Uint64() (x uint64) {
103	packed := field.IsPacked3()
104	wireType := field.WireType()
105	fieldNumber := field.GetNumber()
106	if packed {
107		wireType = 2
108	}
109	x = uint64(uint32(fieldNumber)<<3 | uint32(wireType))
110	return x
111}
112
113func (field *FieldDescriptorProto) GetKey() []byte {
114	x := field.GetKeyUint64()
115	i := 0
116	keybuf := make([]byte, 0)
117	for i = 0; x > 127; i++ {
118		keybuf = append(keybuf, 0x80|uint8(x&0x7F))
119		x >>= 7
120	}
121	keybuf = append(keybuf, uint8(x))
122	return keybuf
123}
124
125func (field *FieldDescriptorProto) GetKey3() []byte {
126	x := field.GetKey3Uint64()
127	i := 0
128	keybuf := make([]byte, 0)
129	for i = 0; x > 127; i++ {
130		keybuf = append(keybuf, 0x80|uint8(x&0x7F))
131		x >>= 7
132	}
133	keybuf = append(keybuf, uint8(x))
134	return keybuf
135}
136
137func (desc *FileDescriptorSet) GetField(packageName, messageName, fieldName string) *FieldDescriptorProto {
138	msg := desc.GetMessage(packageName, messageName)
139	if msg == nil {
140		return nil
141	}
142	for _, field := range msg.GetField() {
143		if field.GetName() == fieldName {
144			return field
145		}
146	}
147	return nil
148}
149
150func (file *FileDescriptorProto) GetMessage(typeName string) *DescriptorProto {
151	for _, msg := range file.GetMessageType() {
152		if msg.GetName() == typeName {
153			return msg
154		}
155		nes := file.GetNestedMessage(msg, strings.TrimPrefix(typeName, msg.GetName()+"."))
156		if nes != nil {
157			return nes
158		}
159	}
160	return nil
161}
162
163func (file *FileDescriptorProto) GetNestedMessage(msg *DescriptorProto, typeName string) *DescriptorProto {
164	for _, nes := range msg.GetNestedType() {
165		if nes.GetName() == typeName {
166			return nes
167		}
168		res := file.GetNestedMessage(nes, strings.TrimPrefix(typeName, nes.GetName()+"."))
169		if res != nil {
170			return res
171		}
172	}
173	return nil
174}
175
176func (desc *FileDescriptorSet) GetMessage(packageName string, typeName string) *DescriptorProto {
177	for _, file := range desc.GetFile() {
178		if strings.Map(dotToUnderscore, file.GetPackage()) != strings.Map(dotToUnderscore, packageName) {
179			continue
180		}
181		for _, msg := range file.GetMessageType() {
182			if msg.GetName() == typeName {
183				return msg
184			}
185		}
186		for _, msg := range file.GetMessageType() {
187			for _, nes := range msg.GetNestedType() {
188				if nes.GetName() == typeName {
189					return nes
190				}
191				if msg.GetName()+"."+nes.GetName() == typeName {
192					return nes
193				}
194			}
195		}
196	}
197	return nil
198}
199
200func (desc *FileDescriptorSet) IsProto3(packageName string, typeName string) bool {
201	for _, file := range desc.GetFile() {
202		if strings.Map(dotToUnderscore, file.GetPackage()) != strings.Map(dotToUnderscore, packageName) {
203			continue
204		}
205		for _, msg := range file.GetMessageType() {
206			if msg.GetName() == typeName {
207				return file.GetSyntax() == "proto3"
208			}
209		}
210		for _, msg := range file.GetMessageType() {
211			for _, nes := range msg.GetNestedType() {
212				if nes.GetName() == typeName {
213					return file.GetSyntax() == "proto3"
214				}
215				if msg.GetName()+"."+nes.GetName() == typeName {
216					return file.GetSyntax() == "proto3"
217				}
218			}
219		}
220	}
221	return false
222}
223
224func (msg *DescriptorProto) IsExtendable() bool {
225	return len(msg.GetExtensionRange()) > 0
226}
227
228func (desc *FileDescriptorSet) FindExtension(packageName string, typeName string, fieldName string) (extPackageName string, field *FieldDescriptorProto) {
229	parent := desc.GetMessage(packageName, typeName)
230	if parent == nil {
231		return "", nil
232	}
233	if !parent.IsExtendable() {
234		return "", nil
235	}
236	extendee := "." + packageName + "." + typeName
237	for _, file := range desc.GetFile() {
238		for _, ext := range file.GetExtension() {
239			if strings.Map(dotToUnderscore, file.GetPackage()) == strings.Map(dotToUnderscore, packageName) {
240				if !(ext.GetExtendee() == typeName || ext.GetExtendee() == extendee) {
241					continue
242				}
243			} else {
244				if ext.GetExtendee() != extendee {
245					continue
246				}
247			}
248			if ext.GetName() == fieldName {
249				return file.GetPackage(), ext
250			}
251		}
252	}
253	return "", nil
254}
255
256func (desc *FileDescriptorSet) FindExtensionByFieldNumber(packageName string, typeName string, fieldNum int32) (extPackageName string, field *FieldDescriptorProto) {
257	parent := desc.GetMessage(packageName, typeName)
258	if parent == nil {
259		return "", nil
260	}
261	if !parent.IsExtendable() {
262		return "", nil
263	}
264	extendee := "." + packageName + "." + typeName
265	for _, file := range desc.GetFile() {
266		for _, ext := range file.GetExtension() {
267			if strings.Map(dotToUnderscore, file.GetPackage()) == strings.Map(dotToUnderscore, packageName) {
268				if !(ext.GetExtendee() == typeName || ext.GetExtendee() == extendee) {
269					continue
270				}
271			} else {
272				if ext.GetExtendee() != extendee {
273					continue
274				}
275			}
276			if ext.GetNumber() == fieldNum {
277				return file.GetPackage(), ext
278			}
279		}
280	}
281	return "", nil
282}
283
284func (desc *FileDescriptorSet) FindMessage(packageName string, typeName string, fieldName string) (msgPackageName string, msgName string) {
285	parent := desc.GetMessage(packageName, typeName)
286	if parent == nil {
287		return "", ""
288	}
289	field := parent.GetFieldDescriptor(fieldName)
290	if field == nil {
291		var extPackageName string
292		extPackageName, field = desc.FindExtension(packageName, typeName, fieldName)
293		if field == nil {
294			return "", ""
295		}
296		packageName = extPackageName
297	}
298	typeNames := strings.Split(field.GetTypeName(), ".")
299	if len(typeNames) == 1 {
300		msg := desc.GetMessage(packageName, typeName)
301		if msg == nil {
302			return "", ""
303		}
304		return packageName, msg.GetName()
305	}
306	if len(typeNames) > 2 {
307		for i := 1; i < len(typeNames)-1; i++ {
308			packageName = strings.Join(typeNames[1:len(typeNames)-i], ".")
309			typeName = strings.Join(typeNames[len(typeNames)-i:], ".")
310			msg := desc.GetMessage(packageName, typeName)
311			if msg != nil {
312				typeNames := strings.Split(msg.GetName(), ".")
313				if len(typeNames) == 1 {
314					return packageName, msg.GetName()
315				}
316				return strings.Join(typeNames[1:len(typeNames)-1], "."), typeNames[len(typeNames)-1]
317			}
318		}
319	}
320	return "", ""
321}
322
323func (msg *DescriptorProto) GetFieldDescriptor(fieldName string) *FieldDescriptorProto {
324	for _, field := range msg.GetField() {
325		if field.GetName() == fieldName {
326			return field
327		}
328	}
329	return nil
330}
331
332func (desc *FileDescriptorSet) GetEnum(packageName string, typeName string) *EnumDescriptorProto {
333	for _, file := range desc.GetFile() {
334		if strings.Map(dotToUnderscore, file.GetPackage()) != strings.Map(dotToUnderscore, packageName) {
335			continue
336		}
337		for _, enum := range file.GetEnumType() {
338			if enum.GetName() == typeName {
339				return enum
340			}
341		}
342	}
343	return nil
344}
345
346func (f *FieldDescriptorProto) IsEnum() bool {
347	return *f.Type == FieldDescriptorProto_TYPE_ENUM
348}
349
350func (f *FieldDescriptorProto) IsMessage() bool {
351	return *f.Type == FieldDescriptorProto_TYPE_MESSAGE
352}
353
354func (f *FieldDescriptorProto) IsBytes() bool {
355	return *f.Type == FieldDescriptorProto_TYPE_BYTES
356}
357
358func (f *FieldDescriptorProto) IsRepeated() bool {
359	return f.Label != nil && *f.Label == FieldDescriptorProto_LABEL_REPEATED
360}
361
362func (f *FieldDescriptorProto) IsString() bool {
363	return *f.Type == FieldDescriptorProto_TYPE_STRING
364}
365
366func (f *FieldDescriptorProto) IsBool() bool {
367	return *f.Type == FieldDescriptorProto_TYPE_BOOL
368}
369
370func (f *FieldDescriptorProto) IsRequired() bool {
371	return f.Label != nil && *f.Label == FieldDescriptorProto_LABEL_REQUIRED
372}
373
374func (f *FieldDescriptorProto) IsPacked() bool {
375	return f.Options != nil && f.GetOptions().GetPacked()
376}
377
378func (f *FieldDescriptorProto) IsPacked3() bool {
379	if f.IsRepeated() && f.IsScalar() {
380		if f.Options == nil || f.GetOptions().Packed == nil {
381			return true
382		}
383		return f.Options != nil && f.GetOptions().GetPacked()
384	}
385	return false
386}
387
388func (m *DescriptorProto) HasExtension() bool {
389	return len(m.ExtensionRange) > 0
390}
391