1//+build ignore 2 3// msg_generate.go is meant to run with go generate. It will use 4// go/{importer,types} to track down all the RR struct types. Then for each type 5// it will generate pack/unpack methods based on the struct tags. The generated source is 6// written to zmsg.go, and is meant to be checked into git. 7package main 8 9import ( 10 "bytes" 11 "fmt" 12 "go/format" 13 "go/types" 14 "log" 15 "os" 16 "strings" 17 18 "golang.org/x/tools/go/packages" 19) 20 21var packageHdr = ` 22// Code generated by "go run msg_generate.go"; DO NOT EDIT. 23 24package dns 25 26` 27 28// getTypeStruct will take a type and the package scope, and return the 29// (innermost) struct if the type is considered a RR type (currently defined as 30// those structs beginning with a RR_Header, could be redefined as implementing 31// the RR interface). The bool return value indicates if embedded structs were 32// resolved. 33func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { 34 st, ok := t.Underlying().(*types.Struct) 35 if !ok { 36 return nil, false 37 } 38 if st.NumFields() == 0 { 39 return nil, false 40 } 41 if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { 42 return st, false 43 } 44 if st.Field(0).Anonymous() { 45 st, _ := getTypeStruct(st.Field(0).Type(), scope) 46 return st, true 47 } 48 return nil, false 49} 50 51// loadModule retrieves package description for a given module. 52func loadModule(name string) (*types.Package, error) { 53 conf := packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo} 54 pkgs, err := packages.Load(&conf, name) 55 if err != nil { 56 return nil, err 57 } 58 return pkgs[0].Types, nil 59} 60 61func main() { 62 // Import and type-check the package 63 pkg, err := loadModule("github.com/miekg/dns") 64 fatalIfErr(err) 65 scope := pkg.Scope() 66 67 // Collect actual types (*X) 68 var namedTypes []string 69 for _, name := range scope.Names() { 70 o := scope.Lookup(name) 71 if o == nil || !o.Exported() { 72 continue 73 } 74 if st, _ := getTypeStruct(o.Type(), scope); st == nil { 75 continue 76 } 77 if name == "PrivateRR" { 78 continue 79 } 80 81 // Check if corresponding TypeX exists 82 if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { 83 log.Fatalf("Constant Type%s does not exist.", o.Name()) 84 } 85 86 namedTypes = append(namedTypes, o.Name()) 87 } 88 89 b := &bytes.Buffer{} 90 b.WriteString(packageHdr) 91 92 fmt.Fprint(b, "// pack*() functions\n\n") 93 for _, name := range namedTypes { 94 o := scope.Lookup(name) 95 st, _ := getTypeStruct(o.Type(), scope) 96 97 fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {\n", name) 98 for i := 1; i < st.NumFields(); i++ { 99 o := func(s string) { 100 fmt.Fprintf(b, s, st.Field(i).Name()) 101 fmt.Fprint(b, `if err != nil { 102return off, err 103} 104`) 105 } 106 107 if _, ok := st.Field(i).Type().(*types.Slice); ok { 108 switch st.Tag(i) { 109 case `dns:"-"`: // ignored 110 case `dns:"txt"`: 111 o("off, err = packStringTxt(rr.%s, msg, off)\n") 112 case `dns:"opt"`: 113 o("off, err = packDataOpt(rr.%s, msg, off)\n") 114 case `dns:"nsec"`: 115 o("off, err = packDataNsec(rr.%s, msg, off)\n") 116 case `dns:"pairs"`: 117 o("off, err = packDataSVCB(rr.%s, msg, off)\n") 118 case `dns:"domain-name"`: 119 o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n") 120 case `dns:"apl"`: 121 o("off, err = packDataApl(rr.%s, msg, off)\n") 122 default: 123 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 124 } 125 continue 126 } 127 128 switch { 129 case st.Tag(i) == `dns:"-"`: // ignored 130 case st.Tag(i) == `dns:"cdomain-name"`: 131 o("off, err = packDomainName(rr.%s, msg, off, compression, compress)\n") 132 case st.Tag(i) == `dns:"domain-name"`: 133 o("off, err = packDomainName(rr.%s, msg, off, compression, false)\n") 134 case st.Tag(i) == `dns:"a"`: 135 o("off, err = packDataA(rr.%s, msg, off)\n") 136 case st.Tag(i) == `dns:"aaaa"`: 137 o("off, err = packDataAAAA(rr.%s, msg, off)\n") 138 case st.Tag(i) == `dns:"uint48"`: 139 o("off, err = packUint48(rr.%s, msg, off)\n") 140 case st.Tag(i) == `dns:"txt"`: 141 o("off, err = packString(rr.%s, msg, off)\n") 142 143 case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32 144 fallthrough 145 case st.Tag(i) == `dns:"base32"`: 146 o("off, err = packStringBase32(rr.%s, msg, off)\n") 147 148 case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64 149 fallthrough 150 case st.Tag(i) == `dns:"base64"`: 151 o("off, err = packStringBase64(rr.%s, msg, off)\n") 152 153 case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`): 154 // directly write instead of using o() so we get the error check in the correct place 155 field := st.Field(i).Name() 156 fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty 157if rr.%s != "-" { 158 off, err = packStringHex(rr.%s, msg, off) 159 if err != nil { 160 return off, err 161 } 162} 163`, field, field) 164 continue 165 case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex 166 fallthrough 167 case st.Tag(i) == `dns:"hex"`: 168 o("off, err = packStringHex(rr.%s, msg, off)\n") 169 case st.Tag(i) == `dns:"any"`: 170 o("off, err = packStringAny(rr.%s, msg, off)\n") 171 case st.Tag(i) == `dns:"octet"`: 172 o("off, err = packStringOctet(rr.%s, msg, off)\n") 173 case st.Tag(i) == "": 174 switch st.Field(i).Type().(*types.Basic).Kind() { 175 case types.Uint8: 176 o("off, err = packUint8(rr.%s, msg, off)\n") 177 case types.Uint16: 178 o("off, err = packUint16(rr.%s, msg, off)\n") 179 case types.Uint32: 180 o("off, err = packUint32(rr.%s, msg, off)\n") 181 case types.Uint64: 182 o("off, err = packUint64(rr.%s, msg, off)\n") 183 case types.String: 184 o("off, err = packString(rr.%s, msg, off)\n") 185 default: 186 log.Fatalln(name, st.Field(i).Name()) 187 } 188 default: 189 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 190 } 191 } 192 fmt.Fprintln(b, "return off, nil }\n") 193 } 194 195 fmt.Fprint(b, "// unpack*() functions\n\n") 196 for _, name := range namedTypes { 197 o := scope.Lookup(name) 198 st, _ := getTypeStruct(o.Type(), scope) 199 200 fmt.Fprintf(b, "func (rr *%s) unpack(msg []byte, off int) (off1 int, err error) {\n", name) 201 fmt.Fprint(b, `rdStart := off 202_ = rdStart 203 204`) 205 for i := 1; i < st.NumFields(); i++ { 206 o := func(s string) { 207 fmt.Fprintf(b, s, st.Field(i).Name()) 208 fmt.Fprint(b, `if err != nil { 209return off, err 210} 211`) 212 } 213 214 // size-* are special, because they reference a struct member we should use for the length. 215 if strings.HasPrefix(st.Tag(i), `dns:"size-`) { 216 structMember := structMember(st.Tag(i)) 217 structTag := structTag(st.Tag(i)) 218 switch structTag { 219 case "hex": 220 fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 221 case "base32": 222 fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 223 case "base64": 224 fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 225 default: 226 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 227 } 228 fmt.Fprint(b, `if err != nil { 229return off, err 230} 231`) 232 continue 233 } 234 235 if _, ok := st.Field(i).Type().(*types.Slice); ok { 236 switch st.Tag(i) { 237 case `dns:"-"`: // ignored 238 case `dns:"txt"`: 239 o("rr.%s, off, err = unpackStringTxt(msg, off)\n") 240 case `dns:"opt"`: 241 o("rr.%s, off, err = unpackDataOpt(msg, off)\n") 242 case `dns:"nsec"`: 243 o("rr.%s, off, err = unpackDataNsec(msg, off)\n") 244 case `dns:"pairs"`: 245 o("rr.%s, off, err = unpackDataSVCB(msg, off)\n") 246 case `dns:"domain-name"`: 247 o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 248 case `dns:"apl"`: 249 o("rr.%s, off, err = unpackDataApl(msg, off)\n") 250 default: 251 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 252 } 253 continue 254 } 255 256 switch st.Tag(i) { 257 case `dns:"-"`: // ignored 258 case `dns:"cdomain-name"`: 259 fallthrough 260 case `dns:"domain-name"`: 261 o("rr.%s, off, err = UnpackDomainName(msg, off)\n") 262 case `dns:"a"`: 263 o("rr.%s, off, err = unpackDataA(msg, off)\n") 264 case `dns:"aaaa"`: 265 o("rr.%s, off, err = unpackDataAAAA(msg, off)\n") 266 case `dns:"uint48"`: 267 o("rr.%s, off, err = unpackUint48(msg, off)\n") 268 case `dns:"txt"`: 269 o("rr.%s, off, err = unpackString(msg, off)\n") 270 case `dns:"base32"`: 271 o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 272 case `dns:"base64"`: 273 o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 274 case `dns:"hex"`: 275 o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 276 case `dns:"any"`: 277 o("rr.%s, off, err = unpackStringAny(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 278 case `dns:"octet"`: 279 o("rr.%s, off, err = unpackStringOctet(msg, off)\n") 280 case "": 281 switch st.Field(i).Type().(*types.Basic).Kind() { 282 case types.Uint8: 283 o("rr.%s, off, err = unpackUint8(msg, off)\n") 284 case types.Uint16: 285 o("rr.%s, off, err = unpackUint16(msg, off)\n") 286 case types.Uint32: 287 o("rr.%s, off, err = unpackUint32(msg, off)\n") 288 case types.Uint64: 289 o("rr.%s, off, err = unpackUint64(msg, off)\n") 290 case types.String: 291 o("rr.%s, off, err = unpackString(msg, off)\n") 292 default: 293 log.Fatalln(name, st.Field(i).Name()) 294 } 295 default: 296 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 297 } 298 // If we've hit len(msg) we return without error. 299 if i < st.NumFields()-1 { 300 fmt.Fprintf(b, `if off == len(msg) { 301return off, nil 302 } 303`) 304 } 305 } 306 fmt.Fprintf(b, "return off, nil }\n\n") 307 } 308 309 // gofmt 310 res, err := format.Source(b.Bytes()) 311 if err != nil { 312 b.WriteTo(os.Stderr) 313 log.Fatal(err) 314 } 315 316 // write result 317 f, err := os.Create("zmsg.go") 318 fatalIfErr(err) 319 defer f.Close() 320 f.Write(res) 321} 322 323// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string. 324func structMember(s string) string { 325 fields := strings.Split(s, ":") 326 if len(fields) == 0 { 327 return "" 328 } 329 f := fields[len(fields)-1] 330 // f should have a closing " 331 if len(f) > 1 { 332 return f[:len(f)-1] 333 } 334 return f 335} 336 337// structTag will take a tag like dns:"size-base32:SaltLength" and return base32. 338func structTag(s string) string { 339 fields := strings.Split(s, ":") 340 if len(fields) < 2 { 341 return "" 342 } 343 return fields[1][len("\"size-"):] 344} 345 346func fatalIfErr(err error) { 347 if err != nil { 348 log.Fatal(err) 349 } 350} 351