1//+build ignore
2
3// msg_generate.go is meant to run with go generate. It will use
4// go/{importer,types} to track down all the RR struct types. Then for each type
5// it will generate pack/unpack methods based on the struct tags. The generated source is
6// written to zmsg.go, and is meant to be checked into git.
7package main
8
9import (
10	"bytes"
11	"fmt"
12	"go/format"
13	"go/types"
14	"log"
15	"os"
16	"strings"
17
18	"golang.org/x/tools/go/packages"
19)
20
21var packageHdr = `
22// Code generated by "go run msg_generate.go"; DO NOT EDIT.
23
24package dns
25
26`
27
28// getTypeStruct will take a type and the package scope, and return the
29// (innermost) struct if the type is considered a RR type (currently defined as
30// those structs beginning with a RR_Header, could be redefined as implementing
31// the RR interface). The bool return value indicates if embedded structs were
32// resolved.
33func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
34	st, ok := t.Underlying().(*types.Struct)
35	if !ok {
36		return nil, false
37	}
38	if st.NumFields() == 0 {
39		return nil, false
40	}
41	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
42		return st, false
43	}
44	if st.Field(0).Anonymous() {
45		st, _ := getTypeStruct(st.Field(0).Type(), scope)
46		return st, true
47	}
48	return nil, false
49}
50
51// loadModule retrieves package description for a given module.
52func loadModule(name string) (*types.Package, error) {
53	conf := packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo}
54	pkgs, err := packages.Load(&conf, name)
55	if err != nil {
56		return nil, err
57	}
58	return pkgs[0].Types, nil
59}
60
61func main() {
62	// Import and type-check the package
63	pkg, err := loadModule("github.com/miekg/dns")
64	fatalIfErr(err)
65	scope := pkg.Scope()
66
67	// Collect actual types (*X)
68	var namedTypes []string
69	for _, name := range scope.Names() {
70		o := scope.Lookup(name)
71		if o == nil || !o.Exported() {
72			continue
73		}
74		if st, _ := getTypeStruct(o.Type(), scope); st == nil {
75			continue
76		}
77		if name == "PrivateRR" {
78			continue
79		}
80
81		// Check if corresponding TypeX exists
82		if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
83			log.Fatalf("Constant Type%s does not exist.", o.Name())
84		}
85
86		namedTypes = append(namedTypes, o.Name())
87	}
88
89	b := &bytes.Buffer{}
90	b.WriteString(packageHdr)
91
92	fmt.Fprint(b, "// pack*() functions\n\n")
93	for _, name := range namedTypes {
94		o := scope.Lookup(name)
95		st, _ := getTypeStruct(o.Type(), scope)
96
97		fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {\n", name)
98		for i := 1; i < st.NumFields(); i++ {
99			o := func(s string) {
100				fmt.Fprintf(b, s, st.Field(i).Name())
101				fmt.Fprint(b, `if err != nil {
102return off, err
103}
104`)
105			}
106
107			if _, ok := st.Field(i).Type().(*types.Slice); ok {
108				switch st.Tag(i) {
109				case `dns:"-"`: // ignored
110				case `dns:"txt"`:
111					o("off, err = packStringTxt(rr.%s, msg, off)\n")
112				case `dns:"opt"`:
113					o("off, err = packDataOpt(rr.%s, msg, off)\n")
114				case `dns:"nsec"`:
115					o("off, err = packDataNsec(rr.%s, msg, off)\n")
116				case `dns:"pairs"`:
117					o("off, err = packDataSVCB(rr.%s, msg, off)\n")
118				case `dns:"domain-name"`:
119					o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n")
120				case `dns:"apl"`:
121					o("off, err = packDataApl(rr.%s, msg, off)\n")
122				default:
123					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
124				}
125				continue
126			}
127
128			switch {
129			case st.Tag(i) == `dns:"-"`: // ignored
130			case st.Tag(i) == `dns:"cdomain-name"`:
131				o("off, err = packDomainName(rr.%s, msg, off, compression, compress)\n")
132			case st.Tag(i) == `dns:"domain-name"`:
133				o("off, err = packDomainName(rr.%s, msg, off, compression, false)\n")
134			case st.Tag(i) == `dns:"a"`:
135				o("off, err = packDataA(rr.%s, msg, off)\n")
136			case st.Tag(i) == `dns:"aaaa"`:
137				o("off, err = packDataAAAA(rr.%s, msg, off)\n")
138			case st.Tag(i) == `dns:"uint48"`:
139				o("off, err = packUint48(rr.%s, msg, off)\n")
140			case st.Tag(i) == `dns:"txt"`:
141				o("off, err = packString(rr.%s, msg, off)\n")
142
143			case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
144				fallthrough
145			case st.Tag(i) == `dns:"base32"`:
146				o("off, err = packStringBase32(rr.%s, msg, off)\n")
147
148			case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
149				fallthrough
150			case st.Tag(i) == `dns:"base64"`:
151				o("off, err = packStringBase64(rr.%s, msg, off)\n")
152
153			case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`):
154				// directly write instead of using o() so we get the error check in the correct place
155				field := st.Field(i).Name()
156				fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty
157if rr.%s != "-" {
158  off, err = packStringHex(rr.%s, msg, off)
159  if err != nil {
160    return off, err
161  }
162}
163`, field, field)
164				continue
165			case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
166				fallthrough
167			case st.Tag(i) == `dns:"hex"`:
168				o("off, err = packStringHex(rr.%s, msg, off)\n")
169			case st.Tag(i) == `dns:"any"`:
170				o("off, err = packStringAny(rr.%s, msg, off)\n")
171			case st.Tag(i) == `dns:"octet"`:
172				o("off, err = packStringOctet(rr.%s, msg, off)\n")
173			case st.Tag(i) == "":
174				switch st.Field(i).Type().(*types.Basic).Kind() {
175				case types.Uint8:
176					o("off, err = packUint8(rr.%s, msg, off)\n")
177				case types.Uint16:
178					o("off, err = packUint16(rr.%s, msg, off)\n")
179				case types.Uint32:
180					o("off, err = packUint32(rr.%s, msg, off)\n")
181				case types.Uint64:
182					o("off, err = packUint64(rr.%s, msg, off)\n")
183				case types.String:
184					o("off, err = packString(rr.%s, msg, off)\n")
185				default:
186					log.Fatalln(name, st.Field(i).Name())
187				}
188			default:
189				log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
190			}
191		}
192		fmt.Fprintln(b, "return off, nil }\n")
193	}
194
195	fmt.Fprint(b, "// unpack*() functions\n\n")
196	for _, name := range namedTypes {
197		o := scope.Lookup(name)
198		st, _ := getTypeStruct(o.Type(), scope)
199
200		fmt.Fprintf(b, "func (rr *%s) unpack(msg []byte, off int) (off1 int, err error) {\n", name)
201		fmt.Fprint(b, `rdStart := off
202_ = rdStart
203
204`)
205		for i := 1; i < st.NumFields(); i++ {
206			o := func(s string) {
207				fmt.Fprintf(b, s, st.Field(i).Name())
208				fmt.Fprint(b, `if err != nil {
209return off, err
210}
211`)
212			}
213
214			// size-* are special, because they reference a struct member we should use for the length.
215			if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
216				structMember := structMember(st.Tag(i))
217				structTag := structTag(st.Tag(i))
218				switch structTag {
219				case "hex":
220					fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
221				case "base32":
222					fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
223				case "base64":
224					fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
225				default:
226					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
227				}
228				fmt.Fprint(b, `if err != nil {
229return off, err
230}
231`)
232				continue
233			}
234
235			if _, ok := st.Field(i).Type().(*types.Slice); ok {
236				switch st.Tag(i) {
237				case `dns:"-"`: // ignored
238				case `dns:"txt"`:
239					o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
240				case `dns:"opt"`:
241					o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
242				case `dns:"nsec"`:
243					o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
244				case `dns:"pairs"`:
245					o("rr.%s, off, err = unpackDataSVCB(msg, off)\n")
246				case `dns:"domain-name"`:
247					o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
248				case `dns:"apl"`:
249					o("rr.%s, off, err = unpackDataApl(msg, off)\n")
250				default:
251					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
252				}
253				continue
254			}
255
256			switch st.Tag(i) {
257			case `dns:"-"`: // ignored
258			case `dns:"cdomain-name"`:
259				fallthrough
260			case `dns:"domain-name"`:
261				o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
262			case `dns:"a"`:
263				o("rr.%s, off, err = unpackDataA(msg, off)\n")
264			case `dns:"aaaa"`:
265				o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
266			case `dns:"uint48"`:
267				o("rr.%s, off, err = unpackUint48(msg, off)\n")
268			case `dns:"txt"`:
269				o("rr.%s, off, err = unpackString(msg, off)\n")
270			case `dns:"base32"`:
271				o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
272			case `dns:"base64"`:
273				o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
274			case `dns:"hex"`:
275				o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
276			case `dns:"any"`:
277				o("rr.%s, off, err = unpackStringAny(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
278			case `dns:"octet"`:
279				o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
280			case "":
281				switch st.Field(i).Type().(*types.Basic).Kind() {
282				case types.Uint8:
283					o("rr.%s, off, err = unpackUint8(msg, off)\n")
284				case types.Uint16:
285					o("rr.%s, off, err = unpackUint16(msg, off)\n")
286				case types.Uint32:
287					o("rr.%s, off, err = unpackUint32(msg, off)\n")
288				case types.Uint64:
289					o("rr.%s, off, err = unpackUint64(msg, off)\n")
290				case types.String:
291					o("rr.%s, off, err = unpackString(msg, off)\n")
292				default:
293					log.Fatalln(name, st.Field(i).Name())
294				}
295			default:
296				log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
297			}
298			// If we've hit len(msg) we return without error.
299			if i < st.NumFields()-1 {
300				fmt.Fprintf(b, `if off == len(msg) {
301return off, nil
302	}
303`)
304			}
305		}
306		fmt.Fprintf(b, "return off, nil }\n\n")
307	}
308
309	// gofmt
310	res, err := format.Source(b.Bytes())
311	if err != nil {
312		b.WriteTo(os.Stderr)
313		log.Fatal(err)
314	}
315
316	// write result
317	f, err := os.Create("zmsg.go")
318	fatalIfErr(err)
319	defer f.Close()
320	f.Write(res)
321}
322
323// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
324func structMember(s string) string {
325	fields := strings.Split(s, ":")
326	if len(fields) == 0 {
327		return ""
328	}
329	f := fields[len(fields)-1]
330	// f should have a closing "
331	if len(f) > 1 {
332		return f[:len(f)-1]
333	}
334	return f
335}
336
337// structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
338func structTag(s string) string {
339	fields := strings.Split(s, ":")
340	if len(fields) < 2 {
341		return ""
342	}
343	return fields[1][len("\"size-"):]
344}
345
346func fatalIfErr(err error) {
347	if err != nil {
348		log.Fatal(err)
349	}
350}
351