1//+build ignore 2 3// msg_generate.go is meant to run with go generate. It will use 4// go/{importer,types} to track down all the RR struct types. Then for each type 5// it will generate pack/unpack methods based on the struct tags. The generated source is 6// written to zmsg.go, and is meant to be checked into git. 7package main 8 9import ( 10 "bytes" 11 "fmt" 12 "go/format" 13 "go/importer" 14 "go/types" 15 "log" 16 "os" 17 "strings" 18) 19 20var packageHdr = ` 21// *** DO NOT MODIFY *** 22// AUTOGENERATED BY go generate from msg_generate.go 23 24package dns 25 26` 27 28// getTypeStruct will take a type and the package scope, and return the 29// (innermost) struct if the type is considered a RR type (currently defined as 30// those structs beginning with a RR_Header, could be redefined as implementing 31// the RR interface). The bool return value indicates if embedded structs were 32// resolved. 33func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { 34 st, ok := t.Underlying().(*types.Struct) 35 if !ok { 36 return nil, false 37 } 38 if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { 39 return st, false 40 } 41 if st.Field(0).Anonymous() { 42 st, _ := getTypeStruct(st.Field(0).Type(), scope) 43 return st, true 44 } 45 return nil, false 46} 47 48func main() { 49 // Import and type-check the package 50 pkg, err := importer.Default().Import("github.com/miekg/dns") 51 fatalIfErr(err) 52 scope := pkg.Scope() 53 54 // Collect actual types (*X) 55 var namedTypes []string 56 for _, name := range scope.Names() { 57 o := scope.Lookup(name) 58 if o == nil || !o.Exported() { 59 continue 60 } 61 if st, _ := getTypeStruct(o.Type(), scope); st == nil { 62 continue 63 } 64 if name == "PrivateRR" { 65 continue 66 } 67 68 // Check if corresponding TypeX exists 69 if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { 70 log.Fatalf("Constant Type%s does not exist.", o.Name()) 71 } 72 73 namedTypes = append(namedTypes, o.Name()) 74 } 75 76 b := &bytes.Buffer{} 77 b.WriteString(packageHdr) 78 79 fmt.Fprint(b, "// pack*() functions\n\n") 80 for _, name := range namedTypes { 81 o := scope.Lookup(name) 82 st, _ := getTypeStruct(o.Type(), scope) 83 84 fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name) 85 fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress) 86if err != nil { 87 return off, err 88} 89headerEnd := off 90`) 91 for i := 1; i < st.NumFields(); i++ { 92 o := func(s string) { 93 fmt.Fprintf(b, s, st.Field(i).Name()) 94 fmt.Fprint(b, `if err != nil { 95return off, err 96} 97`) 98 } 99 100 if _, ok := st.Field(i).Type().(*types.Slice); ok { 101 switch st.Tag(i) { 102 case `dns:"-"`: // ignored 103 case `dns:"txt"`: 104 o("off, err = packStringTxt(rr.%s, msg, off)\n") 105 case `dns:"opt"`: 106 o("off, err = packDataOpt(rr.%s, msg, off)\n") 107 case `dns:"nsec"`: 108 o("off, err = packDataNsec(rr.%s, msg, off)\n") 109 case `dns:"domain-name"`: 110 o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n") 111 default: 112 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 113 } 114 continue 115 } 116 117 switch { 118 case st.Tag(i) == `dns:"-"`: // ignored 119 case st.Tag(i) == `dns:"cdomain-name"`: 120 o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n") 121 case st.Tag(i) == `dns:"domain-name"`: 122 o("off, err = PackDomainName(rr.%s, msg, off, compression, false)\n") 123 case st.Tag(i) == `dns:"a"`: 124 o("off, err = packDataA(rr.%s, msg, off)\n") 125 case st.Tag(i) == `dns:"aaaa"`: 126 o("off, err = packDataAAAA(rr.%s, msg, off)\n") 127 case st.Tag(i) == `dns:"uint48"`: 128 o("off, err = packUint48(rr.%s, msg, off)\n") 129 case st.Tag(i) == `dns:"txt"`: 130 o("off, err = packString(rr.%s, msg, off)\n") 131 132 case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32 133 fallthrough 134 case st.Tag(i) == `dns:"base32"`: 135 o("off, err = packStringBase32(rr.%s, msg, off)\n") 136 137 case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64 138 fallthrough 139 case st.Tag(i) == `dns:"base64"`: 140 o("off, err = packStringBase64(rr.%s, msg, off)\n") 141 142 case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`): // Hack to fix empty salt length for NSEC3 143 o("if rr.%s == \"-\" { /* do nothing, empty salt */ }\n") 144 continue 145 case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex 146 fallthrough 147 case st.Tag(i) == `dns:"hex"`: 148 o("off, err = packStringHex(rr.%s, msg, off)\n") 149 150 case st.Tag(i) == `dns:"octet"`: 151 o("off, err = packStringOctet(rr.%s, msg, off)\n") 152 case st.Tag(i) == "": 153 switch st.Field(i).Type().(*types.Basic).Kind() { 154 case types.Uint8: 155 o("off, err = packUint8(rr.%s, msg, off)\n") 156 case types.Uint16: 157 o("off, err = packUint16(rr.%s, msg, off)\n") 158 case types.Uint32: 159 o("off, err = packUint32(rr.%s, msg, off)\n") 160 case types.Uint64: 161 o("off, err = packUint64(rr.%s, msg, off)\n") 162 case types.String: 163 o("off, err = packString(rr.%s, msg, off)\n") 164 default: 165 log.Fatalln(name, st.Field(i).Name()) 166 } 167 default: 168 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 169 } 170 } 171 // We have packed everything, only now we know the rdlength of this RR 172 fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)") 173 fmt.Fprintln(b, "return off, nil }\n") 174 } 175 176 fmt.Fprint(b, "// unpack*() functions\n\n") 177 for _, name := range namedTypes { 178 o := scope.Lookup(name) 179 st, _ := getTypeStruct(o.Type(), scope) 180 181 fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name) 182 fmt.Fprintf(b, "rr := new(%s)\n", name) 183 fmt.Fprint(b, "rr.Hdr = h\n") 184 fmt.Fprint(b, `if noRdata(h) { 185return rr, off, nil 186 } 187var err error 188rdStart := off 189_ = rdStart 190 191`) 192 for i := 1; i < st.NumFields(); i++ { 193 o := func(s string) { 194 fmt.Fprintf(b, s, st.Field(i).Name()) 195 fmt.Fprint(b, `if err != nil { 196return rr, off, err 197} 198`) 199 } 200 201 // size-* are special, because they reference a struct member we should use for the length. 202 if strings.HasPrefix(st.Tag(i), `dns:"size-`) { 203 structMember := structMember(st.Tag(i)) 204 structTag := structTag(st.Tag(i)) 205 switch structTag { 206 case "hex": 207 fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 208 case "base32": 209 fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 210 case "base64": 211 fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 212 default: 213 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 214 } 215 fmt.Fprint(b, `if err != nil { 216return rr, off, err 217} 218`) 219 continue 220 } 221 222 if _, ok := st.Field(i).Type().(*types.Slice); ok { 223 switch st.Tag(i) { 224 case `dns:"-"`: // ignored 225 case `dns:"txt"`: 226 o("rr.%s, off, err = unpackStringTxt(msg, off)\n") 227 case `dns:"opt"`: 228 o("rr.%s, off, err = unpackDataOpt(msg, off)\n") 229 case `dns:"nsec"`: 230 o("rr.%s, off, err = unpackDataNsec(msg, off)\n") 231 case `dns:"domain-name"`: 232 o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 233 default: 234 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 235 } 236 continue 237 } 238 239 switch st.Tag(i) { 240 case `dns:"-"`: // ignored 241 case `dns:"cdomain-name"`: 242 fallthrough 243 case `dns:"domain-name"`: 244 o("rr.%s, off, err = UnpackDomainName(msg, off)\n") 245 case `dns:"a"`: 246 o("rr.%s, off, err = unpackDataA(msg, off)\n") 247 case `dns:"aaaa"`: 248 o("rr.%s, off, err = unpackDataAAAA(msg, off)\n") 249 case `dns:"uint48"`: 250 o("rr.%s, off, err = unpackUint48(msg, off)\n") 251 case `dns:"txt"`: 252 o("rr.%s, off, err = unpackString(msg, off)\n") 253 case `dns:"base32"`: 254 o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 255 case `dns:"base64"`: 256 o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 257 case `dns:"hex"`: 258 o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 259 case `dns:"octet"`: 260 o("rr.%s, off, err = unpackStringOctet(msg, off)\n") 261 case "": 262 switch st.Field(i).Type().(*types.Basic).Kind() { 263 case types.Uint8: 264 o("rr.%s, off, err = unpackUint8(msg, off)\n") 265 case types.Uint16: 266 o("rr.%s, off, err = unpackUint16(msg, off)\n") 267 case types.Uint32: 268 o("rr.%s, off, err = unpackUint32(msg, off)\n") 269 case types.Uint64: 270 o("rr.%s, off, err = unpackUint64(msg, off)\n") 271 case types.String: 272 o("rr.%s, off, err = unpackString(msg, off)\n") 273 default: 274 log.Fatalln(name, st.Field(i).Name()) 275 } 276 default: 277 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 278 } 279 // If we've hit len(msg) we return without error. 280 if i < st.NumFields()-1 { 281 fmt.Fprintf(b, `if off == len(msg) { 282return rr, off, nil 283 } 284`) 285 } 286 } 287 fmt.Fprintf(b, "return rr, off, err }\n\n") 288 } 289 // Generate typeToUnpack map 290 fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){") 291 for _, name := range namedTypes { 292 if name == "RFC3597" { 293 continue 294 } 295 fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name) 296 } 297 fmt.Fprintln(b, "}\n") 298 299 // gofmt 300 res, err := format.Source(b.Bytes()) 301 if err != nil { 302 b.WriteTo(os.Stderr) 303 log.Fatal(err) 304 } 305 306 // write result 307 f, err := os.Create("zmsg.go") 308 fatalIfErr(err) 309 defer f.Close() 310 f.Write(res) 311} 312 313// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string. 314func structMember(s string) string { 315 fields := strings.Split(s, ":") 316 if len(fields) == 0 { 317 return "" 318 } 319 f := fields[len(fields)-1] 320 // f should have a closing " 321 if len(f) > 1 { 322 return f[:len(f)-1] 323 } 324 return f 325} 326 327// structTag will take a tag like dns:"size-base32:SaltLength" and return base32. 328func structTag(s string) string { 329 fields := strings.Split(s, ":") 330 if len(fields) < 2 { 331 return "" 332 } 333 return fields[1][len("\"size-"):] 334} 335 336func fatalIfErr(err error) { 337 if err != nil { 338 log.Fatal(err) 339 } 340} 341