1// Protocol Buffers for Go with Gadgets 2// 3// Copyright (c) 2013, The GoGo Authors. All rights reserved. 4// http://github.com/gogo/protobuf 5// 6// Redistribution and use in source and binary forms, with or without 7// modification, are permitted provided that the following conditions are 8// met: 9// 10// * Redistributions of source code must retain the above copyright 11// notice, this list of conditions and the following disclaimer. 12// * Redistributions in binary form must reproduce the above 13// copyright notice, this list of conditions and the following disclaimer 14// in the documentation and/or other materials provided with the 15// distribution. 16// 17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 29package compare 30 31import ( 32 "github.com/gogo/protobuf/gogoproto" 33 "github.com/gogo/protobuf/proto" 34 descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" 35 "github.com/gogo/protobuf/protoc-gen-gogo/generator" 36 "github.com/gogo/protobuf/vanity" 37) 38 39type plugin struct { 40 *generator.Generator 41 generator.PluginImports 42 fmtPkg generator.Single 43 bytesPkg generator.Single 44 sortkeysPkg generator.Single 45 protoPkg generator.Single 46} 47 48func NewPlugin() *plugin { 49 return &plugin{} 50} 51 52func (p *plugin) Name() string { 53 return "compare" 54} 55 56func (p *plugin) Init(g *generator.Generator) { 57 p.Generator = g 58} 59 60func (p *plugin) Generate(file *generator.FileDescriptor) { 61 p.PluginImports = generator.NewPluginImports(p.Generator) 62 p.fmtPkg = p.NewImport("fmt") 63 p.bytesPkg = p.NewImport("bytes") 64 p.sortkeysPkg = p.NewImport("github.com/gogo/protobuf/sortkeys") 65 p.protoPkg = p.NewImport("github.com/gogo/protobuf/proto") 66 67 for _, msg := range file.Messages() { 68 if msg.DescriptorProto.GetOptions().GetMapEntry() { 69 continue 70 } 71 if gogoproto.HasCompare(file.FileDescriptorProto, msg.DescriptorProto) { 72 p.generateMessage(file, msg) 73 } 74 } 75} 76 77func (p *plugin) generateNullableField(fieldname string) { 78 p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`) 79 p.In() 80 p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`) 81 p.In() 82 p.P(`if *this.`, fieldname, ` < *that1.`, fieldname, `{`) 83 p.In() 84 p.P(`return -1`) 85 p.Out() 86 p.P(`}`) 87 p.P(`return 1`) 88 p.Out() 89 p.P(`}`) 90 p.Out() 91 p.P(`} else if this.`, fieldname, ` != nil {`) 92 p.In() 93 p.P(`return 1`) 94 p.Out() 95 p.P(`} else if that1.`, fieldname, ` != nil {`) 96 p.In() 97 p.P(`return -1`) 98 p.Out() 99 p.P(`}`) 100} 101 102func (p *plugin) generateMsgNullAndTypeCheck(ccTypeName string) { 103 p.P(`if that == nil {`) 104 p.In() 105 p.P(`if this == nil {`) 106 p.In() 107 p.P(`return 0`) 108 p.Out() 109 p.P(`}`) 110 p.P(`return 1`) 111 p.Out() 112 p.P(`}`) 113 p.P(``) 114 p.P(`that1, ok := that.(*`, ccTypeName, `)`) 115 p.P(`if !ok {`) 116 p.In() 117 p.P(`that2, ok := that.(`, ccTypeName, `)`) 118 p.P(`if ok {`) 119 p.In() 120 p.P(`that1 = &that2`) 121 p.Out() 122 p.P(`} else {`) 123 p.In() 124 p.P(`return 1`) 125 p.Out() 126 p.P(`}`) 127 p.Out() 128 p.P(`}`) 129 p.P(`if that1 == nil {`) 130 p.In() 131 p.P(`if this == nil {`) 132 p.In() 133 p.P(`return 0`) 134 p.Out() 135 p.P(`}`) 136 p.P(`return 1`) 137 p.Out() 138 p.P(`} else if this == nil {`) 139 p.In() 140 p.P(`return -1`) 141 p.Out() 142 p.P(`}`) 143} 144 145func (p *plugin) generateField(file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) { 146 proto3 := gogoproto.IsProto3(file.FileDescriptorProto) 147 fieldname := p.GetOneOfFieldName(message, field) 148 repeated := field.IsRepeated() 149 ctype := gogoproto.IsCustomType(field) 150 nullable := gogoproto.IsNullable(field) 151 // oneof := field.OneofIndex != nil 152 if !repeated { 153 if ctype { 154 if nullable { 155 p.P(`if that1.`, fieldname, ` == nil {`) 156 p.In() 157 p.P(`if this.`, fieldname, ` != nil {`) 158 p.In() 159 p.P(`return 1`) 160 p.Out() 161 p.P(`}`) 162 p.Out() 163 p.P(`} else if this.`, fieldname, ` == nil {`) 164 p.In() 165 p.P(`return -1`) 166 p.Out() 167 p.P(`} else if c := this.`, fieldname, `.Compare(*that1.`, fieldname, `); c != 0 {`) 168 } else { 169 p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`) 170 } 171 p.In() 172 p.P(`return c`) 173 p.Out() 174 p.P(`}`) 175 } else { 176 if field.IsMessage() || p.IsGroup(field) { 177 if nullable { 178 p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`) 179 } else { 180 p.P(`if c := this.`, fieldname, `.Compare(&that1.`, fieldname, `); c != 0 {`) 181 } 182 p.In() 183 p.P(`return c`) 184 p.Out() 185 p.P(`}`) 186 } else if field.IsBytes() { 187 p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`) 188 p.In() 189 p.P(`return c`) 190 p.Out() 191 p.P(`}`) 192 } else if field.IsString() { 193 if nullable && !proto3 { 194 p.generateNullableField(fieldname) 195 } else { 196 p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`) 197 p.In() 198 p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`) 199 p.In() 200 p.P(`return -1`) 201 p.Out() 202 p.P(`}`) 203 p.P(`return 1`) 204 p.Out() 205 p.P(`}`) 206 } 207 } else if field.IsBool() { 208 if nullable && !proto3 { 209 p.P(`if this.`, fieldname, ` != nil && that1.`, fieldname, ` != nil {`) 210 p.In() 211 p.P(`if *this.`, fieldname, ` != *that1.`, fieldname, `{`) 212 p.In() 213 p.P(`if !*this.`, fieldname, ` {`) 214 p.In() 215 p.P(`return -1`) 216 p.Out() 217 p.P(`}`) 218 p.P(`return 1`) 219 p.Out() 220 p.P(`}`) 221 p.Out() 222 p.P(`} else if this.`, fieldname, ` != nil {`) 223 p.In() 224 p.P(`return 1`) 225 p.Out() 226 p.P(`} else if that1.`, fieldname, ` != nil {`) 227 p.In() 228 p.P(`return -1`) 229 p.Out() 230 p.P(`}`) 231 } else { 232 p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`) 233 p.In() 234 p.P(`if !this.`, fieldname, ` {`) 235 p.In() 236 p.P(`return -1`) 237 p.Out() 238 p.P(`}`) 239 p.P(`return 1`) 240 p.Out() 241 p.P(`}`) 242 } 243 } else { 244 if nullable && !proto3 { 245 p.generateNullableField(fieldname) 246 } else { 247 p.P(`if this.`, fieldname, ` != that1.`, fieldname, `{`) 248 p.In() 249 p.P(`if this.`, fieldname, ` < that1.`, fieldname, `{`) 250 p.In() 251 p.P(`return -1`) 252 p.Out() 253 p.P(`}`) 254 p.P(`return 1`) 255 p.Out() 256 p.P(`}`) 257 } 258 } 259 } 260 } else { 261 p.P(`if len(this.`, fieldname, `) != len(that1.`, fieldname, `) {`) 262 p.In() 263 p.P(`if len(this.`, fieldname, `) < len(that1.`, fieldname, `) {`) 264 p.In() 265 p.P(`return -1`) 266 p.Out() 267 p.P(`}`) 268 p.P(`return 1`) 269 p.Out() 270 p.P(`}`) 271 p.P(`for i := range this.`, fieldname, ` {`) 272 p.In() 273 if ctype { 274 p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`) 275 p.In() 276 p.P(`return c`) 277 p.Out() 278 p.P(`}`) 279 } else { 280 if p.IsMap(field) { 281 m := p.GoMapType(nil, field) 282 valuegoTyp, _ := p.GoType(nil, m.ValueField) 283 valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField) 284 nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp) 285 286 mapValue := m.ValueAliasField 287 if mapValue.IsMessage() || p.IsGroup(mapValue) { 288 if nullable && valuegoTyp == valuegoAliasTyp { 289 p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`) 290 } else { 291 // Compare() has a pointer receiver, but map value is a value type 292 a := `this.` + fieldname + `[i]` 293 b := `that1.` + fieldname + `[i]` 294 if valuegoTyp != valuegoAliasTyp { 295 // cast back to the type that has the generated methods on it 296 a = `(` + valuegoTyp + `)(` + a + `)` 297 b = `(` + valuegoTyp + `)(` + b + `)` 298 } 299 p.P(`a := `, a) 300 p.P(`b := `, b) 301 if nullable { 302 p.P(`if c := a.Compare(b); c != 0 {`) 303 } else { 304 p.P(`if c := (&a).Compare(&b); c != 0 {`) 305 } 306 } 307 p.In() 308 p.P(`return c`) 309 p.Out() 310 p.P(`}`) 311 } else if mapValue.IsBytes() { 312 p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`) 313 p.In() 314 p.P(`return c`) 315 p.Out() 316 p.P(`}`) 317 } else if mapValue.IsString() { 318 p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) 319 p.In() 320 p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`) 321 p.In() 322 p.P(`return -1`) 323 p.Out() 324 p.P(`}`) 325 p.P(`return 1`) 326 p.Out() 327 p.P(`}`) 328 } else { 329 p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) 330 p.In() 331 p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`) 332 p.In() 333 p.P(`return -1`) 334 p.Out() 335 p.P(`}`) 336 p.P(`return 1`) 337 p.Out() 338 p.P(`}`) 339 } 340 } else if field.IsMessage() || p.IsGroup(field) { 341 if nullable { 342 p.P(`if c := this.`, fieldname, `[i].Compare(that1.`, fieldname, `[i]); c != 0 {`) 343 p.In() 344 p.P(`return c`) 345 p.Out() 346 p.P(`}`) 347 } else { 348 p.P(`if c := this.`, fieldname, `[i].Compare(&that1.`, fieldname, `[i]); c != 0 {`) 349 p.In() 350 p.P(`return c`) 351 p.Out() 352 p.P(`}`) 353 } 354 } else if field.IsBytes() { 355 p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `[i], that1.`, fieldname, `[i]); c != 0 {`) 356 p.In() 357 p.P(`return c`) 358 p.Out() 359 p.P(`}`) 360 } else if field.IsString() { 361 p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) 362 p.In() 363 p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`) 364 p.In() 365 p.P(`return -1`) 366 p.Out() 367 p.P(`}`) 368 p.P(`return 1`) 369 p.Out() 370 p.P(`}`) 371 } else if field.IsBool() { 372 p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) 373 p.In() 374 p.P(`if !this.`, fieldname, `[i] {`) 375 p.In() 376 p.P(`return -1`) 377 p.Out() 378 p.P(`}`) 379 p.P(`return 1`) 380 p.Out() 381 p.P(`}`) 382 } else { 383 p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`) 384 p.In() 385 p.P(`if this.`, fieldname, `[i] < that1.`, fieldname, `[i] {`) 386 p.In() 387 p.P(`return -1`) 388 p.Out() 389 p.P(`}`) 390 p.P(`return 1`) 391 p.Out() 392 p.P(`}`) 393 } 394 } 395 p.Out() 396 p.P(`}`) 397 } 398} 399 400func (p *plugin) generateMessage(file *generator.FileDescriptor, message *generator.Descriptor) { 401 ccTypeName := generator.CamelCaseSlice(message.TypeName()) 402 p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`) 403 p.In() 404 p.generateMsgNullAndTypeCheck(ccTypeName) 405 oneofs := make(map[string]struct{}) 406 407 for _, field := range message.Field { 408 oneof := field.OneofIndex != nil 409 if oneof { 410 fieldname := p.GetFieldName(message, field) 411 if _, ok := oneofs[fieldname]; ok { 412 continue 413 } else { 414 oneofs[fieldname] = struct{}{} 415 } 416 p.P(`if that1.`, fieldname, ` == nil {`) 417 p.In() 418 p.P(`if this.`, fieldname, ` != nil {`) 419 p.In() 420 p.P(`return 1`) 421 p.Out() 422 p.P(`}`) 423 p.Out() 424 p.P(`} else if this.`, fieldname, ` == nil {`) 425 p.In() 426 p.P(`return -1`) 427 p.Out() 428 p.P(`} else {`) 429 p.In() 430 431 // Generate two type switches in order to compare the 432 // types of the oneofs. If they are of the same type 433 // call Compare, otherwise return 1 or -1. 434 p.P(`thisType := -1`) 435 p.P(`switch this.`, fieldname, `.(type) {`) 436 for i, subfield := range message.Field { 437 if *subfield.OneofIndex == *field.OneofIndex { 438 ccTypeName := p.OneOfTypeName(message, subfield) 439 p.P(`case *`, ccTypeName, `:`) 440 p.In() 441 p.P(`thisType = `, i) 442 p.Out() 443 } 444 } 445 p.P(`default:`) 446 p.In() 447 p.P(`panic(fmt.Sprintf("compare: unexpected type %T in oneof", this.`, fieldname, `))`) 448 p.Out() 449 p.P(`}`) 450 451 p.P(`that1Type := -1`) 452 p.P(`switch that1.`, fieldname, `.(type) {`) 453 for i, subfield := range message.Field { 454 if *subfield.OneofIndex == *field.OneofIndex { 455 ccTypeName := p.OneOfTypeName(message, subfield) 456 p.P(`case *`, ccTypeName, `:`) 457 p.In() 458 p.P(`that1Type = `, i) 459 p.Out() 460 } 461 } 462 p.P(`default:`) 463 p.In() 464 p.P(`panic(fmt.Sprintf("compare: unexpected type %T in oneof", that1.`, fieldname, `))`) 465 p.Out() 466 p.P(`}`) 467 468 p.P(`if thisType == that1Type {`) 469 p.In() 470 p.P(`if c := this.`, fieldname, `.Compare(that1.`, fieldname, `); c != 0 {`) 471 p.In() 472 p.P(`return c`) 473 p.Out() 474 p.P(`}`) 475 p.Out() 476 p.P(`} else if thisType < that1Type {`) 477 p.In() 478 p.P(`return -1`) 479 p.Out() 480 p.P(`} else if thisType > that1Type {`) 481 p.In() 482 p.P(`return 1`) 483 p.Out() 484 p.P(`}`) 485 p.Out() 486 p.P(`}`) 487 } else { 488 p.generateField(file, message, field) 489 } 490 } 491 if message.DescriptorProto.HasExtension() { 492 if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) { 493 p.P(`thismap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(this)`) 494 p.P(`thatmap := `, p.protoPkg.Use(), `.GetUnsafeExtensionsMap(that1)`) 495 p.P(`extkeys := make([]int32, 0, len(thismap)+len(thatmap))`) 496 p.P(`for k, _ := range thismap {`) 497 p.In() 498 p.P(`extkeys = append(extkeys, k)`) 499 p.Out() 500 p.P(`}`) 501 p.P(`for k, _ := range thatmap {`) 502 p.In() 503 p.P(`if _, ok := thismap[k]; !ok {`) 504 p.In() 505 p.P(`extkeys = append(extkeys, k)`) 506 p.Out() 507 p.P(`}`) 508 p.Out() 509 p.P(`}`) 510 p.P(p.sortkeysPkg.Use(), `.Int32s(extkeys)`) 511 p.P(`for _, k := range extkeys {`) 512 p.In() 513 p.P(`if v, ok := thismap[k]; ok {`) 514 p.In() 515 p.P(`if v2, ok := thatmap[k]; ok {`) 516 p.In() 517 p.P(`if c := v.Compare(&v2); c != 0 {`) 518 p.In() 519 p.P(`return c`) 520 p.Out() 521 p.P(`}`) 522 p.Out() 523 p.P(`} else {`) 524 p.In() 525 p.P(`return 1`) 526 p.Out() 527 p.P(`}`) 528 p.Out() 529 p.P(`} else {`) 530 p.In() 531 p.P(`return -1`) 532 p.Out() 533 p.P(`}`) 534 p.Out() 535 p.P(`}`) 536 } else { 537 fieldname := "XXX_extensions" 538 p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`) 539 p.In() 540 p.P(`return c`) 541 p.Out() 542 p.P(`}`) 543 } 544 } 545 if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) { 546 fieldname := "XXX_unrecognized" 547 p.P(`if c := `, p.bytesPkg.Use(), `.Compare(this.`, fieldname, `, that1.`, fieldname, `); c != 0 {`) 548 p.In() 549 p.P(`return c`) 550 p.Out() 551 p.P(`}`) 552 } 553 p.P(`return 0`) 554 p.Out() 555 p.P(`}`) 556 557 //Generate Compare methods for oneof fields 558 m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto) 559 for _, field := range m.Field { 560 oneof := field.OneofIndex != nil 561 if !oneof { 562 continue 563 } 564 ccTypeName := p.OneOfTypeName(message, field) 565 p.P(`func (this *`, ccTypeName, `) Compare(that interface{}) int {`) 566 p.In() 567 568 p.generateMsgNullAndTypeCheck(ccTypeName) 569 vanity.TurnOffNullableForNativeTypes(field) 570 p.generateField(file, message, field) 571 572 p.P(`return 0`) 573 p.Out() 574 p.P(`}`) 575 } 576} 577 578func init() { 579 generator.RegisterPlugin(NewPlugin()) 580} 581