1//+build ignore 2 3// msg_generate.go is meant to run with go generate. It will use 4// go/{importer,types} to track down all the RR struct types. Then for each type 5// it will generate pack/unpack methods based on the struct tags. The generated source is 6// written to zmsg.go, and is meant to be checked into git. 7package main 8 9import ( 10 "bytes" 11 "fmt" 12 "go/format" 13 "go/types" 14 "log" 15 "os" 16 "strings" 17 18 "golang.org/x/tools/go/packages" 19) 20 21var packageHdr = ` 22// Code generated by "go run msg_generate.go"; DO NOT EDIT. 23 24package dns 25 26` 27 28// getTypeStruct will take a type and the package scope, and return the 29// (innermost) struct if the type is considered a RR type (currently defined as 30// those structs beginning with a RR_Header, could be redefined as implementing 31// the RR interface). The bool return value indicates if embedded structs were 32// resolved. 33func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { 34 st, ok := t.Underlying().(*types.Struct) 35 if !ok { 36 return nil, false 37 } 38 if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { 39 return st, false 40 } 41 if st.Field(0).Anonymous() { 42 st, _ := getTypeStruct(st.Field(0).Type(), scope) 43 return st, true 44 } 45 return nil, false 46} 47 48// loadModule retrieves package description for a given module. 49func loadModule(name string) (*types.Package, error) { 50 conf := packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo} 51 pkgs, err := packages.Load(&conf, name) 52 if err != nil { 53 return nil, err 54 } 55 return pkgs[0].Types, nil 56} 57 58func main() { 59 // Import and type-check the package 60 pkg, err := loadModule("github.com/miekg/dns") 61 fatalIfErr(err) 62 scope := pkg.Scope() 63 64 // Collect actual types (*X) 65 var namedTypes []string 66 for _, name := range scope.Names() { 67 o := scope.Lookup(name) 68 if o == nil || !o.Exported() { 69 continue 70 } 71 if st, _ := getTypeStruct(o.Type(), scope); st == nil { 72 continue 73 } 74 if name == "PrivateRR" { 75 continue 76 } 77 78 // Check if corresponding TypeX exists 79 if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { 80 log.Fatalf("Constant Type%s does not exist.", o.Name()) 81 } 82 83 namedTypes = append(namedTypes, o.Name()) 84 } 85 86 b := &bytes.Buffer{} 87 b.WriteString(packageHdr) 88 89 fmt.Fprint(b, "// pack*() functions\n\n") 90 for _, name := range namedTypes { 91 o := scope.Lookup(name) 92 st, _ := getTypeStruct(o.Type(), scope) 93 94 fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {\n", name) 95 for i := 1; i < st.NumFields(); i++ { 96 o := func(s string) { 97 fmt.Fprintf(b, s, st.Field(i).Name()) 98 fmt.Fprint(b, `if err != nil { 99return off, err 100} 101`) 102 } 103 104 if _, ok := st.Field(i).Type().(*types.Slice); ok { 105 switch st.Tag(i) { 106 case `dns:"-"`: // ignored 107 case `dns:"txt"`: 108 o("off, err = packStringTxt(rr.%s, msg, off)\n") 109 case `dns:"opt"`: 110 o("off, err = packDataOpt(rr.%s, msg, off)\n") 111 case `dns:"nsec"`: 112 o("off, err = packDataNsec(rr.%s, msg, off)\n") 113 case `dns:"domain-name"`: 114 o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n") 115 case `dns:"apl"`: 116 o("off, err = packDataApl(rr.%s, msg, off)\n") 117 default: 118 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 119 } 120 continue 121 } 122 123 switch { 124 case st.Tag(i) == `dns:"-"`: // ignored 125 case st.Tag(i) == `dns:"cdomain-name"`: 126 o("off, err = packDomainName(rr.%s, msg, off, compression, compress)\n") 127 case st.Tag(i) == `dns:"domain-name"`: 128 o("off, err = packDomainName(rr.%s, msg, off, compression, false)\n") 129 case st.Tag(i) == `dns:"a"`: 130 o("off, err = packDataA(rr.%s, msg, off)\n") 131 case st.Tag(i) == `dns:"aaaa"`: 132 o("off, err = packDataAAAA(rr.%s, msg, off)\n") 133 case st.Tag(i) == `dns:"uint48"`: 134 o("off, err = packUint48(rr.%s, msg, off)\n") 135 case st.Tag(i) == `dns:"txt"`: 136 o("off, err = packString(rr.%s, msg, off)\n") 137 138 case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32 139 fallthrough 140 case st.Tag(i) == `dns:"base32"`: 141 o("off, err = packStringBase32(rr.%s, msg, off)\n") 142 143 case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64 144 fallthrough 145 case st.Tag(i) == `dns:"base64"`: 146 o("off, err = packStringBase64(rr.%s, msg, off)\n") 147 148 case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`): 149 // directly write instead of using o() so we get the error check in the correct place 150 field := st.Field(i).Name() 151 fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty 152if rr.%s != "-" { 153 off, err = packStringHex(rr.%s, msg, off) 154 if err != nil { 155 return off, err 156 } 157} 158`, field, field) 159 continue 160 case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex 161 fallthrough 162 case st.Tag(i) == `dns:"hex"`: 163 o("off, err = packStringHex(rr.%s, msg, off)\n") 164 case st.Tag(i) == `dns:"any"`: 165 o("off, err = packStringAny(rr.%s, msg, off)\n") 166 case st.Tag(i) == `dns:"octet"`: 167 o("off, err = packStringOctet(rr.%s, msg, off)\n") 168 case st.Tag(i) == "": 169 switch st.Field(i).Type().(*types.Basic).Kind() { 170 case types.Uint8: 171 o("off, err = packUint8(rr.%s, msg, off)\n") 172 case types.Uint16: 173 o("off, err = packUint16(rr.%s, msg, off)\n") 174 case types.Uint32: 175 o("off, err = packUint32(rr.%s, msg, off)\n") 176 case types.Uint64: 177 o("off, err = packUint64(rr.%s, msg, off)\n") 178 case types.String: 179 o("off, err = packString(rr.%s, msg, off)\n") 180 default: 181 log.Fatalln(name, st.Field(i).Name()) 182 } 183 default: 184 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 185 } 186 } 187 fmt.Fprintln(b, "return off, nil }\n") 188 } 189 190 fmt.Fprint(b, "// unpack*() functions\n\n") 191 for _, name := range namedTypes { 192 o := scope.Lookup(name) 193 st, _ := getTypeStruct(o.Type(), scope) 194 195 fmt.Fprintf(b, "func (rr *%s) unpack(msg []byte, off int) (off1 int, err error) {\n", name) 196 fmt.Fprint(b, `rdStart := off 197_ = rdStart 198 199`) 200 for i := 1; i < st.NumFields(); i++ { 201 o := func(s string) { 202 fmt.Fprintf(b, s, st.Field(i).Name()) 203 fmt.Fprint(b, `if err != nil { 204return off, err 205} 206`) 207 } 208 209 // size-* are special, because they reference a struct member we should use for the length. 210 if strings.HasPrefix(st.Tag(i), `dns:"size-`) { 211 structMember := structMember(st.Tag(i)) 212 structTag := structTag(st.Tag(i)) 213 switch structTag { 214 case "hex": 215 fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 216 case "base32": 217 fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 218 case "base64": 219 fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 220 default: 221 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 222 } 223 fmt.Fprint(b, `if err != nil { 224return off, err 225} 226`) 227 continue 228 } 229 230 if _, ok := st.Field(i).Type().(*types.Slice); ok { 231 switch st.Tag(i) { 232 case `dns:"-"`: // ignored 233 case `dns:"txt"`: 234 o("rr.%s, off, err = unpackStringTxt(msg, off)\n") 235 case `dns:"opt"`: 236 o("rr.%s, off, err = unpackDataOpt(msg, off)\n") 237 case `dns:"nsec"`: 238 o("rr.%s, off, err = unpackDataNsec(msg, off)\n") 239 case `dns:"domain-name"`: 240 o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 241 case `dns:"apl"`: 242 o("rr.%s, off, err = unpackDataApl(msg, off)\n") 243 default: 244 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 245 } 246 continue 247 } 248 249 switch st.Tag(i) { 250 case `dns:"-"`: // ignored 251 case `dns:"cdomain-name"`: 252 fallthrough 253 case `dns:"domain-name"`: 254 o("rr.%s, off, err = UnpackDomainName(msg, off)\n") 255 case `dns:"a"`: 256 o("rr.%s, off, err = unpackDataA(msg, off)\n") 257 case `dns:"aaaa"`: 258 o("rr.%s, off, err = unpackDataAAAA(msg, off)\n") 259 case `dns:"uint48"`: 260 o("rr.%s, off, err = unpackUint48(msg, off)\n") 261 case `dns:"txt"`: 262 o("rr.%s, off, err = unpackString(msg, off)\n") 263 case `dns:"base32"`: 264 o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 265 case `dns:"base64"`: 266 o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 267 case `dns:"hex"`: 268 o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 269 case `dns:"any"`: 270 o("rr.%s, off, err = unpackStringAny(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 271 case `dns:"octet"`: 272 o("rr.%s, off, err = unpackStringOctet(msg, off)\n") 273 case "": 274 switch st.Field(i).Type().(*types.Basic).Kind() { 275 case types.Uint8: 276 o("rr.%s, off, err = unpackUint8(msg, off)\n") 277 case types.Uint16: 278 o("rr.%s, off, err = unpackUint16(msg, off)\n") 279 case types.Uint32: 280 o("rr.%s, off, err = unpackUint32(msg, off)\n") 281 case types.Uint64: 282 o("rr.%s, off, err = unpackUint64(msg, off)\n") 283 case types.String: 284 o("rr.%s, off, err = unpackString(msg, off)\n") 285 default: 286 log.Fatalln(name, st.Field(i).Name()) 287 } 288 default: 289 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 290 } 291 // If we've hit len(msg) we return without error. 292 if i < st.NumFields()-1 { 293 fmt.Fprintf(b, `if off == len(msg) { 294return off, nil 295 } 296`) 297 } 298 } 299 fmt.Fprintf(b, "return off, nil }\n\n") 300 } 301 302 // gofmt 303 res, err := format.Source(b.Bytes()) 304 if err != nil { 305 b.WriteTo(os.Stderr) 306 log.Fatal(err) 307 } 308 309 // write result 310 f, err := os.Create("zmsg.go") 311 fatalIfErr(err) 312 defer f.Close() 313 f.Write(res) 314} 315 316// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string. 317func structMember(s string) string { 318 fields := strings.Split(s, ":") 319 if len(fields) == 0 { 320 return "" 321 } 322 f := fields[len(fields)-1] 323 // f should have a closing " 324 if len(f) > 1 { 325 return f[:len(f)-1] 326 } 327 return f 328} 329 330// structTag will take a tag like dns:"size-base32:SaltLength" and return base32. 331func structTag(s string) string { 332 fields := strings.Split(s, ":") 333 if len(fields) < 2 { 334 return "" 335 } 336 return fields[1][len("\"size-"):] 337} 338 339func fatalIfErr(err error) { 340 if err != nil { 341 log.Fatal(err) 342 } 343} 344