1package gen 2 3import ( 4 "encoding" 5 "encoding/json" 6 "fmt" 7 "reflect" 8 "strings" 9 "unicode" 10 11 "github.com/mailru/easyjson" 12) 13 14// Target this byte size for initial slice allocation to reduce garbage collection. 15const minSliceBytes = 64 16 17func (g *Generator) getDecoderName(t reflect.Type) string { 18 return g.functionName("decode", t) 19} 20 21var primitiveDecoders = map[reflect.Kind]string{ 22 reflect.String: "in.String()", 23 reflect.Bool: "in.Bool()", 24 reflect.Int: "in.Int()", 25 reflect.Int8: "in.Int8()", 26 reflect.Int16: "in.Int16()", 27 reflect.Int32: "in.Int32()", 28 reflect.Int64: "in.Int64()", 29 reflect.Uint: "in.Uint()", 30 reflect.Uint8: "in.Uint8()", 31 reflect.Uint16: "in.Uint16()", 32 reflect.Uint32: "in.Uint32()", 33 reflect.Uint64: "in.Uint64()", 34 reflect.Float32: "in.Float32()", 35 reflect.Float64: "in.Float64()", 36} 37 38var primitiveStringDecoders = map[reflect.Kind]string{ 39 reflect.String: "in.String()", 40 reflect.Int: "in.IntStr()", 41 reflect.Int8: "in.Int8Str()", 42 reflect.Int16: "in.Int16Str()", 43 reflect.Int32: "in.Int32Str()", 44 reflect.Int64: "in.Int64Str()", 45 reflect.Uint: "in.UintStr()", 46 reflect.Uint8: "in.Uint8Str()", 47 reflect.Uint16: "in.Uint16Str()", 48 reflect.Uint32: "in.Uint32Str()", 49 reflect.Uint64: "in.Uint64Str()", 50 reflect.Uintptr: "in.UintptrStr()", 51 reflect.Float32: "in.Float32Str()", 52 reflect.Float64: "in.Float64Str()", 53} 54 55var customDecoders = map[string]string{ 56 "json.Number": "in.JsonNumber()", 57} 58 59// genTypeDecoder generates decoding code for the type t, but uses unmarshaler interface if implemented by t. 60func (g *Generator) genTypeDecoder(t reflect.Type, out string, tags fieldTags, indent int) error { 61 ws := strings.Repeat(" ", indent) 62 63 unmarshalerIface := reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem() 64 if reflect.PtrTo(t).Implements(unmarshalerIface) { 65 fmt.Fprintln(g.out, ws+"("+out+").UnmarshalEasyJSON(in)") 66 return nil 67 } 68 69 unmarshalerIface = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() 70 if reflect.PtrTo(t).Implements(unmarshalerIface) { 71 fmt.Fprintln(g.out, ws+"if data := in.Raw(); in.Ok() {") 72 fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalJSON(data) )") 73 fmt.Fprintln(g.out, ws+"}") 74 return nil 75 } 76 77 unmarshalerIface = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() 78 if reflect.PtrTo(t).Implements(unmarshalerIface) { 79 fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {") 80 fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalText(data) )") 81 fmt.Fprintln(g.out, ws+"}") 82 return nil 83 } 84 85 err := g.genTypeDecoderNoCheck(t, out, tags, indent) 86 return err 87} 88 89// returns true of the type t implements one of the custom unmarshaler interfaces 90func hasCustomUnmarshaler(t reflect.Type) bool { 91 t = reflect.PtrTo(t) 92 return t.Implements(reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()) || 93 t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) || 94 t.Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()) 95} 96 97func hasUnknownsUnmarshaler(t reflect.Type) bool { 98 t = reflect.PtrTo(t) 99 return t.Implements(reflect.TypeOf((*easyjson.UnknownsUnmarshaler)(nil)).Elem()) 100} 101 102func hasUnknownsMarshaler(t reflect.Type) bool { 103 t = reflect.PtrTo(t) 104 return t.Implements(reflect.TypeOf((*easyjson.UnknownsMarshaler)(nil)).Elem()) 105} 106 107// genTypeDecoderNoCheck generates decoding code for the type t. 108func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags fieldTags, indent int) error { 109 ws := strings.Repeat(" ", indent) 110 // Check whether type is primitive, needs to be done after interface check. 111 if dec := customDecoders[t.String()]; dec != "" { 112 fmt.Fprintln(g.out, ws+out+" = "+dec) 113 return nil 114 } else if dec := primitiveStringDecoders[t.Kind()]; dec != "" && tags.asString { 115 fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")") 116 return nil 117 } else if dec := primitiveDecoders[t.Kind()]; dec != "" { 118 fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")") 119 return nil 120 } 121 122 switch t.Kind() { 123 case reflect.Slice: 124 tmpVar := g.uniqueVarName() 125 elem := t.Elem() 126 127 if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" { 128 fmt.Fprintln(g.out, ws+"if in.IsNull() {") 129 fmt.Fprintln(g.out, ws+" in.Skip()") 130 fmt.Fprintln(g.out, ws+" "+out+" = nil") 131 fmt.Fprintln(g.out, ws+"} else {") 132 fmt.Fprintln(g.out, ws+" "+out+" = in.Bytes()") 133 fmt.Fprintln(g.out, ws+"}") 134 135 } else { 136 137 capacity := minSliceBytes / elem.Size() 138 if capacity == 0 { 139 capacity = 1 140 } 141 142 fmt.Fprintln(g.out, ws+"if in.IsNull() {") 143 fmt.Fprintln(g.out, ws+" in.Skip()") 144 fmt.Fprintln(g.out, ws+" "+out+" = nil") 145 fmt.Fprintln(g.out, ws+"} else {") 146 fmt.Fprintln(g.out, ws+" in.Delim('[')") 147 fmt.Fprintln(g.out, ws+" if "+out+" == nil {") 148 fmt.Fprintln(g.out, ws+" if !in.IsDelim(']') {") 149 fmt.Fprintln(g.out, ws+" "+out+" = make("+g.getType(t)+", 0, "+fmt.Sprint(capacity)+")") 150 fmt.Fprintln(g.out, ws+" } else {") 151 fmt.Fprintln(g.out, ws+" "+out+" = "+g.getType(t)+"{}") 152 fmt.Fprintln(g.out, ws+" }") 153 fmt.Fprintln(g.out, ws+" } else { ") 154 fmt.Fprintln(g.out, ws+" "+out+" = ("+out+")[:0]") 155 fmt.Fprintln(g.out, ws+" }") 156 fmt.Fprintln(g.out, ws+" for !in.IsDelim(']') {") 157 fmt.Fprintln(g.out, ws+" var "+tmpVar+" "+g.getType(elem)) 158 159 if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil { 160 return err 161 } 162 163 fmt.Fprintln(g.out, ws+" "+out+" = append("+out+", "+tmpVar+")") 164 fmt.Fprintln(g.out, ws+" in.WantComma()") 165 fmt.Fprintln(g.out, ws+" }") 166 fmt.Fprintln(g.out, ws+" in.Delim(']')") 167 fmt.Fprintln(g.out, ws+"}") 168 } 169 170 case reflect.Array: 171 iterVar := g.uniqueVarName() 172 elem := t.Elem() 173 174 if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" { 175 fmt.Fprintln(g.out, ws+"if in.IsNull() {") 176 fmt.Fprintln(g.out, ws+" in.Skip()") 177 fmt.Fprintln(g.out, ws+"} else {") 178 fmt.Fprintln(g.out, ws+" copy("+out+"[:], in.Bytes())") 179 fmt.Fprintln(g.out, ws+"}") 180 181 } else { 182 183 length := t.Len() 184 185 fmt.Fprintln(g.out, ws+"if in.IsNull() {") 186 fmt.Fprintln(g.out, ws+" in.Skip()") 187 fmt.Fprintln(g.out, ws+"} else {") 188 fmt.Fprintln(g.out, ws+" in.Delim('[')") 189 fmt.Fprintln(g.out, ws+" "+iterVar+" := 0") 190 fmt.Fprintln(g.out, ws+" for !in.IsDelim(']') {") 191 fmt.Fprintln(g.out, ws+" if "+iterVar+" < "+fmt.Sprint(length)+" {") 192 193 if err := g.genTypeDecoder(elem, "("+out+")["+iterVar+"]", tags, indent+3); err != nil { 194 return err 195 } 196 197 fmt.Fprintln(g.out, ws+" "+iterVar+"++") 198 fmt.Fprintln(g.out, ws+" } else {") 199 fmt.Fprintln(g.out, ws+" in.SkipRecursive()") 200 fmt.Fprintln(g.out, ws+" }") 201 fmt.Fprintln(g.out, ws+" in.WantComma()") 202 fmt.Fprintln(g.out, ws+" }") 203 fmt.Fprintln(g.out, ws+" in.Delim(']')") 204 fmt.Fprintln(g.out, ws+"}") 205 } 206 207 case reflect.Struct: 208 dec := g.getDecoderName(t) 209 g.addType(t) 210 211 if len(out) > 0 && out[0] == '*' { 212 // NOTE: In order to remove an extra reference to a pointer 213 fmt.Fprintln(g.out, ws+dec+"(in, "+out[1:]+")") 214 } else { 215 fmt.Fprintln(g.out, ws+dec+"(in, &"+out+")") 216 } 217 218 case reflect.Ptr: 219 fmt.Fprintln(g.out, ws+"if in.IsNull() {") 220 fmt.Fprintln(g.out, ws+" in.Skip()") 221 fmt.Fprintln(g.out, ws+" "+out+" = nil") 222 fmt.Fprintln(g.out, ws+"} else {") 223 fmt.Fprintln(g.out, ws+" if "+out+" == nil {") 224 fmt.Fprintln(g.out, ws+" "+out+" = new("+g.getType(t.Elem())+")") 225 fmt.Fprintln(g.out, ws+" }") 226 227 if err := g.genTypeDecoder(t.Elem(), "*"+out, tags, indent+1); err != nil { 228 return err 229 } 230 231 fmt.Fprintln(g.out, ws+"}") 232 233 case reflect.Map: 234 key := t.Key() 235 keyDec, ok := primitiveStringDecoders[key.Kind()] 236 if !ok && !hasCustomUnmarshaler(key) { 237 return fmt.Errorf("map type %v not supported: only string and integer keys and types implementing json.Unmarshaler are allowed", key) 238 } // else assume the caller knows what they are doing and that the custom unmarshaler performs the translation from string or integer keys to the key type 239 elem := t.Elem() 240 tmpVar := g.uniqueVarName() 241 keepEmpty := tags.required || tags.noOmitEmpty || (!g.omitEmpty && !tags.omitEmpty) 242 243 fmt.Fprintln(g.out, ws+"if in.IsNull() {") 244 fmt.Fprintln(g.out, ws+" in.Skip()") 245 fmt.Fprintln(g.out, ws+"} else {") 246 fmt.Fprintln(g.out, ws+" in.Delim('{')") 247 if !keepEmpty { 248 fmt.Fprintln(g.out, ws+" if !in.IsDelim('}') {") 249 } 250 fmt.Fprintln(g.out, ws+" "+out+" = make("+g.getType(t)+")") 251 if !keepEmpty { 252 fmt.Fprintln(g.out, ws+" } else {") 253 fmt.Fprintln(g.out, ws+" "+out+" = nil") 254 fmt.Fprintln(g.out, ws+" }") 255 } 256 257 fmt.Fprintln(g.out, ws+" for !in.IsDelim('}') {") 258 // NOTE: extra check for TextUnmarshaler. It overrides default methods. 259 if reflect.PtrTo(key).Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()) { 260 fmt.Fprintln(g.out, ws+" var key "+g.getType(key)) 261 fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {") 262 fmt.Fprintln(g.out, ws+" in.AddError(key.UnmarshalText(data) )") 263 fmt.Fprintln(g.out, ws+"}") 264 } else if keyDec != "" { 265 fmt.Fprintln(g.out, ws+" key := "+g.getType(key)+"("+keyDec+")") 266 } else { 267 fmt.Fprintln(g.out, ws+" var key "+g.getType(key)) 268 if err := g.genTypeDecoder(key, "key", tags, indent+2); err != nil { 269 return err 270 } 271 } 272 273 fmt.Fprintln(g.out, ws+" in.WantColon()") 274 fmt.Fprintln(g.out, ws+" var "+tmpVar+" "+g.getType(elem)) 275 276 if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil { 277 return err 278 } 279 280 fmt.Fprintln(g.out, ws+" ("+out+")[key] = "+tmpVar) 281 fmt.Fprintln(g.out, ws+" in.WantComma()") 282 fmt.Fprintln(g.out, ws+" }") 283 fmt.Fprintln(g.out, ws+" in.Delim('}')") 284 fmt.Fprintln(g.out, ws+"}") 285 286 case reflect.Interface: 287 if t.NumMethod() != 0 { 288 return fmt.Errorf("interface type %v not supported: only interface{} is allowed", t) 289 } 290 fmt.Fprintln(g.out, ws+"if m, ok := "+out+".(easyjson.Unmarshaler); ok {") 291 fmt.Fprintln(g.out, ws+"m.UnmarshalEasyJSON(in)") 292 fmt.Fprintln(g.out, ws+"} else if m, ok := "+out+".(json.Unmarshaler); ok {") 293 fmt.Fprintln(g.out, ws+"_ = m.UnmarshalJSON(in.Raw())") 294 fmt.Fprintln(g.out, ws+"} else {") 295 fmt.Fprintln(g.out, ws+" "+out+" = in.Interface()") 296 fmt.Fprintln(g.out, ws+"}") 297 default: 298 return fmt.Errorf("don't know how to decode %v", t) 299 } 300 return nil 301 302} 303 304func (g *Generator) genStructFieldDecoder(t reflect.Type, f reflect.StructField) error { 305 jsonName := g.fieldNamer.GetJSONFieldName(t, f) 306 tags := parseFieldTags(f) 307 308 if tags.omit { 309 return nil 310 } 311 312 fmt.Fprintf(g.out, " case %q:\n", jsonName) 313 if err := g.genTypeDecoder(f.Type, "out."+f.Name, tags, 3); err != nil { 314 return err 315 } 316 317 if tags.required { 318 fmt.Fprintf(g.out, "%sSet = true\n", f.Name) 319 } 320 321 return nil 322} 323 324func (g *Generator) genRequiredFieldSet(t reflect.Type, f reflect.StructField) { 325 tags := parseFieldTags(f) 326 327 if !tags.required { 328 return 329 } 330 331 fmt.Fprintf(g.out, "var %sSet bool\n", f.Name) 332} 333 334func (g *Generator) genRequiredFieldCheck(t reflect.Type, f reflect.StructField) { 335 jsonName := g.fieldNamer.GetJSONFieldName(t, f) 336 tags := parseFieldTags(f) 337 338 if !tags.required { 339 return 340 } 341 342 g.imports["fmt"] = "fmt" 343 344 fmt.Fprintf(g.out, "if !%sSet {\n", f.Name) 345 fmt.Fprintf(g.out, " in.AddError(fmt.Errorf(\"key '%s' is required\"))\n", jsonName) 346 fmt.Fprintf(g.out, "}\n") 347} 348 349func mergeStructFields(fields1, fields2 []reflect.StructField) (fields []reflect.StructField) { 350 used := map[string]bool{} 351 for _, f := range fields2 { 352 used[f.Name] = true 353 fields = append(fields, f) 354 } 355 356 for _, f := range fields1 { 357 if !used[f.Name] { 358 fields = append(fields, f) 359 } 360 } 361 return 362} 363 364func getStructFields(t reflect.Type) ([]reflect.StructField, error) { 365 if t.Kind() != reflect.Struct { 366 return nil, fmt.Errorf("got %v; expected a struct", t) 367 } 368 369 var efields []reflect.StructField 370 for i := 0; i < t.NumField(); i++ { 371 f := t.Field(i) 372 tags := parseFieldTags(f) 373 if !f.Anonymous || tags.name != "" { 374 continue 375 } 376 377 t1 := f.Type 378 if t1.Kind() == reflect.Ptr { 379 t1 = t1.Elem() 380 } 381 382 fs, err := getStructFields(t1) 383 if err != nil { 384 return nil, fmt.Errorf("error processing embedded field: %v", err) 385 } 386 efields = mergeStructFields(efields, fs) 387 } 388 389 var fields []reflect.StructField 390 for i := 0; i < t.NumField(); i++ { 391 f := t.Field(i) 392 tags := parseFieldTags(f) 393 if f.Anonymous && tags.name == "" { 394 continue 395 } 396 397 c := []rune(f.Name)[0] 398 if unicode.IsUpper(c) { 399 fields = append(fields, f) 400 } 401 } 402 return mergeStructFields(efields, fields), nil 403} 404 405func (g *Generator) genDecoder(t reflect.Type) error { 406 switch t.Kind() { 407 case reflect.Slice, reflect.Array, reflect.Map: 408 return g.genSliceArrayDecoder(t) 409 default: 410 return g.genStructDecoder(t) 411 } 412} 413 414func (g *Generator) genSliceArrayDecoder(t reflect.Type) error { 415 switch t.Kind() { 416 case reflect.Slice, reflect.Array, reflect.Map: 417 default: 418 return fmt.Errorf("cannot generate encoder/decoder for %v, not a slice/array/map type", t) 419 } 420 421 fname := g.getDecoderName(t) 422 typ := g.getType(t) 423 424 fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {") 425 fmt.Fprintln(g.out, " isTopLevel := in.IsStart()") 426 err := g.genTypeDecoderNoCheck(t, "*out", fieldTags{}, 1) 427 if err != nil { 428 return err 429 } 430 fmt.Fprintln(g.out, " if isTopLevel {") 431 fmt.Fprintln(g.out, " in.Consumed()") 432 fmt.Fprintln(g.out, " }") 433 fmt.Fprintln(g.out, "}") 434 435 return nil 436} 437 438func (g *Generator) genStructDecoder(t reflect.Type) error { 439 if t.Kind() != reflect.Struct { 440 return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct type", t) 441 } 442 443 fname := g.getDecoderName(t) 444 typ := g.getType(t) 445 446 fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {") 447 fmt.Fprintln(g.out, " isTopLevel := in.IsStart()") 448 fmt.Fprintln(g.out, " if in.IsNull() {") 449 fmt.Fprintln(g.out, " if isTopLevel {") 450 fmt.Fprintln(g.out, " in.Consumed()") 451 fmt.Fprintln(g.out, " }") 452 fmt.Fprintln(g.out, " in.Skip()") 453 fmt.Fprintln(g.out, " return") 454 fmt.Fprintln(g.out, " }") 455 456 // Init embedded pointer fields. 457 for i := 0; i < t.NumField(); i++ { 458 f := t.Field(i) 459 if !f.Anonymous || f.Type.Kind() != reflect.Ptr { 460 continue 461 } 462 fmt.Fprintln(g.out, " out."+f.Name+" = new("+g.getType(f.Type.Elem())+")") 463 } 464 465 fs, err := getStructFields(t) 466 if err != nil { 467 return fmt.Errorf("cannot generate decoder for %v: %v", t, err) 468 } 469 470 for _, f := range fs { 471 g.genRequiredFieldSet(t, f) 472 } 473 474 fmt.Fprintln(g.out, " in.Delim('{')") 475 fmt.Fprintln(g.out, " for !in.IsDelim('}') {") 476 fmt.Fprintln(g.out, " key := in.UnsafeString()") 477 fmt.Fprintln(g.out, " in.WantColon()") 478 fmt.Fprintln(g.out, " if in.IsNull() {") 479 fmt.Fprintln(g.out, " in.Skip()") 480 fmt.Fprintln(g.out, " in.WantComma()") 481 fmt.Fprintln(g.out, " continue") 482 fmt.Fprintln(g.out, " }") 483 484 fmt.Fprintln(g.out, " switch key {") 485 for _, f := range fs { 486 if err := g.genStructFieldDecoder(t, f); err != nil { 487 return err 488 } 489 } 490 491 fmt.Fprintln(g.out, " default:") 492 if g.disallowUnknownFields { 493 fmt.Fprintln(g.out, ` in.AddError(&jlexer.LexerError{ 494 Offset: in.GetPos(), 495 Reason: "unknown field", 496 Data: key, 497 })`) 498 } else if hasUnknownsUnmarshaler(t) { 499 fmt.Fprintln(g.out, " out.UnmarshalUnknown(in, key)") 500 } else { 501 fmt.Fprintln(g.out, " in.SkipRecursive()") 502 } 503 fmt.Fprintln(g.out, " }") 504 fmt.Fprintln(g.out, " in.WantComma()") 505 fmt.Fprintln(g.out, " }") 506 fmt.Fprintln(g.out, " in.Delim('}')") 507 fmt.Fprintln(g.out, " if isTopLevel {") 508 fmt.Fprintln(g.out, " in.Consumed()") 509 fmt.Fprintln(g.out, " }") 510 511 for _, f := range fs { 512 g.genRequiredFieldCheck(t, f) 513 } 514 515 fmt.Fprintln(g.out, "}") 516 517 return nil 518} 519 520func (g *Generator) genStructUnmarshaler(t reflect.Type) error { 521 switch t.Kind() { 522 case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct: 523 default: 524 return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct/slice/array/map type", t) 525 } 526 527 fname := g.getDecoderName(t) 528 typ := g.getType(t) 529 530 if !g.noStdMarshalers { 531 fmt.Fprintln(g.out, "// UnmarshalJSON supports json.Unmarshaler interface") 532 fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalJSON(data []byte) error {") 533 fmt.Fprintln(g.out, " r := jlexer.Lexer{Data: data}") 534 fmt.Fprintln(g.out, " "+fname+"(&r, v)") 535 fmt.Fprintln(g.out, " return r.Error()") 536 fmt.Fprintln(g.out, "}") 537 } 538 539 fmt.Fprintln(g.out, "// UnmarshalEasyJSON supports easyjson.Unmarshaler interface") 540 fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalEasyJSON(l *jlexer.Lexer) {") 541 fmt.Fprintln(g.out, " "+fname+"(l, v)") 542 fmt.Fprintln(g.out, "}") 543 544 return nil 545} 546