1package main 2 3import ( 4 "fmt" 5 "go/ast" 6 "go/parser" 7 "go/token" 8 "io/ioutil" 9 "os" 10 "path/filepath" 11 "reflect" 12 "sort" 13 "testing" 14 "text/template" 15 16 "github.com/tinylib/msgp/gen" 17) 18 19// When stuff's going wrong, you'll be glad this is here! 20const debugTemp = false 21 22// Ensure that consistent identifiers are generated on a per-method basis by msgp. 23// 24// Also ensure that no duplicate identifiers appear in a method. 25// 26// structs are currently processed alphabetically by msgp. this test relies on 27// that property. 28// 29func TestIssue185Idents(t *testing.T) { 30 var identCases = []struct { 31 tpl *template.Template 32 expectedChanged []string 33 }{ 34 {tpl: issue185IdentsTpl, expectedChanged: []string{"Test1"}}, 35 {tpl: issue185ComplexIdentsTpl, expectedChanged: []string{"Test2"}}, 36 } 37 38 methods := []string{"DecodeMsg", "EncodeMsg", "Msgsize", "MarshalMsg", "UnmarshalMsg"} 39 40 for idx, identCase := range identCases { 41 // generate the code, extract the generated variable names, mapped to function name 42 var tplData issue185TplData 43 varsBefore, err := loadVars(identCase.tpl, tplData) 44 if err != nil { 45 t.Fatalf("%d: could not extract before vars: %v", idx, err) 46 } 47 48 // regenerate the code with extra field(s), extract the generated variable 49 // names, mapped to function name 50 tplData.Extra = true 51 varsAfter, err := loadVars(identCase.tpl, tplData) 52 if err != nil { 53 t.Fatalf("%d: could not extract after vars: %v", idx, err) 54 } 55 56 // ensure that all declared variable names inside each of the methods we 57 // expect to change have actually changed 58 for _, stct := range identCase.expectedChanged { 59 for _, method := range methods { 60 fn := fmt.Sprintf("%s.%s", stct, method) 61 62 bv, av := varsBefore.Value(fn), varsAfter.Value(fn) 63 if len(bv) > 0 && len(av) > 0 && reflect.DeepEqual(bv, av) { 64 t.Fatalf("%d vars identical! expected vars to change for %s", idx, fn) 65 } 66 delete(varsBefore, fn) 67 delete(varsAfter, fn) 68 } 69 } 70 71 // all of the remaining keys should not have changed 72 for bmethod, bvars := range varsBefore { 73 avars := varsAfter.Value(bmethod) 74 75 if !reflect.DeepEqual(bvars, avars) { 76 t.Fatalf("%d: vars changed! expected vars identical for %s", idx, bmethod) 77 } 78 delete(varsBefore, bmethod) 79 delete(varsAfter, bmethod) 80 } 81 82 if len(varsBefore) > 0 || len(varsAfter) > 0 { 83 t.Fatalf("%d: unexpected methods remaining", idx) 84 } 85 } 86} 87 88type issue185TplData struct { 89 Extra bool 90} 91 92func TestIssue185Overlap(t *testing.T) { 93 var overlapCases = []struct { 94 tpl *template.Template 95 data issue185TplData 96 }{ 97 {tpl: issue185IdentsTpl, data: issue185TplData{Extra: false}}, 98 {tpl: issue185IdentsTpl, data: issue185TplData{Extra: true}}, 99 {tpl: issue185ComplexIdentsTpl, data: issue185TplData{Extra: false}}, 100 {tpl: issue185ComplexIdentsTpl, data: issue185TplData{Extra: true}}, 101 } 102 103 for idx, o := range overlapCases { 104 // regenerate the code with extra field(s), extract the generated variable 105 // names, mapped to function name 106 mvars, err := loadVars(o.tpl, o.data) 107 if err != nil { 108 t.Fatalf("%d: could not extract after vars: %v", idx, err) 109 } 110 111 identCnt := 0 112 for fn, vars := range mvars { 113 sort.Strings(vars) 114 115 // Loose sanity check to make sure the tests expectations aren't broken. 116 // If the prefix ever changes, this needs to change. 117 for _, v := range vars { 118 if v[0] == 'z' { 119 identCnt++ 120 } 121 } 122 123 for i := 0; i < len(vars)-1; i++ { 124 if vars[i] == vars[i+1] { 125 t.Fatalf("%d: duplicate var %s in function %s", idx, vars[i], fn) 126 } 127 } 128 } 129 130 // one last sanity check: if there aren't any vars that start with 'z', 131 // this test's expectations are unsatisfiable. 132 if identCnt == 0 { 133 t.Fatalf("%d: no generated identifiers found", idx) 134 } 135 } 136} 137 138func loadVars(tpl *template.Template, tplData interface{}) (vars extractedVars, err error) { 139 tempDir, err := ioutil.TempDir("", "msgp-") 140 if err != nil { 141 err = fmt.Errorf("could not create temp dir: %v", err) 142 return 143 } 144 145 if !debugTemp { 146 defer os.RemoveAll(tempDir) 147 } else { 148 fmt.Println(tempDir) 149 } 150 tfile := filepath.Join(tempDir, "msg.go") 151 genFile := newFilename(tfile, "") 152 153 if err = goGenerateTpl(tempDir, tfile, tpl, tplData); err != nil { 154 err = fmt.Errorf("could not generate code: %v", err) 155 return 156 } 157 158 vars, err = extractVars(genFile) 159 if err != nil { 160 err = fmt.Errorf("could not extract after vars: %v", err) 161 return 162 } 163 164 return 165} 166 167type varVisitor struct { 168 vars []string 169 fset *token.FileSet 170} 171 172func (v *varVisitor) Visit(node ast.Node) (w ast.Visitor) { 173 gen, ok := node.(*ast.GenDecl) 174 if !ok { 175 return v 176 } 177 for _, spec := range gen.Specs { 178 if vspec, ok := spec.(*ast.ValueSpec); ok { 179 for _, n := range vspec.Names { 180 v.vars = append(v.vars, n.Name) 181 } 182 } 183 } 184 return v 185} 186 187type extractedVars map[string][]string 188 189func (e extractedVars) Value(key string) []string { 190 if v, ok := e[key]; ok { 191 return v 192 } 193 panic(fmt.Errorf("unknown key %s", key)) 194} 195 196func extractVars(file string) (extractedVars, error) { 197 fset := token.NewFileSet() 198 199 f, err := parser.ParseFile(fset, file, nil, 0) 200 if err != nil { 201 return nil, err 202 } 203 204 vars := make(map[string][]string) 205 for _, d := range f.Decls { 206 switch d := d.(type) { 207 case *ast.FuncDecl: 208 sn := "" 209 switch rt := d.Recv.List[0].Type.(type) { 210 case *ast.Ident: 211 sn = rt.Name 212 case *ast.StarExpr: 213 sn = rt.X.(*ast.Ident).Name 214 default: 215 panic("unknown receiver type") 216 } 217 218 key := fmt.Sprintf("%s.%s", sn, d.Name.Name) 219 vis := &varVisitor{fset: fset} 220 ast.Walk(vis, d.Body) 221 vars[key] = vis.vars 222 } 223 } 224 return vars, nil 225} 226 227func goGenerateTpl(cwd, tfile string, tpl *template.Template, tplData interface{}) error { 228 outf, err := os.OpenFile(tfile, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0600) 229 if err != nil { 230 return err 231 } 232 defer outf.Close() 233 234 if err := tpl.Execute(outf, tplData); err != nil { 235 return err 236 } 237 238 mode := gen.Encode | gen.Decode | gen.Size | gen.Marshal | gen.Unmarshal 239 240 return Run(tfile, mode, false) 241} 242 243var issue185IdentsTpl = template.Must(template.New("").Parse(` 244package issue185 245 246//go:generate msgp 247 248type Test1 struct { 249 Foo string 250 Bar string 251 {{ if .Extra }}Baz []string{{ end }} 252 Qux string 253} 254 255type Test2 struct { 256 Foo string 257 Bar string 258 Baz string 259} 260`)) 261 262var issue185ComplexIdentsTpl = template.Must(template.New("").Parse(` 263package issue185 264 265//go:generate msgp 266 267type Test1 struct { 268 Foo string 269 Bar string 270 Baz string 271} 272 273type Test2 struct { 274 Foo string 275 Bar string 276 Baz []string 277 Qux map[string]string 278 Yep map[string]map[string]string 279 Quack struct { 280 Quack struct { 281 Quack struct { 282 {{ if .Extra }}Extra []string{{ end }} 283 Quack string 284 } 285 } 286 } 287 Nup struct { 288 Foo string 289 Bar string 290 Baz []string 291 Qux map[string]string 292 Yep map[string]map[string]string 293 } 294 Ding struct { 295 Dong struct { 296 Dung struct { 297 Thing string 298 } 299 } 300 } 301} 302 303type Test3 struct { 304 Foo string 305 Bar string 306 Baz string 307} 308`)) 309