1//+build ignore 2 3// msg_generate.go is meant to run with go generate. It will use 4// go/{importer,types} to track down all the RR struct types. Then for each type 5// it will generate pack/unpack methods based on the struct tags. The generated source is 6// written to zmsg.go, and is meant to be checked into git. 7package main 8 9import ( 10 "bytes" 11 "fmt" 12 "go/format" 13 "go/importer" 14 "go/types" 15 "log" 16 "os" 17 "strings" 18) 19 20var packageHdr = ` 21// Code generated by "go run msg_generate.go"; DO NOT EDIT. 22 23package dns 24 25` 26 27// getTypeStruct will take a type and the package scope, and return the 28// (innermost) struct if the type is considered a RR type (currently defined as 29// those structs beginning with a RR_Header, could be redefined as implementing 30// the RR interface). The bool return value indicates if embedded structs were 31// resolved. 32func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { 33 st, ok := t.Underlying().(*types.Struct) 34 if !ok { 35 return nil, false 36 } 37 if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { 38 return st, false 39 } 40 if st.Field(0).Anonymous() { 41 st, _ := getTypeStruct(st.Field(0).Type(), scope) 42 return st, true 43 } 44 return nil, false 45} 46 47func main() { 48 // Import and type-check the package 49 pkg, err := importer.Default().Import("github.com/miekg/dns") 50 fatalIfErr(err) 51 scope := pkg.Scope() 52 53 // Collect actual types (*X) 54 var namedTypes []string 55 for _, name := range scope.Names() { 56 o := scope.Lookup(name) 57 if o == nil || !o.Exported() { 58 continue 59 } 60 if st, _ := getTypeStruct(o.Type(), scope); st == nil { 61 continue 62 } 63 if name == "PrivateRR" { 64 continue 65 } 66 67 // Check if corresponding TypeX exists 68 if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { 69 log.Fatalf("Constant Type%s does not exist.", o.Name()) 70 } 71 72 namedTypes = append(namedTypes, o.Name()) 73 } 74 75 b := &bytes.Buffer{} 76 b.WriteString(packageHdr) 77 78 fmt.Fprint(b, "// pack*() functions\n\n") 79 for _, name := range namedTypes { 80 o := scope.Lookup(name) 81 st, _ := getTypeStruct(o.Type(), scope) 82 83 fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name) 84 fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress) 85if err != nil { 86 return off, err 87} 88headerEnd := off 89`) 90 for i := 1; i < st.NumFields(); i++ { 91 o := func(s string) { 92 fmt.Fprintf(b, s, st.Field(i).Name()) 93 fmt.Fprint(b, `if err != nil { 94return off, err 95} 96`) 97 } 98 99 if _, ok := st.Field(i).Type().(*types.Slice); ok { 100 switch st.Tag(i) { 101 case `dns:"-"`: // ignored 102 case `dns:"txt"`: 103 o("off, err = packStringTxt(rr.%s, msg, off)\n") 104 case `dns:"opt"`: 105 o("off, err = packDataOpt(rr.%s, msg, off)\n") 106 case `dns:"nsec"`: 107 o("off, err = packDataNsec(rr.%s, msg, off)\n") 108 case `dns:"domain-name"`: 109 o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n") 110 default: 111 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 112 } 113 continue 114 } 115 116 switch { 117 case st.Tag(i) == `dns:"-"`: // ignored 118 case st.Tag(i) == `dns:"cdomain-name"`: 119 o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n") 120 case st.Tag(i) == `dns:"domain-name"`: 121 o("off, err = PackDomainName(rr.%s, msg, off, compression, false)\n") 122 case st.Tag(i) == `dns:"a"`: 123 o("off, err = packDataA(rr.%s, msg, off)\n") 124 case st.Tag(i) == `dns:"aaaa"`: 125 o("off, err = packDataAAAA(rr.%s, msg, off)\n") 126 case st.Tag(i) == `dns:"uint48"`: 127 o("off, err = packUint48(rr.%s, msg, off)\n") 128 case st.Tag(i) == `dns:"txt"`: 129 o("off, err = packString(rr.%s, msg, off)\n") 130 131 case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32 132 fallthrough 133 case st.Tag(i) == `dns:"base32"`: 134 o("off, err = packStringBase32(rr.%s, msg, off)\n") 135 136 case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64 137 fallthrough 138 case st.Tag(i) == `dns:"base64"`: 139 o("off, err = packStringBase64(rr.%s, msg, off)\n") 140 141 case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`): 142 // directly write instead of using o() so we get the error check in the correct place 143 field := st.Field(i).Name() 144 fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty 145if rr.%s != "-" { 146 off, err = packStringHex(rr.%s, msg, off) 147 if err != nil { 148 return off, err 149 } 150} 151`, field, field) 152 continue 153 case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex 154 fallthrough 155 case st.Tag(i) == `dns:"hex"`: 156 o("off, err = packStringHex(rr.%s, msg, off)\n") 157 158 case st.Tag(i) == `dns:"octet"`: 159 o("off, err = packStringOctet(rr.%s, msg, off)\n") 160 case st.Tag(i) == "": 161 switch st.Field(i).Type().(*types.Basic).Kind() { 162 case types.Uint8: 163 o("off, err = packUint8(rr.%s, msg, off)\n") 164 case types.Uint16: 165 o("off, err = packUint16(rr.%s, msg, off)\n") 166 case types.Uint32: 167 o("off, err = packUint32(rr.%s, msg, off)\n") 168 case types.Uint64: 169 o("off, err = packUint64(rr.%s, msg, off)\n") 170 case types.String: 171 o("off, err = packString(rr.%s, msg, off)\n") 172 default: 173 log.Fatalln(name, st.Field(i).Name()) 174 } 175 default: 176 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 177 } 178 } 179 // We have packed everything, only now we know the rdlength of this RR 180 fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)") 181 fmt.Fprintln(b, "return off, nil }\n") 182 } 183 184 fmt.Fprint(b, "// unpack*() functions\n\n") 185 for _, name := range namedTypes { 186 o := scope.Lookup(name) 187 st, _ := getTypeStruct(o.Type(), scope) 188 189 fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name) 190 fmt.Fprintf(b, "rr := new(%s)\n", name) 191 fmt.Fprint(b, "rr.Hdr = h\n") 192 fmt.Fprint(b, `if noRdata(h) { 193return rr, off, nil 194 } 195var err error 196rdStart := off 197_ = rdStart 198 199`) 200 for i := 1; i < st.NumFields(); i++ { 201 o := func(s string) { 202 fmt.Fprintf(b, s, st.Field(i).Name()) 203 fmt.Fprint(b, `if err != nil { 204return rr, off, err 205} 206`) 207 } 208 209 // size-* are special, because they reference a struct member we should use for the length. 210 if strings.HasPrefix(st.Tag(i), `dns:"size-`) { 211 structMember := structMember(st.Tag(i)) 212 structTag := structTag(st.Tag(i)) 213 switch structTag { 214 case "hex": 215 fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 216 case "base32": 217 fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 218 case "base64": 219 fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 220 default: 221 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 222 } 223 fmt.Fprint(b, `if err != nil { 224return rr, off, err 225} 226`) 227 continue 228 } 229 230 if _, ok := st.Field(i).Type().(*types.Slice); ok { 231 switch st.Tag(i) { 232 case `dns:"-"`: // ignored 233 case `dns:"txt"`: 234 o("rr.%s, off, err = unpackStringTxt(msg, off)\n") 235 case `dns:"opt"`: 236 o("rr.%s, off, err = unpackDataOpt(msg, off)\n") 237 case `dns:"nsec"`: 238 o("rr.%s, off, err = unpackDataNsec(msg, off)\n") 239 case `dns:"domain-name"`: 240 o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 241 default: 242 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 243 } 244 continue 245 } 246 247 switch st.Tag(i) { 248 case `dns:"-"`: // ignored 249 case `dns:"cdomain-name"`: 250 fallthrough 251 case `dns:"domain-name"`: 252 o("rr.%s, off, err = UnpackDomainName(msg, off)\n") 253 case `dns:"a"`: 254 o("rr.%s, off, err = unpackDataA(msg, off)\n") 255 case `dns:"aaaa"`: 256 o("rr.%s, off, err = unpackDataAAAA(msg, off)\n") 257 case `dns:"uint48"`: 258 o("rr.%s, off, err = unpackUint48(msg, off)\n") 259 case `dns:"txt"`: 260 o("rr.%s, off, err = unpackString(msg, off)\n") 261 case `dns:"base32"`: 262 o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 263 case `dns:"base64"`: 264 o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 265 case `dns:"hex"`: 266 o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 267 case `dns:"octet"`: 268 o("rr.%s, off, err = unpackStringOctet(msg, off)\n") 269 case "": 270 switch st.Field(i).Type().(*types.Basic).Kind() { 271 case types.Uint8: 272 o("rr.%s, off, err = unpackUint8(msg, off)\n") 273 case types.Uint16: 274 o("rr.%s, off, err = unpackUint16(msg, off)\n") 275 case types.Uint32: 276 o("rr.%s, off, err = unpackUint32(msg, off)\n") 277 case types.Uint64: 278 o("rr.%s, off, err = unpackUint64(msg, off)\n") 279 case types.String: 280 o("rr.%s, off, err = unpackString(msg, off)\n") 281 default: 282 log.Fatalln(name, st.Field(i).Name()) 283 } 284 default: 285 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 286 } 287 // If we've hit len(msg) we return without error. 288 if i < st.NumFields()-1 { 289 fmt.Fprintf(b, `if off == len(msg) { 290return rr, off, nil 291 } 292`) 293 } 294 } 295 fmt.Fprintf(b, "return rr, off, err }\n\n") 296 } 297 // Generate typeToUnpack map 298 fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){") 299 for _, name := range namedTypes { 300 if name == "RFC3597" { 301 continue 302 } 303 fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name) 304 } 305 fmt.Fprintln(b, "}\n") 306 307 // gofmt 308 res, err := format.Source(b.Bytes()) 309 if err != nil { 310 b.WriteTo(os.Stderr) 311 log.Fatal(err) 312 } 313 314 // write result 315 f, err := os.Create("zmsg.go") 316 fatalIfErr(err) 317 defer f.Close() 318 f.Write(res) 319} 320 321// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string. 322func structMember(s string) string { 323 fields := strings.Split(s, ":") 324 if len(fields) == 0 { 325 return "" 326 } 327 f := fields[len(fields)-1] 328 // f should have a closing " 329 if len(f) > 1 { 330 return f[:len(f)-1] 331 } 332 return f 333} 334 335// structTag will take a tag like dns:"size-base32:SaltLength" and return base32. 336func structTag(s string) string { 337 fields := strings.Split(s, ":") 338 if len(fields) < 2 { 339 return "" 340 } 341 return fields[1][len("\"size-"):] 342} 343 344func fatalIfErr(err error) { 345 if err != nil { 346 log.Fatal(err) 347 } 348} 349