1//+build ignore 2 3// types_generate.go is meant to run with go generate. It will use 4// go/{importer,types} to track down all the RR struct types. Then for each type 5// it will generate conversion tables (TypeToRR and TypeToString) and banal 6// methods (len, Header, copy) based on the struct tags. The generated source is 7// written to ztypes.go, and is meant to be checked into git. 8package main 9 10import ( 11 "bytes" 12 "fmt" 13 "go/format" 14 "go/types" 15 "log" 16 "os" 17 "strings" 18 "text/template" 19 20 "golang.org/x/tools/go/packages" 21) 22 23var skipLen = map[string]struct{}{ 24 "NSEC": {}, 25 "NSEC3": {}, 26 "OPT": {}, 27 "CSYNC": {}, 28} 29 30var packageHdr = ` 31// Code generated by "go run types_generate.go"; DO NOT EDIT. 32 33package dns 34 35import ( 36 "encoding/base64" 37 "net" 38) 39 40` 41 42var TypeToRR = template.Must(template.New("TypeToRR").Parse(` 43// TypeToRR is a map of constructors for each RR type. 44var TypeToRR = map[uint16]func() RR{ 45{{range .}}{{if ne . "RFC3597"}} Type{{.}}: func() RR { return new({{.}}) }, 46{{end}}{{end}} } 47 48`)) 49 50var typeToString = template.Must(template.New("typeToString").Parse(` 51// TypeToString is a map of strings for each RR type. 52var TypeToString = map[uint16]string{ 53{{range .}}{{if ne . "NSAPPTR"}} Type{{.}}: "{{.}}", 54{{end}}{{end}} TypeNSAPPTR: "NSAP-PTR", 55} 56 57`)) 58 59var headerFunc = template.Must(template.New("headerFunc").Parse(` 60{{range .}} func (rr *{{.}}) Header() *RR_Header { return &rr.Hdr } 61{{end}} 62 63`)) 64 65// getTypeStruct will take a type and the package scope, and return the 66// (innermost) struct if the type is considered a RR type (currently defined as 67// those structs beginning with a RR_Header, could be redefined as implementing 68// the RR interface). The bool return value indicates if embedded structs were 69// resolved. 70func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { 71 st, ok := t.Underlying().(*types.Struct) 72 if !ok { 73 return nil, false 74 } 75 if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { 76 return st, false 77 } 78 if st.Field(0).Anonymous() { 79 st, _ := getTypeStruct(st.Field(0).Type(), scope) 80 return st, true 81 } 82 return nil, false 83} 84 85// loadModule retrieves package description for a given module. 86func loadModule(name string) (*types.Package, error) { 87 conf := packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo} 88 pkgs, err := packages.Load(&conf, name) 89 if err != nil { 90 return nil, err 91 } 92 return pkgs[0].Types, nil 93} 94 95func main() { 96 // Import and type-check the package 97 pkg, err := loadModule("github.com/miekg/dns") 98 fatalIfErr(err) 99 scope := pkg.Scope() 100 101 // Collect constants like TypeX 102 var numberedTypes []string 103 for _, name := range scope.Names() { 104 o := scope.Lookup(name) 105 if o == nil || !o.Exported() { 106 continue 107 } 108 b, ok := o.Type().(*types.Basic) 109 if !ok || b.Kind() != types.Uint16 { 110 continue 111 } 112 if !strings.HasPrefix(o.Name(), "Type") { 113 continue 114 } 115 name := strings.TrimPrefix(o.Name(), "Type") 116 if name == "PrivateRR" { 117 continue 118 } 119 numberedTypes = append(numberedTypes, name) 120 } 121 122 // Collect actual types (*X) 123 var namedTypes []string 124 for _, name := range scope.Names() { 125 o := scope.Lookup(name) 126 if o == nil || !o.Exported() { 127 continue 128 } 129 if st, _ := getTypeStruct(o.Type(), scope); st == nil { 130 continue 131 } 132 if name == "PrivateRR" { 133 continue 134 } 135 136 // Check if corresponding TypeX exists 137 if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { 138 log.Fatalf("Constant Type%s does not exist.", o.Name()) 139 } 140 141 namedTypes = append(namedTypes, o.Name()) 142 } 143 144 b := &bytes.Buffer{} 145 b.WriteString(packageHdr) 146 147 // Generate TypeToRR 148 fatalIfErr(TypeToRR.Execute(b, namedTypes)) 149 150 // Generate typeToString 151 fatalIfErr(typeToString.Execute(b, numberedTypes)) 152 153 // Generate headerFunc 154 fatalIfErr(headerFunc.Execute(b, namedTypes)) 155 156 // Generate len() 157 fmt.Fprint(b, "// len() functions\n") 158 for _, name := range namedTypes { 159 if _, ok := skipLen[name]; ok { 160 continue 161 } 162 o := scope.Lookup(name) 163 st, isEmbedded := getTypeStruct(o.Type(), scope) 164 if isEmbedded { 165 continue 166 } 167 fmt.Fprintf(b, "func (rr *%s) len(off int, compression map[string]struct{}) int {\n", name) 168 fmt.Fprintf(b, "l := rr.Hdr.len(off, compression)\n") 169 for i := 1; i < st.NumFields(); i++ { 170 o := func(s string) { fmt.Fprintf(b, s, st.Field(i).Name()) } 171 172 if _, ok := st.Field(i).Type().(*types.Slice); ok { 173 switch st.Tag(i) { 174 case `dns:"-"`: 175 // ignored 176 case `dns:"cdomain-name"`: 177 o("for _, x := range rr.%s { l += domainNameLen(x, off+l, compression, true) }\n") 178 case `dns:"domain-name"`: 179 o("for _, x := range rr.%s { l += domainNameLen(x, off+l, compression, false) }\n") 180 case `dns:"txt"`: 181 o("for _, x := range rr.%s { l += len(x) + 1 }\n") 182 case `dns:"apl"`: 183 o("for _, x := range rr.%s { l += x.len() }\n") 184 default: 185 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 186 } 187 continue 188 } 189 190 switch { 191 case st.Tag(i) == `dns:"-"`: 192 // ignored 193 case st.Tag(i) == `dns:"cdomain-name"`: 194 o("l += domainNameLen(rr.%s, off+l, compression, true)\n") 195 case st.Tag(i) == `dns:"domain-name"`: 196 o("l += domainNameLen(rr.%s, off+l, compression, false)\n") 197 case st.Tag(i) == `dns:"octet"`: 198 o("l += len(rr.%s)\n") 199 case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): 200 fallthrough 201 case st.Tag(i) == `dns:"base64"`: 202 o("l += base64.StdEncoding.DecodedLen(len(rr.%s))\n") 203 case strings.HasPrefix(st.Tag(i), `dns:"size-hex:`): // this has an extra field where the length is stored 204 o("l += len(rr.%s)/2\n") 205 case st.Tag(i) == `dns:"hex"`: 206 o("l += len(rr.%s)/2\n") 207 case st.Tag(i) == `dns:"any"`: 208 o("l += len(rr.%s)\n") 209 case st.Tag(i) == `dns:"a"`: 210 o("if len(rr.%s) != 0 { l += net.IPv4len }\n") 211 case st.Tag(i) == `dns:"aaaa"`: 212 o("if len(rr.%s) != 0 { l += net.IPv6len }\n") 213 case st.Tag(i) == `dns:"txt"`: 214 o("for _, t := range rr.%s { l += len(t) + 1 }\n") 215 case st.Tag(i) == `dns:"uint48"`: 216 o("l += 6 // %s\n") 217 case st.Tag(i) == "": 218 switch st.Field(i).Type().(*types.Basic).Kind() { 219 case types.Uint8: 220 o("l++ // %s\n") 221 case types.Uint16: 222 o("l += 2 // %s\n") 223 case types.Uint32: 224 o("l += 4 // %s\n") 225 case types.Uint64: 226 o("l += 8 // %s\n") 227 case types.String: 228 o("l += len(rr.%s) + 1\n") 229 default: 230 log.Fatalln(name, st.Field(i).Name()) 231 } 232 default: 233 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 234 } 235 } 236 fmt.Fprintf(b, "return l }\n") 237 } 238 239 // Generate copy() 240 fmt.Fprint(b, "// copy() functions\n") 241 for _, name := range namedTypes { 242 o := scope.Lookup(name) 243 st, isEmbedded := getTypeStruct(o.Type(), scope) 244 if isEmbedded { 245 continue 246 } 247 fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name) 248 fields := []string{"rr.Hdr"} 249 for i := 1; i < st.NumFields(); i++ { 250 f := st.Field(i).Name() 251 if sl, ok := st.Field(i).Type().(*types.Slice); ok { 252 t := sl.Underlying().String() 253 t = strings.TrimPrefix(t, "[]") 254 if strings.Contains(t, ".") { 255 splits := strings.Split(t, ".") 256 t = splits[len(splits)-1] 257 } 258 // For the EDNS0 interface (used in the OPT RR), we need to call the copy method on each element. 259 if t == "EDNS0" { 260 fmt.Fprintf(b, "%s := make([]%s, len(rr.%s));\nfor i,e := range rr.%s {\n %s[i] = e.copy()\n}\n", 261 f, t, f, f, f) 262 fields = append(fields, f) 263 continue 264 } 265 if t == "APLPrefix" { 266 fmt.Fprintf(b, "%s := make([]%s, len(rr.%s));\nfor i := range rr.%s {\n %s[i] = rr.%s[i].copy()\n}\n", 267 f, t, f, f, f, f) 268 fields = append(fields, f) 269 continue 270 } 271 fmt.Fprintf(b, "%s := make([]%s, len(rr.%s)); copy(%s, rr.%s)\n", 272 f, t, f, f, f) 273 fields = append(fields, f) 274 continue 275 } 276 if st.Field(i).Type().String() == "net.IP" { 277 fields = append(fields, "copyIP(rr."+f+")") 278 continue 279 } 280 fields = append(fields, "rr."+f) 281 } 282 fmt.Fprintf(b, "return &%s{%s}\n", name, strings.Join(fields, ",")) 283 fmt.Fprintf(b, "}\n") 284 } 285 286 // gofmt 287 res, err := format.Source(b.Bytes()) 288 if err != nil { 289 b.WriteTo(os.Stderr) 290 log.Fatal(err) 291 } 292 293 // write result 294 f, err := os.Create("ztypes.go") 295 fatalIfErr(err) 296 defer f.Close() 297 f.Write(res) 298} 299 300func fatalIfErr(err error) { 301 if err != nil { 302 log.Fatal(err) 303 } 304} 305