1//+build ignore 2 3// types_generate.go is meant to run with go generate. It will use 4// go/{importer,types} to track down all the RR struct types. Then for each type 5// it will generate conversion tables (TypeToRR and TypeToString) and banal 6// methods (len, Header, copy) based on the struct tags. The generated source is 7// written to ztypes.go, and is meant to be checked into git. 8package main 9 10import ( 11 "bytes" 12 "fmt" 13 "go/format" 14 "go/importer" 15 "go/types" 16 "log" 17 "os" 18 "strings" 19 "text/template" 20) 21 22var skipLen = map[string]struct{}{ 23 "NSEC": {}, 24 "NSEC3": {}, 25 "OPT": {}, 26 "CSYNC": {}, 27} 28 29var packageHdr = ` 30// Code generated by "go run types_generate.go"; DO NOT EDIT. 31 32package dns 33 34import ( 35 "encoding/base64" 36 "net" 37) 38 39` 40 41var TypeToRR = template.Must(template.New("TypeToRR").Parse(` 42// TypeToRR is a map of constructors for each RR type. 43var TypeToRR = map[uint16]func() RR{ 44{{range .}}{{if ne . "RFC3597"}} Type{{.}}: func() RR { return new({{.}}) }, 45{{end}}{{end}} } 46 47`)) 48 49var typeToString = template.Must(template.New("typeToString").Parse(` 50// TypeToString is a map of strings for each RR type. 51var TypeToString = map[uint16]string{ 52{{range .}}{{if ne . "NSAPPTR"}} Type{{.}}: "{{.}}", 53{{end}}{{end}} TypeNSAPPTR: "NSAP-PTR", 54} 55 56`)) 57 58var headerFunc = template.Must(template.New("headerFunc").Parse(` 59{{range .}} func (rr *{{.}}) Header() *RR_Header { return &rr.Hdr } 60{{end}} 61 62`)) 63 64// getTypeStruct will take a type and the package scope, and return the 65// (innermost) struct if the type is considered a RR type (currently defined as 66// those structs beginning with a RR_Header, could be redefined as implementing 67// the RR interface). The bool return value indicates if embedded structs were 68// resolved. 69func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { 70 st, ok := t.Underlying().(*types.Struct) 71 if !ok { 72 return nil, false 73 } 74 if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { 75 return st, false 76 } 77 if st.Field(0).Anonymous() { 78 st, _ := getTypeStruct(st.Field(0).Type(), scope) 79 return st, true 80 } 81 return nil, false 82} 83 84func main() { 85 // Import and type-check the package 86 pkg, err := importer.Default().Import("github.com/miekg/dns") 87 fatalIfErr(err) 88 scope := pkg.Scope() 89 90 // Collect constants like TypeX 91 var numberedTypes []string 92 for _, name := range scope.Names() { 93 o := scope.Lookup(name) 94 if o == nil || !o.Exported() { 95 continue 96 } 97 b, ok := o.Type().(*types.Basic) 98 if !ok || b.Kind() != types.Uint16 { 99 continue 100 } 101 if !strings.HasPrefix(o.Name(), "Type") { 102 continue 103 } 104 name := strings.TrimPrefix(o.Name(), "Type") 105 if name == "PrivateRR" { 106 continue 107 } 108 numberedTypes = append(numberedTypes, name) 109 } 110 111 // Collect actual types (*X) 112 var namedTypes []string 113 for _, name := range scope.Names() { 114 o := scope.Lookup(name) 115 if o == nil || !o.Exported() { 116 continue 117 } 118 if st, _ := getTypeStruct(o.Type(), scope); st == nil { 119 continue 120 } 121 if name == "PrivateRR" { 122 continue 123 } 124 125 // Check if corresponding TypeX exists 126 if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" { 127 log.Fatalf("Constant Type%s does not exist.", o.Name()) 128 } 129 130 namedTypes = append(namedTypes, o.Name()) 131 } 132 133 b := &bytes.Buffer{} 134 b.WriteString(packageHdr) 135 136 // Generate TypeToRR 137 fatalIfErr(TypeToRR.Execute(b, namedTypes)) 138 139 // Generate typeToString 140 fatalIfErr(typeToString.Execute(b, numberedTypes)) 141 142 // Generate headerFunc 143 fatalIfErr(headerFunc.Execute(b, namedTypes)) 144 145 // Generate len() 146 fmt.Fprint(b, "// len() functions\n") 147 for _, name := range namedTypes { 148 if _, ok := skipLen[name]; ok { 149 continue 150 } 151 o := scope.Lookup(name) 152 st, isEmbedded := getTypeStruct(o.Type(), scope) 153 if isEmbedded { 154 continue 155 } 156 fmt.Fprintf(b, "func (rr *%s) len() int {\n", name) 157 fmt.Fprintf(b, "l := rr.Hdr.len()\n") 158 for i := 1; i < st.NumFields(); i++ { 159 o := func(s string) { fmt.Fprintf(b, s, st.Field(i).Name()) } 160 161 if _, ok := st.Field(i).Type().(*types.Slice); ok { 162 switch st.Tag(i) { 163 case `dns:"-"`: 164 // ignored 165 case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`: 166 o("for _, x := range rr.%s { l += len(x) + 1 }\n") 167 default: 168 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 169 } 170 continue 171 } 172 173 switch { 174 case st.Tag(i) == `dns:"-"`: 175 // ignored 176 case st.Tag(i) == `dns:"cdomain-name"`, st.Tag(i) == `dns:"domain-name"`: 177 o("l += len(rr.%s) + 1\n") 178 case st.Tag(i) == `dns:"octet"`: 179 o("l += len(rr.%s)\n") 180 case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): 181 fallthrough 182 case st.Tag(i) == `dns:"base64"`: 183 o("l += base64.StdEncoding.DecodedLen(len(rr.%s))\n") 184 case strings.HasPrefix(st.Tag(i), `dns:"size-hex:`): // this has an extra field where the length is stored 185 o("l += len(rr.%s)/2\n") 186 case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): 187 fallthrough 188 case st.Tag(i) == `dns:"hex"`: 189 o("l += len(rr.%s)/2 + 1\n") 190 case st.Tag(i) == `dns:"a"`: 191 o("l += net.IPv4len // %s\n") 192 case st.Tag(i) == `dns:"aaaa"`: 193 o("l += net.IPv6len // %s\n") 194 case st.Tag(i) == `dns:"txt"`: 195 o("for _, t := range rr.%s { l += len(t) + 1 }\n") 196 case st.Tag(i) == `dns:"uint48"`: 197 o("l += 6 // %s\n") 198 case st.Tag(i) == "": 199 switch st.Field(i).Type().(*types.Basic).Kind() { 200 case types.Uint8: 201 o("l++ // %s\n") 202 case types.Uint16: 203 o("l += 2 // %s\n") 204 case types.Uint32: 205 o("l += 4 // %s\n") 206 case types.Uint64: 207 o("l += 8 // %s\n") 208 case types.String: 209 o("l += len(rr.%s) + 1\n") 210 default: 211 log.Fatalln(name, st.Field(i).Name()) 212 } 213 default: 214 log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) 215 } 216 } 217 fmt.Fprintf(b, "return l }\n") 218 } 219 220 // Generate copy() 221 fmt.Fprint(b, "// copy() functions\n") 222 for _, name := range namedTypes { 223 o := scope.Lookup(name) 224 st, isEmbedded := getTypeStruct(o.Type(), scope) 225 if isEmbedded { 226 continue 227 } 228 fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name) 229 fields := []string{"*rr.Hdr.copyHeader()"} 230 for i := 1; i < st.NumFields(); i++ { 231 f := st.Field(i).Name() 232 if sl, ok := st.Field(i).Type().(*types.Slice); ok { 233 t := sl.Underlying().String() 234 t = strings.TrimPrefix(t, "[]") 235 if strings.Contains(t, ".") { 236 splits := strings.Split(t, ".") 237 t = splits[len(splits)-1] 238 } 239 fmt.Fprintf(b, "%s := make([]%s, len(rr.%s)); copy(%s, rr.%s)\n", 240 f, t, f, f, f) 241 fields = append(fields, f) 242 continue 243 } 244 if st.Field(i).Type().String() == "net.IP" { 245 fields = append(fields, "copyIP(rr."+f+")") 246 continue 247 } 248 fields = append(fields, "rr."+f) 249 } 250 fmt.Fprintf(b, "return &%s{%s}\n", name, strings.Join(fields, ",")) 251 fmt.Fprintf(b, "}\n") 252 } 253 254 // gofmt 255 res, err := format.Source(b.Bytes()) 256 if err != nil { 257 b.WriteTo(os.Stderr) 258 log.Fatal(err) 259 } 260 261 // write result 262 f, err := os.Create("ztypes.go") 263 fatalIfErr(err) 264 defer f.Close() 265 f.Write(res) 266} 267 268func fatalIfErr(err error) { 269 if err != nil { 270 log.Fatal(err) 271 } 272} 273