1package gen 2 3import ( 4 "encoding" 5 "encoding/json" 6 "fmt" 7 "reflect" 8 "strings" 9 "unicode" 10 11 "github.com/mailru/easyjson" 12) 13 14// Target this byte size for initial slice allocation to reduce garbage collection. 15const minSliceBytes = 64 16 17func (g *Generator) getDecoderName(t reflect.Type) string { 18 return g.functionName("decode", t) 19} 20 21var primitiveDecoders = map[reflect.Kind]string{ 22 reflect.String: "in.String()", 23 reflect.Bool: "in.Bool()", 24 reflect.Int: "in.Int()", 25 reflect.Int8: "in.Int8()", 26 reflect.Int16: "in.Int16()", 27 reflect.Int32: "in.Int32()", 28 reflect.Int64: "in.Int64()", 29 reflect.Uint: "in.Uint()", 30 reflect.Uint8: "in.Uint8()", 31 reflect.Uint16: "in.Uint16()", 32 reflect.Uint32: "in.Uint32()", 33 reflect.Uint64: "in.Uint64()", 34 reflect.Float32: "in.Float32()", 35 reflect.Float64: "in.Float64()", 36} 37 38var primitiveStringDecoders = map[reflect.Kind]string{ 39 reflect.String: "in.String()", 40 reflect.Int: "in.IntStr()", 41 reflect.Int8: "in.Int8Str()", 42 reflect.Int16: "in.Int16Str()", 43 reflect.Int32: "in.Int32Str()", 44 reflect.Int64: "in.Int64Str()", 45 reflect.Uint: "in.UintStr()", 46 reflect.Uint8: "in.Uint8Str()", 47 reflect.Uint16: "in.Uint16Str()", 48 reflect.Uint32: "in.Uint32Str()", 49 reflect.Uint64: "in.Uint64Str()", 50 reflect.Uintptr: "in.UintptrStr()", 51 reflect.Float32: "in.Float32Str()", 52 reflect.Float64: "in.Float64Str()", 53} 54 55var customDecoders = map[string]string{ 56 "json.Number": "in.JsonNumber()", 57} 58 59// genTypeDecoder generates decoding code for the type t, but uses unmarshaler interface if implemented by t. 60func (g *Generator) genTypeDecoder(t reflect.Type, out string, tags fieldTags, indent int) error { 61 ws := strings.Repeat(" ", indent) 62 63 unmarshalerIface := reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem() 64 if reflect.PtrTo(t).Implements(unmarshalerIface) { 65 fmt.Fprintln(g.out, ws+"("+out+").UnmarshalEasyJSON(in)") 66 return nil 67 } 68 69 unmarshalerIface = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() 70 if reflect.PtrTo(t).Implements(unmarshalerIface) { 71 fmt.Fprintln(g.out, ws+"if data := in.Raw(); in.Ok() {") 72 fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalJSON(data) )") 73 fmt.Fprintln(g.out, ws+"}") 74 return nil 75 } 76 77 unmarshalerIface = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() 78 if reflect.PtrTo(t).Implements(unmarshalerIface) { 79 fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {") 80 fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalText(data) )") 81 fmt.Fprintln(g.out, ws+"}") 82 return nil 83 } 84 85 err := g.genTypeDecoderNoCheck(t, out, tags, indent) 86 return err 87} 88 89// returns true of the type t implements one of the custom unmarshaler interfaces 90func hasCustomUnmarshaler(t reflect.Type) bool { 91 t = reflect.PtrTo(t) 92 return t.Implements(reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()) || 93 t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) || 94 t.Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()) 95} 96 97// genTypeDecoderNoCheck generates decoding code for the type t. 98func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags fieldTags, indent int) error { 99 ws := strings.Repeat(" ", indent) 100 // Check whether type is primitive, needs to be done after interface check. 101 if dec := customDecoders[t.String()]; dec != "" { 102 fmt.Fprintln(g.out, ws+out+" = "+dec) 103 return nil 104 } else if dec := primitiveStringDecoders[t.Kind()]; dec != "" && tags.asString { 105 fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")") 106 return nil 107 } else if dec := primitiveDecoders[t.Kind()]; dec != "" { 108 fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")") 109 return nil 110 } 111 112 switch t.Kind() { 113 case reflect.Slice: 114 tmpVar := g.uniqueVarName() 115 elem := t.Elem() 116 117 if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" { 118 fmt.Fprintln(g.out, ws+"if in.IsNull() {") 119 fmt.Fprintln(g.out, ws+" in.Skip()") 120 fmt.Fprintln(g.out, ws+" "+out+" = nil") 121 fmt.Fprintln(g.out, ws+"} else {") 122 fmt.Fprintln(g.out, ws+" "+out+" = in.Bytes()") 123 fmt.Fprintln(g.out, ws+"}") 124 125 } else { 126 127 capacity := minSliceBytes / elem.Size() 128 if capacity == 0 { 129 capacity = 1 130 } 131 132 fmt.Fprintln(g.out, ws+"if in.IsNull() {") 133 fmt.Fprintln(g.out, ws+" in.Skip()") 134 fmt.Fprintln(g.out, ws+" "+out+" = nil") 135 fmt.Fprintln(g.out, ws+"} else {") 136 fmt.Fprintln(g.out, ws+" in.Delim('[')") 137 fmt.Fprintln(g.out, ws+" if "+out+" == nil {") 138 fmt.Fprintln(g.out, ws+" if !in.IsDelim(']') {") 139 fmt.Fprintln(g.out, ws+" "+out+" = make("+g.getType(t)+", 0, "+fmt.Sprint(capacity)+")") 140 fmt.Fprintln(g.out, ws+" } else {") 141 fmt.Fprintln(g.out, ws+" "+out+" = "+g.getType(t)+"{}") 142 fmt.Fprintln(g.out, ws+" }") 143 fmt.Fprintln(g.out, ws+" } else { ") 144 fmt.Fprintln(g.out, ws+" "+out+" = ("+out+")[:0]") 145 fmt.Fprintln(g.out, ws+" }") 146 fmt.Fprintln(g.out, ws+" for !in.IsDelim(']') {") 147 fmt.Fprintln(g.out, ws+" var "+tmpVar+" "+g.getType(elem)) 148 149 if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil { 150 return err 151 } 152 153 fmt.Fprintln(g.out, ws+" "+out+" = append("+out+", "+tmpVar+")") 154 fmt.Fprintln(g.out, ws+" in.WantComma()") 155 fmt.Fprintln(g.out, ws+" }") 156 fmt.Fprintln(g.out, ws+" in.Delim(']')") 157 fmt.Fprintln(g.out, ws+"}") 158 } 159 160 case reflect.Array: 161 iterVar := g.uniqueVarName() 162 elem := t.Elem() 163 164 if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" { 165 fmt.Fprintln(g.out, ws+"if in.IsNull() {") 166 fmt.Fprintln(g.out, ws+" in.Skip()") 167 fmt.Fprintln(g.out, ws+"} else {") 168 fmt.Fprintln(g.out, ws+" copy("+out+"[:], in.Bytes())") 169 fmt.Fprintln(g.out, ws+"}") 170 171 } else { 172 173 length := t.Len() 174 175 fmt.Fprintln(g.out, ws+"if in.IsNull() {") 176 fmt.Fprintln(g.out, ws+" in.Skip()") 177 fmt.Fprintln(g.out, ws+"} else {") 178 fmt.Fprintln(g.out, ws+" in.Delim('[')") 179 fmt.Fprintln(g.out, ws+" "+iterVar+" := 0") 180 fmt.Fprintln(g.out, ws+" for !in.IsDelim(']') {") 181 fmt.Fprintln(g.out, ws+" if "+iterVar+" < "+fmt.Sprint(length)+" {") 182 183 if err := g.genTypeDecoder(elem, "("+out+")["+iterVar+"]", tags, indent+3); err != nil { 184 return err 185 } 186 187 fmt.Fprintln(g.out, ws+" "+iterVar+"++") 188 fmt.Fprintln(g.out, ws+" } else {") 189 fmt.Fprintln(g.out, ws+" in.SkipRecursive()") 190 fmt.Fprintln(g.out, ws+" }") 191 fmt.Fprintln(g.out, ws+" in.WantComma()") 192 fmt.Fprintln(g.out, ws+" }") 193 fmt.Fprintln(g.out, ws+" in.Delim(']')") 194 fmt.Fprintln(g.out, ws+"}") 195 } 196 197 case reflect.Struct: 198 dec := g.getDecoderName(t) 199 g.addType(t) 200 201 if len(out) > 0 && out[0] == '*' { 202 // NOTE: In order to remove an extra reference to a pointer 203 fmt.Fprintln(g.out, ws+dec+"(in, "+out[1:]+")") 204 } else { 205 fmt.Fprintln(g.out, ws+dec+"(in, &"+out+")") 206 } 207 208 case reflect.Ptr: 209 fmt.Fprintln(g.out, ws+"if in.IsNull() {") 210 fmt.Fprintln(g.out, ws+" in.Skip()") 211 fmt.Fprintln(g.out, ws+" "+out+" = nil") 212 fmt.Fprintln(g.out, ws+"} else {") 213 fmt.Fprintln(g.out, ws+" if "+out+" == nil {") 214 fmt.Fprintln(g.out, ws+" "+out+" = new("+g.getType(t.Elem())+")") 215 fmt.Fprintln(g.out, ws+" }") 216 217 if err := g.genTypeDecoder(t.Elem(), "*"+out, tags, indent+1); err != nil { 218 return err 219 } 220 221 fmt.Fprintln(g.out, ws+"}") 222 223 case reflect.Map: 224 key := t.Key() 225 keyDec, ok := primitiveStringDecoders[key.Kind()] 226 if !ok && !hasCustomUnmarshaler(key) { 227 return fmt.Errorf("map type %v not supported: only string and integer keys and types implementing json.Unmarshaler are allowed", key) 228 } // else assume the caller knows what they are doing and that the custom unmarshaler performs the translation from string or integer keys to the key type 229 elem := t.Elem() 230 tmpVar := g.uniqueVarName() 231 232 fmt.Fprintln(g.out, ws+"if in.IsNull() {") 233 fmt.Fprintln(g.out, ws+" in.Skip()") 234 fmt.Fprintln(g.out, ws+"} else {") 235 fmt.Fprintln(g.out, ws+" in.Delim('{')") 236 fmt.Fprintln(g.out, ws+" if !in.IsDelim('}') {") 237 fmt.Fprintln(g.out, ws+" "+out+" = make("+g.getType(t)+")") 238 fmt.Fprintln(g.out, ws+" } else {") 239 fmt.Fprintln(g.out, ws+" "+out+" = nil") 240 fmt.Fprintln(g.out, ws+" }") 241 242 fmt.Fprintln(g.out, ws+" for !in.IsDelim('}') {") 243 // NOTE: extra check for TextUnmarshaler. It overrides default methods. 244 if reflect.PtrTo(key).Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()) { 245 fmt.Fprintln(g.out, ws+" var key "+g.getType(key)) 246 fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {") 247 fmt.Fprintln(g.out, ws+" in.AddError(key.UnmarshalText(data) )") 248 fmt.Fprintln(g.out, ws+"}") 249 } else if keyDec != "" { 250 fmt.Fprintln(g.out, ws+" key := "+g.getType(key)+"("+keyDec+")") 251 } else { 252 fmt.Fprintln(g.out, ws+" var key "+g.getType(key)) 253 if err := g.genTypeDecoder(key, "key", tags, indent+2); err != nil { 254 return err 255 } 256 } 257 258 fmt.Fprintln(g.out, ws+" in.WantColon()") 259 fmt.Fprintln(g.out, ws+" var "+tmpVar+" "+g.getType(elem)) 260 261 if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil { 262 return err 263 } 264 265 fmt.Fprintln(g.out, ws+" ("+out+")[key] = "+tmpVar) 266 fmt.Fprintln(g.out, ws+" in.WantComma()") 267 fmt.Fprintln(g.out, ws+" }") 268 fmt.Fprintln(g.out, ws+" in.Delim('}')") 269 fmt.Fprintln(g.out, ws+"}") 270 271 case reflect.Interface: 272 if t.NumMethod() != 0 { 273 return fmt.Errorf("interface type %v not supported: only interface{} is allowed", t) 274 } 275 fmt.Fprintln(g.out, ws+"if m, ok := "+out+".(easyjson.Unmarshaler); ok {") 276 fmt.Fprintln(g.out, ws+"m.UnmarshalEasyJSON(in)") 277 fmt.Fprintln(g.out, ws+"} else if m, ok := "+out+".(json.Unmarshaler); ok {") 278 fmt.Fprintln(g.out, ws+"_ = m.UnmarshalJSON(in.Raw())") 279 fmt.Fprintln(g.out, ws+"} else {") 280 fmt.Fprintln(g.out, ws+" "+out+" = in.Interface()") 281 fmt.Fprintln(g.out, ws+"}") 282 default: 283 return fmt.Errorf("don't know how to decode %v", t) 284 } 285 return nil 286 287} 288 289func (g *Generator) genStructFieldDecoder(t reflect.Type, f reflect.StructField) error { 290 jsonName := g.fieldNamer.GetJSONFieldName(t, f) 291 tags := parseFieldTags(f) 292 293 if tags.omit { 294 return nil 295 } 296 297 fmt.Fprintf(g.out, " case %q:\n", jsonName) 298 if err := g.genTypeDecoder(f.Type, "out."+f.Name, tags, 3); err != nil { 299 return err 300 } 301 302 if tags.required { 303 fmt.Fprintf(g.out, "%sSet = true\n", f.Name) 304 } 305 306 return nil 307} 308 309func (g *Generator) genRequiredFieldSet(t reflect.Type, f reflect.StructField) { 310 tags := parseFieldTags(f) 311 312 if !tags.required { 313 return 314 } 315 316 fmt.Fprintf(g.out, "var %sSet bool\n", f.Name) 317} 318 319func (g *Generator) genRequiredFieldCheck(t reflect.Type, f reflect.StructField) { 320 jsonName := g.fieldNamer.GetJSONFieldName(t, f) 321 tags := parseFieldTags(f) 322 323 if !tags.required { 324 return 325 } 326 327 g.imports["fmt"] = "fmt" 328 329 fmt.Fprintf(g.out, "if !%sSet {\n", f.Name) 330 fmt.Fprintf(g.out, " in.AddError(fmt.Errorf(\"key '%s' is required\"))\n", jsonName) 331 fmt.Fprintf(g.out, "}\n") 332} 333 334func mergeStructFields(fields1, fields2 []reflect.StructField) (fields []reflect.StructField) { 335 used := map[string]bool{} 336 for _, f := range fields2 { 337 used[f.Name] = true 338 fields = append(fields, f) 339 } 340 341 for _, f := range fields1 { 342 if !used[f.Name] { 343 fields = append(fields, f) 344 } 345 } 346 return 347} 348 349func getStructFields(t reflect.Type) ([]reflect.StructField, error) { 350 if t.Kind() != reflect.Struct { 351 return nil, fmt.Errorf("got %v; expected a struct", t) 352 } 353 354 var efields []reflect.StructField 355 for i := 0; i < t.NumField(); i++ { 356 f := t.Field(i) 357 tags := parseFieldTags(f) 358 if !f.Anonymous || tags.name != "" { 359 continue 360 } 361 362 t1 := f.Type 363 if t1.Kind() == reflect.Ptr { 364 t1 = t1.Elem() 365 } 366 367 fs, err := getStructFields(t1) 368 if err != nil { 369 return nil, fmt.Errorf("error processing embedded field: %v", err) 370 } 371 efields = mergeStructFields(efields, fs) 372 } 373 374 var fields []reflect.StructField 375 for i := 0; i < t.NumField(); i++ { 376 f := t.Field(i) 377 tags := parseFieldTags(f) 378 if f.Anonymous && tags.name == "" { 379 continue 380 } 381 382 c := []rune(f.Name)[0] 383 if unicode.IsUpper(c) { 384 fields = append(fields, f) 385 } 386 } 387 return mergeStructFields(efields, fields), nil 388} 389 390func (g *Generator) genDecoder(t reflect.Type) error { 391 switch t.Kind() { 392 case reflect.Slice, reflect.Array, reflect.Map: 393 return g.genSliceArrayDecoder(t) 394 default: 395 return g.genStructDecoder(t) 396 } 397} 398 399func (g *Generator) genSliceArrayDecoder(t reflect.Type) error { 400 switch t.Kind() { 401 case reflect.Slice, reflect.Array, reflect.Map: 402 default: 403 return fmt.Errorf("cannot generate encoder/decoder for %v, not a slice/array/map type", t) 404 } 405 406 fname := g.getDecoderName(t) 407 typ := g.getType(t) 408 409 fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {") 410 fmt.Fprintln(g.out, " isTopLevel := in.IsStart()") 411 err := g.genTypeDecoderNoCheck(t, "*out", fieldTags{}, 1) 412 if err != nil { 413 return err 414 } 415 fmt.Fprintln(g.out, " if isTopLevel {") 416 fmt.Fprintln(g.out, " in.Consumed()") 417 fmt.Fprintln(g.out, " }") 418 fmt.Fprintln(g.out, "}") 419 420 return nil 421} 422 423func (g *Generator) genStructDecoder(t reflect.Type) error { 424 if t.Kind() != reflect.Struct { 425 return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct type", t) 426 } 427 428 fname := g.getDecoderName(t) 429 typ := g.getType(t) 430 431 fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {") 432 fmt.Fprintln(g.out, " isTopLevel := in.IsStart()") 433 fmt.Fprintln(g.out, " if in.IsNull() {") 434 fmt.Fprintln(g.out, " if isTopLevel {") 435 fmt.Fprintln(g.out, " in.Consumed()") 436 fmt.Fprintln(g.out, " }") 437 fmt.Fprintln(g.out, " in.Skip()") 438 fmt.Fprintln(g.out, " return") 439 fmt.Fprintln(g.out, " }") 440 441 // Init embedded pointer fields. 442 for i := 0; i < t.NumField(); i++ { 443 f := t.Field(i) 444 if !f.Anonymous || f.Type.Kind() != reflect.Ptr { 445 continue 446 } 447 fmt.Fprintln(g.out, " out."+f.Name+" = new("+g.getType(f.Type.Elem())+")") 448 } 449 450 fs, err := getStructFields(t) 451 if err != nil { 452 return fmt.Errorf("cannot generate decoder for %v: %v", t, err) 453 } 454 455 for _, f := range fs { 456 g.genRequiredFieldSet(t, f) 457 } 458 459 fmt.Fprintln(g.out, " in.Delim('{')") 460 fmt.Fprintln(g.out, " for !in.IsDelim('}') {") 461 fmt.Fprintln(g.out, " key := in.UnsafeString()") 462 fmt.Fprintln(g.out, " in.WantColon()") 463 fmt.Fprintln(g.out, " if in.IsNull() {") 464 fmt.Fprintln(g.out, " in.Skip()") 465 fmt.Fprintln(g.out, " in.WantComma()") 466 fmt.Fprintln(g.out, " continue") 467 fmt.Fprintln(g.out, " }") 468 469 fmt.Fprintln(g.out, " switch key {") 470 for _, f := range fs { 471 if err := g.genStructFieldDecoder(t, f); err != nil { 472 return err 473 } 474 } 475 476 fmt.Fprintln(g.out, " default:") 477 if g.disallowUnknownFields { 478 fmt.Fprintln(g.out, ` in.AddError(&jlexer.LexerError{ 479 Offset: in.GetPos(), 480 Reason: "unknown field", 481 Data: key, 482 })`) 483 } else { 484 fmt.Fprintln(g.out, " in.SkipRecursive()") 485 } 486 fmt.Fprintln(g.out, " }") 487 fmt.Fprintln(g.out, " in.WantComma()") 488 fmt.Fprintln(g.out, " }") 489 fmt.Fprintln(g.out, " in.Delim('}')") 490 fmt.Fprintln(g.out, " if isTopLevel {") 491 fmt.Fprintln(g.out, " in.Consumed()") 492 fmt.Fprintln(g.out, " }") 493 494 for _, f := range fs { 495 g.genRequiredFieldCheck(t, f) 496 } 497 498 fmt.Fprintln(g.out, "}") 499 500 return nil 501} 502 503func (g *Generator) genStructUnmarshaler(t reflect.Type) error { 504 switch t.Kind() { 505 case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct: 506 default: 507 return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct/slice/array/map type", t) 508 } 509 510 fname := g.getDecoderName(t) 511 typ := g.getType(t) 512 513 if !g.noStdMarshalers { 514 fmt.Fprintln(g.out, "// UnmarshalJSON supports json.Unmarshaler interface") 515 fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalJSON(data []byte) error {") 516 fmt.Fprintln(g.out, " r := jlexer.Lexer{Data: data}") 517 fmt.Fprintln(g.out, " "+fname+"(&r, v)") 518 fmt.Fprintln(g.out, " return r.Error()") 519 fmt.Fprintln(g.out, "}") 520 } 521 522 fmt.Fprintln(g.out, "// UnmarshalEasyJSON supports easyjson.Unmarshaler interface") 523 fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalEasyJSON(l *jlexer.Lexer) {") 524 fmt.Fprintln(g.out, " "+fname+"(l, v)") 525 fmt.Fprintln(g.out, "}") 526 527 return nil 528} 529