1//+build ignore 2 3// types_generate.go is meant to run with go generate. It will use 4// go/{importer,types} to track down all the RR struct types. Then for each type 5// it will generate conversion tables (TypeToRR and TypeToString) and banal 6// methods (len, Header, copy) based on the struct tags. The generated source is 7// written to ztypes.go, and is meant to be checked into git. 8package main 9 10import ( 11 "bytes" 12 "fmt" 13 "go/format" 14 "go/types" 15 "log" 16 "os" 17 "strings" 18 "text/template" 19 20 "golang.org/x/tools/go/packages" 21) 22 23var skipLen = map[string]struct{}{ 24 "NSEC": {}, 25 "NSEC3": {}, 26 "OPT": {}, 27 "CSYNC": {}, 28} 29 30var packageHdr = ` 31// Code generated by "go run types_generate.go"; DO NOT EDIT. 32 33package dns 34 35import ( 36 "encoding/base64" 37 "net" 38) 39 40` 41 42var TypeToRR = template.Must(template.New("TypeToRR").Parse(` 43// TypeToRR is a map of constructors for each RR type. 44var TypeToRR = map[uint16]func() RR{ 45{{range .}}{{if ne . "RFC3597"}} Type{{.}}: func() RR { return new({{.}}) }, 46{{end}}{{end}} } 47 48`)) 49 50var typeToString = template.Must(template.New("typeToString").Parse(` 51// TypeToString is a map of strings for each RR type. 52var TypeToString = map[uint16]string{ 53{{range .}}{{if ne . "NSAPPTR"}} Type{{.}}: "{{.}}", 54{{end}}{{end}} TypeNSAPPTR: "NSAP-PTR", 55} 56 57`)) 58 59var headerFunc = template.Must(template.New("headerFunc").Parse(` 60{{range .}} func (rr *{{.}}) Header() *RR_Header { return &rr.Hdr } 61{{end}} 62 63`)) 64 65// getTypeStruct will take a type and the package scope, and return the 66// (innermost) struct if the type is considered a RR type (currently defined as 67// those structs beginning with a RR_Header, could be redefined as implementing 68// the RR interface). The bool return value indicates if embedded structs were 69// resolved. 70func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { 71 st, ok := t.Underlying().(*types.Struct) 72 if !ok { 73 return nil, false 74 } 75 if st.NumFields() == 0 { 76 return nil, false 77 } 78 if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { 79 return st, false 80 } 81 if st.Field(0).Anonymous() { 82 st, _ := getTypeStruct(st.Field(0).Type(), scope) 83 return st, true 84 } 85 return nil, false 86} 87 88// loadModule retrieves package description for a given module. 89func loadModule(name string) (*types.Package, error) { 90 conf := packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo} 91 pkgs, err := packages.Load(&conf, name) 92 if err != nil { 93 return nil, err 94 } 95 return pkgs[0].Types, nil 96} 97 98func main() { 99 // Import and type-check the package 100 pkg, err := loadModule("github.com/miekg/dns") 101 fatalIfErr(err) 102 scope := pkg.Scope() 103 104 // Collect constants like TypeX 105 var numberedTypes []string 106 for _, name := range scope.Names() { 107 o := scope.Lookup(name) 108 if o == nil || !o.Exported() { 109 continue 110 } 111 b, ok := o.Type().(*types.Basic) 112 if !ok || b.Kind() != types.Uint16 { 113 continue 114 } 115 if !strings.HasPrefix(o.Name(), "Type") { 116 continue 117 } 118 name := strings.TrimPrefix(o.Name(), "Type") 119 if name == "PrivateRR" { 120 continue 121 } 122 numberedTypes = append(numberedTypes, name) 123 } 124 125 // Collect actual types (*X) 126 var namedTypes []string 127 for _, name := range scope.Names() { 128 o := scope.Lookup(name) 129 if o == nil || !o.Exported() { 130 continue 131 } 132 if st, _ := getTypeStruct(o.Type(), scope); st == nil { 133 continue 134 } 135 if name == "PrivateRR" { 136 continue 137 } 138 139 // Check if corresponding TypeX exists 140 if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { 141 log.Fatalf("Constant Type%s does not exist.", o.Name()) 142 } 143 144 namedTypes = append(namedTypes, o.Name()) 145 } 146 147 b := &bytes.Buffer{} 148 b.WriteString(packageHdr) 149 150 // Generate TypeToRR 151 fatalIfErr(TypeToRR.Execute(b, namedTypes)) 152 153 // Generate typeToString 154 fatalIfErr(typeToString.Execute(b, numberedTypes)) 155 156 // Generate headerFunc 157 fatalIfErr(headerFunc.Execute(b, namedTypes)) 158 159 // Generate len() 160 fmt.Fprint(b, "// len() functions\n") 161 for _, name := range namedTypes { 162 if _, ok := skipLen[name]; ok { 163 continue 164 } 165 o := scope.Lookup(name) 166 st, isEmbedded := getTypeStruct(o.Type(), scope) 167 if isEmbedded { 168 continue 169 } 170 fmt.Fprintf(b, "func (rr *%s) len(off int, compression map[string]struct{}) int {\n", name) 171 fmt.Fprintf(b, "l := rr.Hdr.len(off, compression)\n") 172 for i := 1; i < st.NumFields(); i++ { 173 o := func(s string) { fmt.Fprintf(b, s, st.Field(i).Name()) } 174 175 if _, ok := st.Field(i).Type().(*types.Slice); ok { 176 switch st.Tag(i) { 177 case `dns:"-"`: 178 // ignored 179 case `dns:"cdomain-name"`: 180 o("for _, x := range rr.%s { l += domainNameLen(x, off+l, compression, true) }\n") 181 case `dns:"domain-name"`: 182 o("for _, x := range rr.%s { l += domainNameLen(x, off+l, compression, false) }\n") 183 case `dns:"txt"`: 184 o("for _, x := range rr.%s { l += len(x) + 1 }\n") 185 case `dns:"apl"`: 186 o("for _, x := range rr.%s { l += x.len() }\n") 187 case `dns:"pairs"`: 188 o("for _, x := range rr.%s { l += 4 + int(x.len()) }\n") 189 default: 190 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 191 } 192 continue 193 } 194 195 switch { 196 case st.Tag(i) == `dns:"-"`: 197 // ignored 198 case st.Tag(i) == `dns:"cdomain-name"`: 199 o("l += domainNameLen(rr.%s, off+l, compression, true)\n") 200 case st.Tag(i) == `dns:"domain-name"`: 201 o("l += domainNameLen(rr.%s, off+l, compression, false)\n") 202 case st.Tag(i) == `dns:"octet"`: 203 o("l += len(rr.%s)\n") 204 case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): 205 fallthrough 206 case st.Tag(i) == `dns:"base64"`: 207 o("l += base64.StdEncoding.DecodedLen(len(rr.%s))\n") 208 case strings.HasPrefix(st.Tag(i), `dns:"size-hex:`): // this has an extra field where the length is stored 209 o("l += len(rr.%s)/2\n") 210 case st.Tag(i) == `dns:"hex"`: 211 o("l += len(rr.%s)/2\n") 212 case st.Tag(i) == `dns:"any"`: 213 o("l += len(rr.%s)\n") 214 case st.Tag(i) == `dns:"a"`: 215 o("if len(rr.%s) != 0 { l += net.IPv4len }\n") 216 case st.Tag(i) == `dns:"aaaa"`: 217 o("if len(rr.%s) != 0 { l += net.IPv6len }\n") 218 case st.Tag(i) == `dns:"txt"`: 219 o("for _, t := range rr.%s { l += len(t) + 1 }\n") 220 case st.Tag(i) == `dns:"uint48"`: 221 o("l += 6 // %s\n") 222 case st.Tag(i) == "": 223 switch st.Field(i).Type().(*types.Basic).Kind() { 224 case types.Uint8: 225 o("l++ // %s\n") 226 case types.Uint16: 227 o("l += 2 // %s\n") 228 case types.Uint32: 229 o("l += 4 // %s\n") 230 case types.Uint64: 231 o("l += 8 // %s\n") 232 case types.String: 233 o("l += len(rr.%s) + 1\n") 234 default: 235 log.Fatalln(name, st.Field(i).Name()) 236 } 237 default: 238 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 239 } 240 } 241 fmt.Fprintf(b, "return l }\n") 242 } 243 244 // Generate copy() 245 fmt.Fprint(b, "// copy() functions\n") 246 for _, name := range namedTypes { 247 o := scope.Lookup(name) 248 st, isEmbedded := getTypeStruct(o.Type(), scope) 249 fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name) 250 fields := make([]string, 0, st.NumFields()) 251 if isEmbedded { 252 a, _ := o.Type().Underlying().(*types.Struct) 253 parent := a.Field(0).Name() 254 fields = append(fields, "*rr."+parent+".copy().(*"+parent+")") 255 goto WriteCopy 256 } 257 fields = append(fields, "rr.Hdr") 258 for i := 1; i < st.NumFields(); i++ { 259 f := st.Field(i).Name() 260 if sl, ok := st.Field(i).Type().(*types.Slice); ok { 261 t := sl.Underlying().String() 262 t = strings.TrimPrefix(t, "[]") 263 if strings.Contains(t, ".") { 264 splits := strings.Split(t, ".") 265 t = splits[len(splits)-1] 266 } 267 // For the EDNS0 interface (used in the OPT RR), we need to call the copy method on each element. 268 if t == "EDNS0" { 269 fmt.Fprintf(b, "%s := make([]%s, len(rr.%s));\nfor i,e := range rr.%s {\n %s[i] = e.copy()\n}\n", 270 f, t, f, f, f) 271 fields = append(fields, f) 272 continue 273 } 274 if t == "APLPrefix" { 275 fmt.Fprintf(b, "%s := make([]%s, len(rr.%s));\nfor i,e := range rr.%s {\n %s[i] = e.copy()\n}\n", 276 f, t, f, f, f) 277 fields = append(fields, f) 278 continue 279 } 280 if t == "SVCBKeyValue" { 281 fmt.Fprintf(b, "%s := make([]%s, len(rr.%s));\nfor i,e := range rr.%s {\n %s[i] = e.copy()\n}\n", 282 f, t, f, f, f) 283 fields = append(fields, f) 284 continue 285 } 286 fmt.Fprintf(b, "%s := make([]%s, len(rr.%s)); copy(%s, rr.%s)\n", 287 f, t, f, f, f) 288 fields = append(fields, f) 289 continue 290 } 291 if st.Field(i).Type().String() == "net.IP" { 292 fields = append(fields, "copyIP(rr."+f+")") 293 continue 294 } 295 fields = append(fields, "rr."+f) 296 } 297 WriteCopy: 298 fmt.Fprintf(b, "return &%s{%s}\n", name, strings.Join(fields, ",")) 299 fmt.Fprintf(b, "}\n") 300 } 301 302 // gofmt 303 res, err := format.Source(b.Bytes()) 304 if err != nil { 305 b.WriteTo(os.Stderr) 306 log.Fatal(err) 307 } 308 309 // write result 310 f, err := os.Create("ztypes.go") 311 fatalIfErr(err) 312 defer f.Close() 313 f.Write(res) 314} 315 316func fatalIfErr(err error) { 317 if err != nil { 318 log.Fatal(err) 319 } 320} 321