1//+build ignore
2
3// msg_generate.go is meant to run with go generate. It will use
4// go/{importer,types} to track down all the RR struct types. Then for each type
5// it will generate pack/unpack methods based on the struct tags. The generated source is
6// written to zmsg.go, and is meant to be checked into git.
7package main
8
9import (
10	"bytes"
11	"fmt"
12	"go/format"
13	"go/importer"
14	"go/types"
15	"log"
16	"os"
17	"strings"
18)
19
20var packageHdr = `
21// Code generated by "go run msg_generate.go"; DO NOT EDIT.
22
23package dns
24
25`
26
27// getTypeStruct will take a type and the package scope, and return the
28// (innermost) struct if the type is considered a RR type (currently defined as
29// those structs beginning with a RR_Header, could be redefined as implementing
30// the RR interface). The bool return value indicates if embedded structs were
31// resolved.
32func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
33	st, ok := t.Underlying().(*types.Struct)
34	if !ok {
35		return nil, false
36	}
37	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
38		return st, false
39	}
40	if st.Field(0).Anonymous() {
41		st, _ := getTypeStruct(st.Field(0).Type(), scope)
42		return st, true
43	}
44	return nil, false
45}
46
47func main() {
48	// Import and type-check the package
49	pkg, err := importer.Default().Import("github.com/miekg/dns")
50	fatalIfErr(err)
51	scope := pkg.Scope()
52
53	// Collect actual types (*X)
54	var namedTypes []string
55	for _, name := range scope.Names() {
56		o := scope.Lookup(name)
57		if o == nil || !o.Exported() {
58			continue
59		}
60		if st, _ := getTypeStruct(o.Type(), scope); st == nil {
61			continue
62		}
63		if name == "PrivateRR" {
64			continue
65		}
66
67		// Check if corresponding TypeX exists
68		if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
69			log.Fatalf("Constant Type%s does not exist.", o.Name())
70		}
71
72		namedTypes = append(namedTypes, o.Name())
73	}
74
75	b := &bytes.Buffer{}
76	b.WriteString(packageHdr)
77
78	fmt.Fprint(b, "// pack*() functions\n\n")
79	for _, name := range namedTypes {
80		o := scope.Lookup(name)
81		st, _ := getTypeStruct(o.Type(), scope)
82
83		fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name)
84		fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress)
85if err != nil {
86	return off, err
87}
88headerEnd := off
89`)
90		for i := 1; i < st.NumFields(); i++ {
91			o := func(s string) {
92				fmt.Fprintf(b, s, st.Field(i).Name())
93				fmt.Fprint(b, `if err != nil {
94return off, err
95}
96`)
97			}
98
99			if _, ok := st.Field(i).Type().(*types.Slice); ok {
100				switch st.Tag(i) {
101				case `dns:"-"`: // ignored
102				case `dns:"txt"`:
103					o("off, err = packStringTxt(rr.%s, msg, off)\n")
104				case `dns:"opt"`:
105					o("off, err = packDataOpt(rr.%s, msg, off)\n")
106				case `dns:"nsec"`:
107					o("off, err = packDataNsec(rr.%s, msg, off)\n")
108				case `dns:"domain-name"`:
109					o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n")
110				default:
111					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
112				}
113				continue
114			}
115
116			switch {
117			case st.Tag(i) == `dns:"-"`: // ignored
118			case st.Tag(i) == `dns:"cdomain-name"`:
119				o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n")
120			case st.Tag(i) == `dns:"domain-name"`:
121				o("off, err = PackDomainName(rr.%s, msg, off, compression, false)\n")
122			case st.Tag(i) == `dns:"a"`:
123				o("off, err = packDataA(rr.%s, msg, off)\n")
124			case st.Tag(i) == `dns:"aaaa"`:
125				o("off, err = packDataAAAA(rr.%s, msg, off)\n")
126			case st.Tag(i) == `dns:"uint48"`:
127				o("off, err = packUint48(rr.%s, msg, off)\n")
128			case st.Tag(i) == `dns:"txt"`:
129				o("off, err = packString(rr.%s, msg, off)\n")
130
131			case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
132				fallthrough
133			case st.Tag(i) == `dns:"base32"`:
134				o("off, err = packStringBase32(rr.%s, msg, off)\n")
135
136			case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
137				fallthrough
138			case st.Tag(i) == `dns:"base64"`:
139				o("off, err = packStringBase64(rr.%s, msg, off)\n")
140
141			case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`):
142				// directly write instead of using o() so we get the error check in the correct place
143				field := st.Field(i).Name()
144				fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty
145if rr.%s != "-" {
146  off, err = packStringHex(rr.%s, msg, off)
147  if err != nil {
148    return off, err
149  }
150}
151`, field, field)
152				continue
153			case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
154				fallthrough
155			case st.Tag(i) == `dns:"hex"`:
156				o("off, err = packStringHex(rr.%s, msg, off)\n")
157
158			case st.Tag(i) == `dns:"octet"`:
159				o("off, err = packStringOctet(rr.%s, msg, off)\n")
160			case st.Tag(i) == "":
161				switch st.Field(i).Type().(*types.Basic).Kind() {
162				case types.Uint8:
163					o("off, err = packUint8(rr.%s, msg, off)\n")
164				case types.Uint16:
165					o("off, err = packUint16(rr.%s, msg, off)\n")
166				case types.Uint32:
167					o("off, err = packUint32(rr.%s, msg, off)\n")
168				case types.Uint64:
169					o("off, err = packUint64(rr.%s, msg, off)\n")
170				case types.String:
171					o("off, err = packString(rr.%s, msg, off)\n")
172				default:
173					log.Fatalln(name, st.Field(i).Name())
174				}
175			default:
176				log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
177			}
178		}
179		// We have packed everything, only now we know the rdlength of this RR
180		fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)")
181		fmt.Fprintln(b, "return off, nil }\n")
182	}
183
184	fmt.Fprint(b, "// unpack*() functions\n\n")
185	for _, name := range namedTypes {
186		o := scope.Lookup(name)
187		st, _ := getTypeStruct(o.Type(), scope)
188
189		fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name)
190		fmt.Fprintf(b, "rr := new(%s)\n", name)
191		fmt.Fprint(b, "rr.Hdr = h\n")
192		fmt.Fprint(b, `if noRdata(h) {
193return rr, off, nil
194	}
195var err error
196rdStart := off
197_ = rdStart
198
199`)
200		for i := 1; i < st.NumFields(); i++ {
201			o := func(s string) {
202				fmt.Fprintf(b, s, st.Field(i).Name())
203				fmt.Fprint(b, `if err != nil {
204return rr, off, err
205}
206`)
207			}
208
209			// size-* are special, because they reference a struct member we should use for the length.
210			if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
211				structMember := structMember(st.Tag(i))
212				structTag := structTag(st.Tag(i))
213				switch structTag {
214				case "hex":
215					fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
216				case "base32":
217					fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
218				case "base64":
219					fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
220				default:
221					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
222				}
223				fmt.Fprint(b, `if err != nil {
224return rr, off, err
225}
226`)
227				continue
228			}
229
230			if _, ok := st.Field(i).Type().(*types.Slice); ok {
231				switch st.Tag(i) {
232				case `dns:"-"`: // ignored
233				case `dns:"txt"`:
234					o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
235				case `dns:"opt"`:
236					o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
237				case `dns:"nsec"`:
238					o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
239				case `dns:"domain-name"`:
240					o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
241				default:
242					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
243				}
244				continue
245			}
246
247			switch st.Tag(i) {
248			case `dns:"-"`: // ignored
249			case `dns:"cdomain-name"`:
250				fallthrough
251			case `dns:"domain-name"`:
252				o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
253			case `dns:"a"`:
254				o("rr.%s, off, err = unpackDataA(msg, off)\n")
255			case `dns:"aaaa"`:
256				o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
257			case `dns:"uint48"`:
258				o("rr.%s, off, err = unpackUint48(msg, off)\n")
259			case `dns:"txt"`:
260				o("rr.%s, off, err = unpackString(msg, off)\n")
261			case `dns:"base32"`:
262				o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
263			case `dns:"base64"`:
264				o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
265			case `dns:"hex"`:
266				o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
267			case `dns:"octet"`:
268				o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
269			case "":
270				switch st.Field(i).Type().(*types.Basic).Kind() {
271				case types.Uint8:
272					o("rr.%s, off, err = unpackUint8(msg, off)\n")
273				case types.Uint16:
274					o("rr.%s, off, err = unpackUint16(msg, off)\n")
275				case types.Uint32:
276					o("rr.%s, off, err = unpackUint32(msg, off)\n")
277				case types.Uint64:
278					o("rr.%s, off, err = unpackUint64(msg, off)\n")
279				case types.String:
280					o("rr.%s, off, err = unpackString(msg, off)\n")
281				default:
282					log.Fatalln(name, st.Field(i).Name())
283				}
284			default:
285				log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
286			}
287			// If we've hit len(msg) we return without error.
288			if i < st.NumFields()-1 {
289				fmt.Fprintf(b, `if off == len(msg) {
290return rr, off, nil
291	}
292`)
293			}
294		}
295		fmt.Fprintf(b, "return rr, off, err }\n\n")
296	}
297	// Generate typeToUnpack map
298	fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){")
299	for _, name := range namedTypes {
300		if name == "RFC3597" {
301			continue
302		}
303		fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name)
304	}
305	fmt.Fprintln(b, "}\n")
306
307	// gofmt
308	res, err := format.Source(b.Bytes())
309	if err != nil {
310		b.WriteTo(os.Stderr)
311		log.Fatal(err)
312	}
313
314	// write result
315	f, err := os.Create("zmsg.go")
316	fatalIfErr(err)
317	defer f.Close()
318	f.Write(res)
319}
320
321// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
322func structMember(s string) string {
323	fields := strings.Split(s, ":")
324	if len(fields) == 0 {
325		return ""
326	}
327	f := fields[len(fields)-1]
328	// f should have a closing "
329	if len(f) > 1 {
330		return f[:len(f)-1]
331	}
332	return f
333}
334
335// structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
336func structTag(s string) string {
337	fields := strings.Split(s, ":")
338	if len(fields) < 2 {
339		return ""
340	}
341	return fields[1][len("\"size-"):]
342}
343
344func fatalIfErr(err error) {
345	if err != nil {
346		log.Fatal(err)
347	}
348}
349