1//+build ignore
2
3// msg_generate.go is meant to run with go generate. It will use
4// go/{importer,types} to track down all the RR struct types. Then for each type
5// it will generate pack/unpack methods based on the struct tags. The generated source is
6// written to zmsg.go, and is meant to be checked into git.
7package main
8
9import (
10	"bytes"
11	"fmt"
12	"go/format"
13	"go/importer"
14	"go/types"
15	"log"
16	"os"
17	"strings"
18)
19
20var packageHdr = `
21// Code generated by "go run msg_generate.go"; DO NOT EDIT.
22
23package dns
24
25`
26
27// getTypeStruct will take a type and the package scope, and return the
28// (innermost) struct if the type is considered a RR type (currently defined as
29// those structs beginning with a RR_Header, could be redefined as implementing
30// the RR interface). The bool return value indicates if embedded structs were
31// resolved.
32func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
33	st, ok := t.Underlying().(*types.Struct)
34	if !ok {
35		return nil, false
36	}
37	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
38		return st, false
39	}
40	if st.Field(0).Anonymous() {
41		st, _ := getTypeStruct(st.Field(0).Type(), scope)
42		return st, true
43	}
44	return nil, false
45}
46
47func main() {
48	// Import and type-check the package
49	pkg, err := importer.Default().Import("github.com/miekg/dns")
50	fatalIfErr(err)
51	scope := pkg.Scope()
52
53	// Collect actual types (*X)
54	var namedTypes []string
55	for _, name := range scope.Names() {
56		o := scope.Lookup(name)
57		if o == nil || !o.Exported() {
58			continue
59		}
60		if st, _ := getTypeStruct(o.Type(), scope); st == nil {
61			continue
62		}
63		if name == "PrivateRR" {
64			continue
65		}
66
67		// Check if corresponding TypeX exists
68		if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
69			log.Fatalf("Constant Type%s does not exist.", o.Name())
70		}
71
72		namedTypes = append(namedTypes, o.Name())
73	}
74
75	b := &bytes.Buffer{}
76	b.WriteString(packageHdr)
77
78	fmt.Fprint(b, "// pack*() functions\n\n")
79	for _, name := range namedTypes {
80		o := scope.Lookup(name)
81		st, _ := getTypeStruct(o.Type(), scope)
82
83		fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {\n", name)
84		for i := 1; i < st.NumFields(); i++ {
85			o := func(s string) {
86				fmt.Fprintf(b, s, st.Field(i).Name())
87				fmt.Fprint(b, `if err != nil {
88return off, err
89}
90`)
91			}
92
93			if _, ok := st.Field(i).Type().(*types.Slice); ok {
94				switch st.Tag(i) {
95				case `dns:"-"`: // ignored
96				case `dns:"txt"`:
97					o("off, err = packStringTxt(rr.%s, msg, off)\n")
98				case `dns:"opt"`:
99					o("off, err = packDataOpt(rr.%s, msg, off)\n")
100				case `dns:"nsec"`:
101					o("off, err = packDataNsec(rr.%s, msg, off)\n")
102				case `dns:"domain-name"`:
103					o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n")
104				default:
105					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
106				}
107				continue
108			}
109
110			switch {
111			case st.Tag(i) == `dns:"-"`: // ignored
112			case st.Tag(i) == `dns:"cdomain-name"`:
113				o("off, err = packDomainName(rr.%s, msg, off, compression, compress)\n")
114			case st.Tag(i) == `dns:"domain-name"`:
115				o("off, err = packDomainName(rr.%s, msg, off, compression, false)\n")
116			case st.Tag(i) == `dns:"a"`:
117				o("off, err = packDataA(rr.%s, msg, off)\n")
118			case st.Tag(i) == `dns:"aaaa"`:
119				o("off, err = packDataAAAA(rr.%s, msg, off)\n")
120			case st.Tag(i) == `dns:"uint48"`:
121				o("off, err = packUint48(rr.%s, msg, off)\n")
122			case st.Tag(i) == `dns:"txt"`:
123				o("off, err = packString(rr.%s, msg, off)\n")
124
125			case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
126				fallthrough
127			case st.Tag(i) == `dns:"base32"`:
128				o("off, err = packStringBase32(rr.%s, msg, off)\n")
129
130			case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
131				fallthrough
132			case st.Tag(i) == `dns:"base64"`:
133				o("off, err = packStringBase64(rr.%s, msg, off)\n")
134
135			case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`):
136				// directly write instead of using o() so we get the error check in the correct place
137				field := st.Field(i).Name()
138				fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty
139if rr.%s != "-" {
140  off, err = packStringHex(rr.%s, msg, off)
141  if err != nil {
142    return off, err
143  }
144}
145`, field, field)
146				continue
147			case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
148				fallthrough
149			case st.Tag(i) == `dns:"hex"`:
150				o("off, err = packStringHex(rr.%s, msg, off)\n")
151			case st.Tag(i) == `dns:"any"`:
152				o("off, err = packStringAny(rr.%s, msg, off)\n")
153			case st.Tag(i) == `dns:"octet"`:
154				o("off, err = packStringOctet(rr.%s, msg, off)\n")
155			case st.Tag(i) == "":
156				switch st.Field(i).Type().(*types.Basic).Kind() {
157				case types.Uint8:
158					o("off, err = packUint8(rr.%s, msg, off)\n")
159				case types.Uint16:
160					o("off, err = packUint16(rr.%s, msg, off)\n")
161				case types.Uint32:
162					o("off, err = packUint32(rr.%s, msg, off)\n")
163				case types.Uint64:
164					o("off, err = packUint64(rr.%s, msg, off)\n")
165				case types.String:
166					o("off, err = packString(rr.%s, msg, off)\n")
167				default:
168					log.Fatalln(name, st.Field(i).Name())
169				}
170			default:
171				log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
172			}
173		}
174		fmt.Fprintln(b, "return off, nil }\n")
175	}
176
177	fmt.Fprint(b, "// unpack*() functions\n\n")
178	for _, name := range namedTypes {
179		o := scope.Lookup(name)
180		st, _ := getTypeStruct(o.Type(), scope)
181
182		fmt.Fprintf(b, "func (rr *%s) unpack(msg []byte, off int) (off1 int, err error) {\n", name)
183		fmt.Fprint(b, `rdStart := off
184_ = rdStart
185
186`)
187		for i := 1; i < st.NumFields(); i++ {
188			o := func(s string) {
189				fmt.Fprintf(b, s, st.Field(i).Name())
190				fmt.Fprint(b, `if err != nil {
191return off, err
192}
193`)
194			}
195
196			// size-* are special, because they reference a struct member we should use for the length.
197			if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
198				structMember := structMember(st.Tag(i))
199				structTag := structTag(st.Tag(i))
200				switch structTag {
201				case "hex":
202					fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
203				case "base32":
204					fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
205				case "base64":
206					fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
207				default:
208					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
209				}
210				fmt.Fprint(b, `if err != nil {
211return off, err
212}
213`)
214				continue
215			}
216
217			if _, ok := st.Field(i).Type().(*types.Slice); ok {
218				switch st.Tag(i) {
219				case `dns:"-"`: // ignored
220				case `dns:"txt"`:
221					o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
222				case `dns:"opt"`:
223					o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
224				case `dns:"nsec"`:
225					o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
226				case `dns:"domain-name"`:
227					o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
228				default:
229					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
230				}
231				continue
232			}
233
234			switch st.Tag(i) {
235			case `dns:"-"`: // ignored
236			case `dns:"cdomain-name"`:
237				fallthrough
238			case `dns:"domain-name"`:
239				o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
240			case `dns:"a"`:
241				o("rr.%s, off, err = unpackDataA(msg, off)\n")
242			case `dns:"aaaa"`:
243				o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
244			case `dns:"uint48"`:
245				o("rr.%s, off, err = unpackUint48(msg, off)\n")
246			case `dns:"txt"`:
247				o("rr.%s, off, err = unpackString(msg, off)\n")
248			case `dns:"base32"`:
249				o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
250			case `dns:"base64"`:
251				o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
252			case `dns:"hex"`:
253				o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
254			case `dns:"any"`:
255				o("rr.%s, off, err = unpackStringAny(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
256			case `dns:"octet"`:
257				o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
258			case "":
259				switch st.Field(i).Type().(*types.Basic).Kind() {
260				case types.Uint8:
261					o("rr.%s, off, err = unpackUint8(msg, off)\n")
262				case types.Uint16:
263					o("rr.%s, off, err = unpackUint16(msg, off)\n")
264				case types.Uint32:
265					o("rr.%s, off, err = unpackUint32(msg, off)\n")
266				case types.Uint64:
267					o("rr.%s, off, err = unpackUint64(msg, off)\n")
268				case types.String:
269					o("rr.%s, off, err = unpackString(msg, off)\n")
270				default:
271					log.Fatalln(name, st.Field(i).Name())
272				}
273			default:
274				log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
275			}
276			// If we've hit len(msg) we return without error.
277			if i < st.NumFields()-1 {
278				fmt.Fprintf(b, `if off == len(msg) {
279return off, nil
280	}
281`)
282			}
283		}
284		fmt.Fprintf(b, "return off, nil }\n\n")
285	}
286
287	// gofmt
288	res, err := format.Source(b.Bytes())
289	if err != nil {
290		b.WriteTo(os.Stderr)
291		log.Fatal(err)
292	}
293
294	// write result
295	f, err := os.Create("zmsg.go")
296	fatalIfErr(err)
297	defer f.Close()
298	f.Write(res)
299}
300
301// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
302func structMember(s string) string {
303	fields := strings.Split(s, ":")
304	if len(fields) == 0 {
305		return ""
306	}
307	f := fields[len(fields)-1]
308	// f should have a closing "
309	if len(f) > 1 {
310		return f[:len(f)-1]
311	}
312	return f
313}
314
315// structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
316func structTag(s string) string {
317	fields := strings.Split(s, ":")
318	if len(fields) < 2 {
319		return ""
320	}
321	return fields[1][len("\"size-"):]
322}
323
324func fatalIfErr(err error) {
325	if err != nil {
326		log.Fatal(err)
327	}
328}
329