1//+build ignore 2 3// types_generate.go is meant to run with go generate. It will use 4// go/{importer,types} to track down all the RR struct types. Then for each type 5// it will generate conversion tables (TypeToRR and TypeToString) and banal 6// methods (len, Header, copy) based on the struct tags. The generated source is 7// written to ztypes.go, and is meant to be checked into git. 8package main 9 10import ( 11 "bytes" 12 "fmt" 13 "go/format" 14 "go/types" 15 "log" 16 "os" 17 18 "golang.org/x/tools/go/packages" 19) 20 21var packageHdr = ` 22// Code generated by "go run duplicate_generate.go"; DO NOT EDIT. 23 24package dns 25 26` 27 28func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) { 29 st, ok := t.Underlying().(*types.Struct) 30 if !ok { 31 return nil, false 32 } 33 if st.Field(0).Type() == scope.Lookup("RR_Header").Type() { 34 return st, false 35 } 36 if st.Field(0).Anonymous() { 37 st, _ := getTypeStruct(st.Field(0).Type(), scope) 38 return st, true 39 } 40 return nil, false 41} 42 43// loadModule retrieves package description for a given module. 44func loadModule(name string) (*types.Package, error) { 45 conf := packages.Config{Mode: packages.NeedTypes | packages.NeedTypesInfo} 46 pkgs, err := packages.Load(&conf, name) 47 if err != nil { 48 return nil, err 49 } 50 return pkgs[0].Types, nil 51} 52 53func main() { 54 // Import and type-check the package 55 pkg, err := loadModule("github.com/miekg/dns") 56 fatalIfErr(err) 57 scope := pkg.Scope() 58 59 // Collect actual types (*X) 60 var namedTypes []string 61 for _, name := range scope.Names() { 62 o := scope.Lookup(name) 63 if o == nil || !o.Exported() { 64 continue 65 } 66 67 if st, _ := getTypeStruct(o.Type(), scope); st == nil { 68 continue 69 } 70 71 if name == "PrivateRR" || name == "OPT" { 72 continue 73 } 74 75 namedTypes = append(namedTypes, o.Name()) 76 } 77 78 b := &bytes.Buffer{} 79 b.WriteString(packageHdr) 80 81 // Generate the duplicate check for each type. 82 fmt.Fprint(b, "// isDuplicate() functions\n\n") 83 for _, name := range namedTypes { 84 85 o := scope.Lookup(name) 86 st, isEmbedded := getTypeStruct(o.Type(), scope) 87 if isEmbedded { 88 continue 89 } 90 fmt.Fprintf(b, "func (r1 *%s) isDuplicate(_r2 RR) bool {\n", name) 91 fmt.Fprintf(b, "r2, ok := _r2.(*%s)\n", name) 92 fmt.Fprint(b, "if !ok { return false }\n") 93 fmt.Fprint(b, "_ = r2\n") 94 for i := 1; i < st.NumFields(); i++ { 95 field := st.Field(i).Name() 96 o2 := func(s string) { fmt.Fprintf(b, s+"\n", field, field) } 97 o3 := func(s string) { fmt.Fprintf(b, s+"\n", field, field, field) } 98 99 // For some reason, a and aaaa don't pop up as *types.Slice here (mostly like because the are 100 // *indirectly* defined as a slice in the net package). 101 if _, ok := st.Field(i).Type().(*types.Slice); ok { 102 o2("if len(r1.%s) != len(r2.%s) {\nreturn false\n}") 103 104 if st.Tag(i) == `dns:"cdomain-name"` || st.Tag(i) == `dns:"domain-name"` { 105 o3(`for i := 0; i < len(r1.%s); i++ { 106 if !isDuplicateName(r1.%s[i], r2.%s[i]) { 107 return false 108 } 109 }`) 110 111 continue 112 } 113 114 if st.Tag(i) == `dns:"apl"` { 115 o3(`for i := 0; i < len(r1.%s); i++ { 116 if !r1.%s[i].equals(&r2.%s[i]) { 117 return false 118 } 119 }`) 120 121 continue 122 } 123 124 o3(`for i := 0; i < len(r1.%s); i++ { 125 if r1.%s[i] != r2.%s[i] { 126 return false 127 } 128 }`) 129 130 continue 131 } 132 133 switch st.Tag(i) { 134 case `dns:"-"`: 135 // ignored 136 case `dns:"a"`, `dns:"aaaa"`: 137 o2("if !r1.%s.Equal(r2.%s) {\nreturn false\n}") 138 case `dns:"cdomain-name"`, `dns:"domain-name"`: 139 o2("if !isDuplicateName(r1.%s, r2.%s) {\nreturn false\n}") 140 default: 141 o2("if r1.%s != r2.%s {\nreturn false\n}") 142 } 143 } 144 fmt.Fprintf(b, "return true\n}\n\n") 145 } 146 147 // gofmt 148 res, err := format.Source(b.Bytes()) 149 if err != nil { 150 b.WriteTo(os.Stderr) 151 log.Fatal(err) 152 } 153 154 // write result 155 f, err := os.Create("zduplicate.go") 156 fatalIfErr(err) 157 defer f.Close() 158 f.Write(res) 159} 160 161func fatalIfErr(err error) { 162 if err != nil { 163 log.Fatal(err) 164 } 165} 166