1//+build ignore 2 3// msg_generate.go is meant to run with go generate. It will use 4// go/{importer,types} to track down all the RR struct types. Then for each type 5// it will generate pack/unpack methods based on the struct tags. The generated source is 6// written to zmsg.go, and is meant to be checked into git. 7package main 8 9import ( 10 "bytes" 11 "fmt" 12 "go/format" 13 "go/importer" 14 "go/types" 15 "log" 16 "os" 17 "strings" 18) 19 20var packageHdr = ` 21// Code generated by "go run msg_generate.go"; DO NOT EDIT. 22 23package dns 24 25` 26 27// getTypeStruct will take a type and the package scope, and return the 28// (innermost) struct if the type is considered a RR type (currently defined as 29// those structs beginning with a RR_Header, could be redefined as implementing 30// the RR interface). The bool return value indicates if embedded structs were 31// resolved. 32func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { 33 st, ok := t.Underlying().(*types.Struct) 34 if !ok { 35 return nil, false 36 } 37 if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { 38 return st, false 39 } 40 if st.Field(0).Anonymous() { 41 st, _ := getTypeStruct(st.Field(0).Type(), scope) 42 return st, true 43 } 44 return nil, false 45} 46 47func main() { 48 // Import and type-check the package 49 pkg, err := importer.Default().Import("github.com/miekg/dns") 50 fatalIfErr(err) 51 scope := pkg.Scope() 52 53 // Collect actual types (*X) 54 var namedTypes []string 55 for _, name := range scope.Names() { 56 o := scope.Lookup(name) 57 if o == nil || !o.Exported() { 58 continue 59 } 60 if st, _ := getTypeStruct(o.Type(), scope); st == nil { 61 continue 62 } 63 if name == "PrivateRR" { 64 continue 65 } 66 67 // Check if corresponding TypeX exists 68 if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { 69 log.Fatalf("Constant Type%s does not exist.", o.Name()) 70 } 71 72 namedTypes = append(namedTypes, o.Name()) 73 } 74 75 b := &bytes.Buffer{} 76 b.WriteString(packageHdr) 77 78 fmt.Fprint(b, "// pack*() functions\n\n") 79 for _, name := range namedTypes { 80 o := scope.Lookup(name) 81 st, _ := getTypeStruct(o.Type(), scope) 82 83 fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {\n", name) 84 for i := 1; i < st.NumFields(); i++ { 85 o := func(s string) { 86 fmt.Fprintf(b, s, st.Field(i).Name()) 87 fmt.Fprint(b, `if err != nil { 88return off, err 89} 90`) 91 } 92 93 if _, ok := st.Field(i).Type().(*types.Slice); ok { 94 switch st.Tag(i) { 95 case `dns:"-"`: // ignored 96 case `dns:"txt"`: 97 o("off, err = packStringTxt(rr.%s, msg, off)\n") 98 case `dns:"opt"`: 99 o("off, err = packDataOpt(rr.%s, msg, off)\n") 100 case `dns:"nsec"`: 101 o("off, err = packDataNsec(rr.%s, msg, off)\n") 102 case `dns:"domain-name"`: 103 o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n") 104 default: 105 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 106 } 107 continue 108 } 109 110 switch { 111 case st.Tag(i) == `dns:"-"`: // ignored 112 case st.Tag(i) == `dns:"cdomain-name"`: 113 o("off, err = packDomainName(rr.%s, msg, off, compression, compress)\n") 114 case st.Tag(i) == `dns:"domain-name"`: 115 o("off, err = packDomainName(rr.%s, msg, off, compression, false)\n") 116 case st.Tag(i) == `dns:"a"`: 117 o("off, err = packDataA(rr.%s, msg, off)\n") 118 case st.Tag(i) == `dns:"aaaa"`: 119 o("off, err = packDataAAAA(rr.%s, msg, off)\n") 120 case st.Tag(i) == `dns:"uint48"`: 121 o("off, err = packUint48(rr.%s, msg, off)\n") 122 case st.Tag(i) == `dns:"txt"`: 123 o("off, err = packString(rr.%s, msg, off)\n") 124 125 case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32 126 fallthrough 127 case st.Tag(i) == `dns:"base32"`: 128 o("off, err = packStringBase32(rr.%s, msg, off)\n") 129 130 case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64 131 fallthrough 132 case st.Tag(i) == `dns:"base64"`: 133 o("off, err = packStringBase64(rr.%s, msg, off)\n") 134 135 case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`): 136 // directly write instead of using o() so we get the error check in the correct place 137 field := st.Field(i).Name() 138 fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty 139if rr.%s != "-" { 140 off, err = packStringHex(rr.%s, msg, off) 141 if err != nil { 142 return off, err 143 } 144} 145`, field, field) 146 continue 147 case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex 148 fallthrough 149 case st.Tag(i) == `dns:"hex"`: 150 o("off, err = packStringHex(rr.%s, msg, off)\n") 151 case st.Tag(i) == `dns:"any"`: 152 o("off, err = packStringAny(rr.%s, msg, off)\n") 153 case st.Tag(i) == `dns:"octet"`: 154 o("off, err = packStringOctet(rr.%s, msg, off)\n") 155 case st.Tag(i) == "": 156 switch st.Field(i).Type().(*types.Basic).Kind() { 157 case types.Uint8: 158 o("off, err = packUint8(rr.%s, msg, off)\n") 159 case types.Uint16: 160 o("off, err = packUint16(rr.%s, msg, off)\n") 161 case types.Uint32: 162 o("off, err = packUint32(rr.%s, msg, off)\n") 163 case types.Uint64: 164 o("off, err = packUint64(rr.%s, msg, off)\n") 165 case types.String: 166 o("off, err = packString(rr.%s, msg, off)\n") 167 default: 168 log.Fatalln(name, st.Field(i).Name()) 169 } 170 default: 171 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 172 } 173 } 174 fmt.Fprintln(b, "return off, nil }\n") 175 } 176 177 fmt.Fprint(b, "// unpack*() functions\n\n") 178 for _, name := range namedTypes { 179 o := scope.Lookup(name) 180 st, _ := getTypeStruct(o.Type(), scope) 181 182 fmt.Fprintf(b, "func (rr *%s) unpack(msg []byte, off int) (off1 int, err error) {\n", name) 183 fmt.Fprint(b, `rdStart := off 184_ = rdStart 185 186`) 187 for i := 1; i < st.NumFields(); i++ { 188 o := func(s string) { 189 fmt.Fprintf(b, s, st.Field(i).Name()) 190 fmt.Fprint(b, `if err != nil { 191return off, err 192} 193`) 194 } 195 196 // size-* are special, because they reference a struct member we should use for the length. 197 if strings.HasPrefix(st.Tag(i), `dns:"size-`) { 198 structMember := structMember(st.Tag(i)) 199 structTag := structTag(st.Tag(i)) 200 switch structTag { 201 case "hex": 202 fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 203 case "base32": 204 fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 205 case "base64": 206 fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember) 207 default: 208 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 209 } 210 fmt.Fprint(b, `if err != nil { 211return off, err 212} 213`) 214 continue 215 } 216 217 if _, ok := st.Field(i).Type().(*types.Slice); ok { 218 switch st.Tag(i) { 219 case `dns:"-"`: // ignored 220 case `dns:"txt"`: 221 o("rr.%s, off, err = unpackStringTxt(msg, off)\n") 222 case `dns:"opt"`: 223 o("rr.%s, off, err = unpackDataOpt(msg, off)\n") 224 case `dns:"nsec"`: 225 o("rr.%s, off, err = unpackDataNsec(msg, off)\n") 226 case `dns:"domain-name"`: 227 o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 228 default: 229 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 230 } 231 continue 232 } 233 234 switch st.Tag(i) { 235 case `dns:"-"`: // ignored 236 case `dns:"cdomain-name"`: 237 fallthrough 238 case `dns:"domain-name"`: 239 o("rr.%s, off, err = UnpackDomainName(msg, off)\n") 240 case `dns:"a"`: 241 o("rr.%s, off, err = unpackDataA(msg, off)\n") 242 case `dns:"aaaa"`: 243 o("rr.%s, off, err = unpackDataAAAA(msg, off)\n") 244 case `dns:"uint48"`: 245 o("rr.%s, off, err = unpackUint48(msg, off)\n") 246 case `dns:"txt"`: 247 o("rr.%s, off, err = unpackString(msg, off)\n") 248 case `dns:"base32"`: 249 o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 250 case `dns:"base64"`: 251 o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 252 case `dns:"hex"`: 253 o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 254 case `dns:"any"`: 255 o("rr.%s, off, err = unpackStringAny(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") 256 case `dns:"octet"`: 257 o("rr.%s, off, err = unpackStringOctet(msg, off)\n") 258 case "": 259 switch st.Field(i).Type().(*types.Basic).Kind() { 260 case types.Uint8: 261 o("rr.%s, off, err = unpackUint8(msg, off)\n") 262 case types.Uint16: 263 o("rr.%s, off, err = unpackUint16(msg, off)\n") 264 case types.Uint32: 265 o("rr.%s, off, err = unpackUint32(msg, off)\n") 266 case types.Uint64: 267 o("rr.%s, off, err = unpackUint64(msg, off)\n") 268 case types.String: 269 o("rr.%s, off, err = unpackString(msg, off)\n") 270 default: 271 log.Fatalln(name, st.Field(i).Name()) 272 } 273 default: 274 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 275 } 276 // If we've hit len(msg) we return without error. 277 if i < st.NumFields()-1 { 278 fmt.Fprintf(b, `if off == len(msg) { 279return off, nil 280 } 281`) 282 } 283 } 284 fmt.Fprintf(b, "return off, nil }\n\n") 285 } 286 287 // gofmt 288 res, err := format.Source(b.Bytes()) 289 if err != nil { 290 b.WriteTo(os.Stderr) 291 log.Fatal(err) 292 } 293 294 // write result 295 f, err := os.Create("zmsg.go") 296 fatalIfErr(err) 297 defer f.Close() 298 f.Write(res) 299} 300 301// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string. 302func structMember(s string) string { 303 fields := strings.Split(s, ":") 304 if len(fields) == 0 { 305 return "" 306 } 307 f := fields[len(fields)-1] 308 // f should have a closing " 309 if len(f) > 1 { 310 return f[:len(f)-1] 311 } 312 return f 313} 314 315// structTag will take a tag like dns:"size-base32:SaltLength" and return base32. 316func structTag(s string) string { 317 fields := strings.Split(s, ":") 318 if len(fields) < 2 { 319 return "" 320 } 321 return fields[1][len("\"size-"):] 322} 323 324func fatalIfErr(err error) { 325 if err != nil { 326 log.Fatal(err) 327 } 328} 329