1// Protocol Buffers for Go with Gadgets
2//
3// Copyright (c) 2013, The GoGo Authors. All rights reserved.
4// http://github.com/gogo/protobuf
5//
6// Redistribution and use in source and binary forms, with or without
7// modification, are permitted provided that the following conditions are
8// met:
9//
10//     * Redistributions of source code must retain the above copyright
11// notice, this list of conditions and the following disclaimer.
12//     * Redistributions in binary form must reproduce the above
13// copyright notice, this list of conditions and the following disclaimer
14// in the documentation and/or other materials provided with the
15// distribution.
16//
17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
21// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
22// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
23// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
25// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
29package compare
30
31import (
32	"github.com/gogo/protobuf/gogoproto"
33	"github.com/gogo/protobuf/proto"
34	descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
35	"github.com/gogo/protobuf/protoc-gen-gogo/generator"
36	"github.com/gogo/protobuf/vanity"
37)
38
39type plugin struct {
40	*generator.Generator
41	generator.PluginImports
42	fmtPkg      generator.Single
43	bytesPkg    generator.Single
44	sortkeysPkg generator.Single
45	protoPkg    generator.Single
46}
47
48func NewPlugin() *plugin {
49	return &plugin{}
50}
51
52func (p *plugin) Name() string {
53	return "compare"
54}
55
56func (p *plugin) Init(g *generator.Generator) {
57	p.Generator = g
58}
59
60func (p *plugin) Generate(file *generator.FileDescriptor) {
61	p.PluginImports = generator.NewPluginImports(p.Generator)
62	p.fmtPkg = p.NewImport("fmt")
63	p.bytesPkg = p.NewImport("bytes")
64	p.sortkeysPkg = p.NewImport("github.com/gogo/protobuf/sortkeys")
65	p.protoPkg = p.NewImport("github.com/gogo/protobuf/proto")
66
67	for _, msg := range file.Messages() {
68		if msg.DescriptorProto.GetOptions().GetMapEntry() {
69			continue
70		}
71		if gogoproto.HasCompare(file.FileDescriptorProto, msg.DescriptorProto) {
72			p.generateMessage(file, msg)
73		}
74	}
75}
76
77func (p *plugin) generateNullableField(fieldname string) {
78	p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`)
79	p.In()
80	p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`)
81	p.In()
82	p.P(`if *this.`, fieldname, ` < *that1.`, fieldname, `{`)
83	p.In()
84	p.P(`return -1`)
85	p.Out()
86	p.P(`}`)
87	p.P(`return 1`)
88	p.Out()
89	p.P(`}`)
90	p.Out()
91	p.P(`} else if this.`, fieldname, ` != nil {`)
92	p.In()
93	p.P(`return 1`)
94	p.Out()
95	p.P(`} else if that1.`, fieldname, ` != nil {`)
96	p.In()
97	p.P(`return -1`)
98	p.Out()
99	p.P(`}`)
100}
101
102func (p *plugin) generateMsgNullAndTypeCheck(ccTypeName string) {
103	p.P(`if that == nil {`)
104	p.In()
105	p.P(`if this == nil {`)
106	p.In()
107	p.P(`return 0`)
108	p.Out()
109	p.P(`}`)
110	p.P(`return 1`)
111	p.Out()
112	p.P(`}`)
113	p.P(``)
114	p.P(`that1, ok := that.(*`, ccTypeName, `)`)
115	p.P(`if !ok {`)
116	p.In()
117	p.P(`that2, ok := that.(`, ccTypeName, `)`)
118	p.P(`if ok {`)
119	p.In()
120	p.P(`that1 = &that2`)
121	p.Out()
122	p.P(`} else {`)
123	p.In()
124	p.P(`return 1`)
125	p.Out()
126	p.P(`}`)
127	p.Out()
128	p.P(`}`)
129	p.P(`if that1 == nil {`)
130	p.In()
131	p.P(`if this == nil {`)
132	p.In()
133	p.P(`return 0`)
134	p.Out()
135	p.P(`}`)
136	p.P(`return 1`)
137	p.Out()
138	p.P(`} else if this == nil {`)
139	p.In()
140	p.P(`return -1`)
141	p.Out()
142	p.P(`}`)
143}
144
145func (p *plugin) generateField(file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) {
146	proto3 := gogoproto.IsProto3(file.FileDescriptorProto)
147	fieldname := p.GetOneOfFieldName(message, field)
148	repeated := field.IsRepeated()
149	ctype := gogoproto.IsCustomType(field)
150	nullable := gogoproto.IsNullable(field)
151	// oneof := field.OneofIndex != nil
152	if !repeated {
153		if ctype {
154			if nullable {
155				p.P(`if that1.`, fieldname, ` == nil {`)
156				p.In()
157				p.P(`if this.`, fieldname, ` != nil {`)
158				p.In()
159				p.P(`return 1`)
160				p.Out()
161				p.P(`}`)
162				p.Out()
163				p.P(`} else if this.`, fieldname, ` == nil {`)
164				p.In()
165				p.P(`return -1`)
166				p.Out()
167				p.P(`} else if c := this.`, fieldname, `.Compare(*that1.`, fieldname, `); c != 0 {`)
168			} else {
169				p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
170			}
171			p.In()
172			p.P(`return c`)
173			p.Out()
174			p.P(`}`)
175		} else {
176			if field.IsMessage() || p.IsGroup(field) {
177				if nullable {
178					p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
179				} else {
180					p.P(`if c := this.`, fieldname, `.Compare(&that1.`, fieldname, `); c != 0 {`)
181				}
182				p.In()
183				p.P(`return c`)
184				p.Out()
185				p.P(`}`)
186			} else if field.IsBytes() {
187				p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
188				p.In()
189				p.P(`return c`)
190				p.Out()
191				p.P(`}`)
192			} else if field.IsString() {
193				if nullable && !proto3 {
194					p.generateNullableField(fieldname)
195				} else {
196					p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
197					p.In()
198					p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`)
199					p.In()
200					p.P(`return -1`)
201					p.Out()
202					p.P(`}`)
203					p.P(`return 1`)
204					p.Out()
205					p.P(`}`)
206				}
207			} else if field.IsBool() {
208				if nullable && !proto3 {
209					p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`)
210					p.In()
211					p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`)
212					p.In()
213					p.P(`if !*this.`, fieldname, ` {`)
214					p.In()
215					p.P(`return -1`)
216					p.Out()
217					p.P(`}`)
218					p.P(`return 1`)
219					p.Out()
220					p.P(`}`)
221					p.Out()
222					p.P(`} else if this.`, fieldname, ` != nil {`)
223					p.In()
224					p.P(`return 1`)
225					p.Out()
226					p.P(`} else if that1.`, fieldname, ` != nil {`)
227					p.In()
228					p.P(`return -1`)
229					p.Out()
230					p.P(`}`)
231				} else {
232					p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
233					p.In()
234					p.P(`if !this.`, fieldname, ` {`)
235					p.In()
236					p.P(`return -1`)
237					p.Out()
238					p.P(`}`)
239					p.P(`return 1`)
240					p.Out()
241					p.P(`}`)
242				}
243			} else {
244				if nullable && !proto3 {
245					p.generateNullableField(fieldname)
246				} else {
247					p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`)
248					p.In()
249					p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`)
250					p.In()
251					p.P(`return -1`)
252					p.Out()
253					p.P(`}`)
254					p.P(`return 1`)
255					p.Out()
256					p.P(`}`)
257				}
258			}
259		}
260	} else {
261		p.P(`if len(this.`, fieldname, `) != len(that1.`, fieldname, `) {`)
262		p.In()
263		p.P(`if len(this.`, fieldname, `) < len(that1.`, fieldname, `) {`)
264		p.In()
265		p.P(`return -1`)
266		p.Out()
267		p.P(`}`)
268		p.P(`return 1`)
269		p.Out()
270		p.P(`}`)
271		p.P(`for i := range this.`, fieldname, ` {`)
272		p.In()
273		if ctype {
274			p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
275			p.In()
276			p.P(`return c`)
277			p.Out()
278			p.P(`}`)
279		} else {
280			if p.IsMap(field) {
281				m := p.GoMapType(nil, field)
282				valuegoTyp, _ := p.GoType(nil, m.ValueField)
283				valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField)
284				nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp)
285
286				mapValue := m.ValueAliasField
287				if mapValue.IsMessage() || p.IsGroup(mapValue) {
288					if nullable && valuegoTyp == valuegoAliasTyp {
289						p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
290					} else {
291						// Compare() has a pointer receiver, but map value is a value type
292						a := `this.` + fieldname + `[i]`
293						b := `that1.` + fieldname + `[i]`
294						if valuegoTyp != valuegoAliasTyp {
295							// cast back to the type that has the generated methods on it
296							a = `(` + valuegoTyp + `)(` + a + `)`
297							b = `(` + valuegoTyp + `)(` + b + `)`
298						}
299						p.P(`a := `, a)
300						p.P(`b := `, b)
301						if nullable {
302							p.P(`if c := a.Compare(b); c != 0 {`)
303						} else {
304							p.P(`if c := (&a).Compare(&b); c != 0 {`)
305						}
306					}
307					p.In()
308					p.P(`return c`)
309					p.Out()
310					p.P(`}`)
311				} else if mapValue.IsBytes() {
312					p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`)
313					p.In()
314					p.P(`return c`)
315					p.Out()
316					p.P(`}`)
317				} else if mapValue.IsString() {
318					p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
319					p.In()
320					p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
321					p.In()
322					p.P(`return -1`)
323					p.Out()
324					p.P(`}`)
325					p.P(`return 1`)
326					p.Out()
327					p.P(`}`)
328				} else {
329					p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
330					p.In()
331					p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
332					p.In()
333					p.P(`return -1`)
334					p.Out()
335					p.P(`}`)
336					p.P(`return 1`)
337					p.Out()
338					p.P(`}`)
339				}
340			} else if field.IsMessage() || p.IsGroup(field) {
341				if nullable {
342					p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`)
343					p.In()
344					p.P(`return c`)
345					p.Out()
346					p.P(`}`)
347				} else {
348					p.P(`if c := this.`, fieldname, `[i].Compare(&that1.`, fieldname, `[i]); c != 0 {`)
349					p.In()
350					p.P(`return c`)
351					p.Out()
352					p.P(`}`)
353				}
354			} else if field.IsBytes() {
355				p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`)
356				p.In()
357				p.P(`return c`)
358				p.Out()
359				p.P(`}`)
360			} else if field.IsString() {
361				p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
362				p.In()
363				p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
364				p.In()
365				p.P(`return -1`)
366				p.Out()
367				p.P(`}`)
368				p.P(`return 1`)
369				p.Out()
370				p.P(`}`)
371			} else if field.IsBool() {
372				p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
373				p.In()
374				p.P(`if !this.`, fieldname, `[i] {`)
375				p.In()
376				p.P(`return -1`)
377				p.Out()
378				p.P(`}`)
379				p.P(`return 1`)
380				p.Out()
381				p.P(`}`)
382			} else {
383				p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
384				p.In()
385				p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`)
386				p.In()
387				p.P(`return -1`)
388				p.Out()
389				p.P(`}`)
390				p.P(`return 1`)
391				p.Out()
392				p.P(`}`)
393			}
394		}
395		p.Out()
396		p.P(`}`)
397	}
398}
399
400func (p *plugin) generateMessage(file *generator.FileDescriptor, message *generator.Descriptor) {
401	ccTypeName := generator.CamelCaseSlice(message.TypeName())
402	p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`)
403	p.In()
404	p.generateMsgNullAndTypeCheck(ccTypeName)
405	oneofs := make(map[string]struct{})
406
407	for _, field := range message.Field {
408		oneof := field.OneofIndex != nil
409		if oneof {
410			fieldname := p.GetFieldName(message, field)
411			if _, ok := oneofs[fieldname]; ok {
412				continue
413			} else {
414				oneofs[fieldname] = struct{}{}
415			}
416			p.P(`if that1.`, fieldname, ` == nil {`)
417			p.In()
418			p.P(`if this.`, fieldname, ` != nil {`)
419			p.In()
420			p.P(`return 1`)
421			p.Out()
422			p.P(`}`)
423			p.Out()
424			p.P(`} else if this.`, fieldname, ` == nil {`)
425			p.In()
426			p.P(`return -1`)
427			p.Out()
428			p.P(`} else {`)
429			p.In()
430
431			// Generate two type switches in order to compare the
432			// types of the oneofs. If they are of the same type
433			// call Compare, otherwise return 1 or -1.
434			p.P(`thisType := -1`)
435			p.P(`switch this.`, fieldname, `.(type) {`)
436			for i, subfield := range message.Field {
437				if *subfield.OneofIndex == *field.OneofIndex {
438					ccTypeName := p.OneOfTypeName(message, subfield)
439					p.P(`case *`, ccTypeName, `:`)
440					p.In()
441					p.P(`thisType = `, i)
442					p.Out()
443				}
444			}
445			p.P(`default:`)
446			p.In()
447			p.P(`panic(fmt.Sprintf("compare: unexpected type %T in oneof", this.`, fieldname, `))`)
448			p.Out()
449			p.P(`}`)
450
451			p.P(`that1Type := -1`)
452			p.P(`switch that1.`, fieldname, `.(type) {`)
453			for i, subfield := range message.Field {
454				if *subfield.OneofIndex == *field.OneofIndex {
455					ccTypeName := p.OneOfTypeName(message, subfield)
456					p.P(`case *`, ccTypeName, `:`)
457					p.In()
458					p.P(`that1Type = `, i)
459					p.Out()
460				}
461			}
462			p.P(`default:`)
463			p.In()
464			p.P(`panic(fmt.Sprintf("compare: unexpected type %T in oneof", that1.`, fieldname, `))`)
465			p.Out()
466			p.P(`}`)
467
468			p.P(`if thisType == that1Type {`)
469			p.In()
470			p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`)
471			p.In()
472			p.P(`return c`)
473			p.Out()
474			p.P(`}`)
475			p.Out()
476			p.P(`} else if thisType < that1Type {`)
477			p.In()
478			p.P(`return -1`)
479			p.Out()
480			p.P(`} else if thisType > that1Type {`)
481			p.In()
482			p.P(`return 1`)
483			p.Out()
484			p.P(`}`)
485			p.Out()
486			p.P(`}`)
487		} else {
488			p.generateField(file, message, field)
489		}
490	}
491	if message.DescriptorProto.HasExtension() {
492		if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) {
493			p.P(`thismap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(this)`)
494			p.P(`thatmap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(that1)`)
495			p.P(`extkeys := make([]int32, 0, len(thismap)+len(thatmap))`)
496			p.P(`for k, _ := range thismap {`)
497			p.In()
498			p.P(`extkeys = append(extkeys, k)`)
499			p.Out()
500			p.P(`}`)
501			p.P(`for k, _ := range thatmap {`)
502			p.In()
503			p.P(`if _, ok := thismap[k]; !ok {`)
504			p.In()
505			p.P(`extkeys = append(extkeys, k)`)
506			p.Out()
507			p.P(`}`)
508			p.Out()
509			p.P(`}`)
510			p.P(p.sortkeysPkg.Use(), `.Int32s(extkeys)`)
511			p.P(`for _, k := range extkeys {`)
512			p.In()
513			p.P(`if v, ok := thismap[k]; ok {`)
514			p.In()
515			p.P(`if v2, ok := thatmap[k]; ok {`)
516			p.In()
517			p.P(`if c := v.Compare(&v2); c != 0 {`)
518			p.In()
519			p.P(`return c`)
520			p.Out()
521			p.P(`}`)
522			p.Out()
523			p.P(`} else  {`)
524			p.In()
525			p.P(`return 1`)
526			p.Out()
527			p.P(`}`)
528			p.Out()
529			p.P(`} else {`)
530			p.In()
531			p.P(`return -1`)
532			p.Out()
533			p.P(`}`)
534			p.Out()
535			p.P(`}`)
536		} else {
537			fieldname := "XXX_extensions"
538			p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
539			p.In()
540			p.P(`return c`)
541			p.Out()
542			p.P(`}`)
543		}
544	}
545	if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) {
546		fieldname := "XXX_unrecognized"
547		p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`)
548		p.In()
549		p.P(`return c`)
550		p.Out()
551		p.P(`}`)
552	}
553	p.P(`return 0`)
554	p.Out()
555	p.P(`}`)
556
557	//Generate Compare methods for oneof fields
558	m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto)
559	for _, field := range m.Field {
560		oneof := field.OneofIndex != nil
561		if !oneof {
562			continue
563		}
564		ccTypeName := p.OneOfTypeName(message, field)
565		p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`)
566		p.In()
567
568		p.generateMsgNullAndTypeCheck(ccTypeName)
569		vanity.TurnOffNullableForNativeTypes(field)
570		p.generateField(file, message, field)
571
572		p.P(`return 0`)
573		p.Out()
574		p.P(`}`)
575	}
576}
577
578func init() {
579	generator.RegisterPlugin(NewPlugin())
580}
581