1//+build ignore
2
3// msg_generate.go is meant to run with go generate. It will use
4// go/{importer,types} to track down all the RR struct types. Then for each type
5// it will generate pack/unpack methods based on the struct tags. The generated source is
6// written to zmsg.go, and is meant to be checked into git.
7package main
8
9import (
10	"bytes"
11	"fmt"
12	"go/format"
13	"go/importer"
14	"go/types"
15	"log"
16	"os"
17	"strings"
18)
19
20var packageHdr = `
21// *** DO NOT MODIFY ***
22// AUTOGENERATED BY go generate from msg_generate.go
23
24package dns
25
26`
27
28// getTypeStruct will take a type and the package scope, and return the
29// (innermost) struct if the type is considered a RR type (currently defined as
30// those structs beginning with a RR_Header, could be redefined as implementing
31// the RR interface). The bool return value indicates if embedded structs were
32// resolved.
33func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
34	st, ok := t.Underlying().(*types.Struct)
35	if !ok {
36		return nil, false
37	}
38	if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
39		return st, false
40	}
41	if st.Field(0).Anonymous() {
42		st, _ := getTypeStruct(st.Field(0).Type(), scope)
43		return st, true
44	}
45	return nil, false
46}
47
48func main() {
49	// Import and type-check the package
50	pkg, err := importer.Default().Import("github.com/miekg/dns")
51	fatalIfErr(err)
52	scope := pkg.Scope()
53
54	// Collect actual types (*X)
55	var namedTypes []string
56	for _, name := range scope.Names() {
57		o := scope.Lookup(name)
58		if o == nil || !o.Exported() {
59			continue
60		}
61		if st, _ := getTypeStruct(o.Type(), scope); st == nil {
62			continue
63		}
64		if name == "PrivateRR" {
65			continue
66		}
67
68		// Check if corresponding TypeX exists
69		if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
70			log.Fatalf("Constant Type%s does not exist.", o.Name())
71		}
72
73		namedTypes = append(namedTypes, o.Name())
74	}
75
76	b := &bytes.Buffer{}
77	b.WriteString(packageHdr)
78
79	fmt.Fprint(b, "// pack*() functions\n\n")
80	for _, name := range namedTypes {
81		o := scope.Lookup(name)
82		st, _ := getTypeStruct(o.Type(), scope)
83
84		fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name)
85		fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress)
86if err != nil {
87	return off, err
88}
89headerEnd := off
90`)
91		for i := 1; i < st.NumFields(); i++ {
92			o := func(s string) {
93				fmt.Fprintf(b, s, st.Field(i).Name())
94				fmt.Fprint(b, `if err != nil {
95return off, err
96}
97`)
98			}
99
100			if _, ok := st.Field(i).Type().(*types.Slice); ok {
101				switch st.Tag(i) {
102				case `dns:"-"`: // ignored
103				case `dns:"txt"`:
104					o("off, err = packStringTxt(rr.%s, msg, off)\n")
105				case `dns:"opt"`:
106					o("off, err = packDataOpt(rr.%s, msg, off)\n")
107				case `dns:"nsec"`:
108					o("off, err = packDataNsec(rr.%s, msg, off)\n")
109				case `dns:"domain-name"`:
110					o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n")
111				default:
112					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
113				}
114				continue
115			}
116
117			switch {
118			case st.Tag(i) == `dns:"-"`: // ignored
119			case st.Tag(i) == `dns:"cdomain-name"`:
120				o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n")
121			case st.Tag(i) == `dns:"domain-name"`:
122				o("off, err = PackDomainName(rr.%s, msg, off, compression, false)\n")
123			case st.Tag(i) == `dns:"a"`:
124				o("off, err = packDataA(rr.%s, msg, off)\n")
125			case st.Tag(i) == `dns:"aaaa"`:
126				o("off, err = packDataAAAA(rr.%s, msg, off)\n")
127			case st.Tag(i) == `dns:"uint48"`:
128				o("off, err = packUint48(rr.%s, msg, off)\n")
129			case st.Tag(i) == `dns:"txt"`:
130				o("off, err = packString(rr.%s, msg, off)\n")
131
132			case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
133				fallthrough
134			case st.Tag(i) == `dns:"base32"`:
135				o("off, err = packStringBase32(rr.%s, msg, off)\n")
136
137			case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
138				fallthrough
139			case st.Tag(i) == `dns:"base64"`:
140				o("off, err = packStringBase64(rr.%s, msg, off)\n")
141
142			case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`): // Hack to fix empty salt length for NSEC3
143				o("if rr.%s == \"-\" { /* do nothing, empty salt */ }\n")
144				continue
145			case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
146				fallthrough
147			case st.Tag(i) == `dns:"hex"`:
148				o("off, err = packStringHex(rr.%s, msg, off)\n")
149
150			case st.Tag(i) == `dns:"octet"`:
151				o("off, err = packStringOctet(rr.%s, msg, off)\n")
152			case st.Tag(i) == "":
153				switch st.Field(i).Type().(*types.Basic).Kind() {
154				case types.Uint8:
155					o("off, err = packUint8(rr.%s, msg, off)\n")
156				case types.Uint16:
157					o("off, err = packUint16(rr.%s, msg, off)\n")
158				case types.Uint32:
159					o("off, err = packUint32(rr.%s, msg, off)\n")
160				case types.Uint64:
161					o("off, err = packUint64(rr.%s, msg, off)\n")
162				case types.String:
163					o("off, err = packString(rr.%s, msg, off)\n")
164				default:
165					log.Fatalln(name, st.Field(i).Name())
166				}
167			default:
168				log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
169			}
170		}
171		// We have packed everything, only now we know the rdlength of this RR
172		fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)")
173		fmt.Fprintln(b, "return off, nil }\n")
174	}
175
176	fmt.Fprint(b, "// unpack*() functions\n\n")
177	for _, name := range namedTypes {
178		o := scope.Lookup(name)
179		st, _ := getTypeStruct(o.Type(), scope)
180
181		fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name)
182		fmt.Fprintf(b, "rr := new(%s)\n", name)
183		fmt.Fprint(b, "rr.Hdr = h\n")
184		fmt.Fprint(b, `if noRdata(h) {
185return rr, off, nil
186	}
187var err error
188rdStart := off
189_ = rdStart
190
191`)
192		for i := 1; i < st.NumFields(); i++ {
193			o := func(s string) {
194				fmt.Fprintf(b, s, st.Field(i).Name())
195				fmt.Fprint(b, `if err != nil {
196return rr, off, err
197}
198`)
199			}
200
201			// size-* are special, because they reference a struct member we should use for the length.
202			if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
203				structMember := structMember(st.Tag(i))
204				structTag := structTag(st.Tag(i))
205				switch structTag {
206				case "hex":
207					fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
208				case "base32":
209					fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
210				case "base64":
211					fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
212				default:
213					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
214				}
215				fmt.Fprint(b, `if err != nil {
216return rr, off, err
217}
218`)
219				continue
220			}
221
222			if _, ok := st.Field(i).Type().(*types.Slice); ok {
223				switch st.Tag(i) {
224				case `dns:"-"`: // ignored
225				case `dns:"txt"`:
226					o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
227				case `dns:"opt"`:
228					o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
229				case `dns:"nsec"`:
230					o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
231				case `dns:"domain-name"`:
232					o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
233				default:
234					log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
235				}
236				continue
237			}
238
239			switch st.Tag(i) {
240			case `dns:"-"`: // ignored
241			case `dns:"cdomain-name"`:
242				fallthrough
243			case `dns:"domain-name"`:
244				o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
245			case `dns:"a"`:
246				o("rr.%s, off, err = unpackDataA(msg, off)\n")
247			case `dns:"aaaa"`:
248				o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
249			case `dns:"uint48"`:
250				o("rr.%s, off, err = unpackUint48(msg, off)\n")
251			case `dns:"txt"`:
252				o("rr.%s, off, err = unpackString(msg, off)\n")
253			case `dns:"base32"`:
254				o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
255			case `dns:"base64"`:
256				o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
257			case `dns:"hex"`:
258				o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
259			case `dns:"octet"`:
260				o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
261			case "":
262				switch st.Field(i).Type().(*types.Basic).Kind() {
263				case types.Uint8:
264					o("rr.%s, off, err = unpackUint8(msg, off)\n")
265				case types.Uint16:
266					o("rr.%s, off, err = unpackUint16(msg, off)\n")
267				case types.Uint32:
268					o("rr.%s, off, err = unpackUint32(msg, off)\n")
269				case types.Uint64:
270					o("rr.%s, off, err = unpackUint64(msg, off)\n")
271				case types.String:
272					o("rr.%s, off, err = unpackString(msg, off)\n")
273				default:
274					log.Fatalln(name, st.Field(i).Name())
275				}
276			default:
277				log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
278			}
279			// If we've hit len(msg) we return without error.
280			if i < st.NumFields()-1 {
281				fmt.Fprintf(b, `if off == len(msg) {
282return rr, off, nil
283	}
284`)
285			}
286		}
287		fmt.Fprintf(b, "return rr, off, err }\n\n")
288	}
289	// Generate typeToUnpack map
290	fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){")
291	for _, name := range namedTypes {
292		if name == "RFC3597" {
293			continue
294		}
295		fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name)
296	}
297	fmt.Fprintln(b, "}\n")
298
299	// gofmt
300	res, err := format.Source(b.Bytes())
301	if err != nil {
302		b.WriteTo(os.Stderr)
303		log.Fatal(err)
304	}
305
306	// write result
307	f, err := os.Create("zmsg.go")
308	fatalIfErr(err)
309	defer f.Close()
310	f.Write(res)
311}
312
313// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
314func structMember(s string) string {
315	fields := strings.Split(s, ":")
316	if len(fields) == 0 {
317		return ""
318	}
319	f := fields[len(fields)-1]
320	// f should have a closing "
321	if len(f) > 1 {
322		return f[:len(f)-1]
323	}
324	return f
325}
326
327// structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
328func structTag(s string) string {
329	fields := strings.Split(s, ":")
330	if len(fields) < 2 {
331		return ""
332	}
333	return fields[1][len("\"size-"):]
334}
335
336func fatalIfErr(err error) {
337	if err != nil {
338		log.Fatal(err)
339	}
340}
341