1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package packagestest
6
7import (
8	"fmt"
9	"go/token"
10	"io/ioutil"
11	"os"
12	"path/filepath"
13	"reflect"
14	"regexp"
15	"strings"
16
17	"golang.org/x/tools/go/expect"
18	"golang.org/x/tools/go/packages"
19	"golang.org/x/tools/internal/span"
20)
21
22const (
23	markMethod    = "mark"
24	eofIdentifier = "EOF"
25)
26
27// Expect invokes the supplied methods for all expectation notes found in
28// the exported source files.
29//
30// All exported go source files are parsed to collect the expectation
31// notes.
32// See the documentation for expect.Parse for how the notes are collected
33// and parsed.
34//
35// The methods are supplied as a map of name to function, and those functions
36// will be matched against the expectations by name.
37// Notes with no matching function will be skipped, and functions with no
38// matching notes will not be invoked.
39// If there are no registered markers yet, a special pass will be run first
40// which adds any markers declared with @mark(Name, pattern) or @name. These
41// call the Mark method to add the marker to the global set.
42// You can register the "mark" method to override these in your own call to
43// Expect. The bound Mark function is usable directly in your method map, so
44//    exported.Expect(map[string]interface{}{"mark": exported.Mark})
45// replicates the built in behavior.
46//
47// Method invocation
48//
49// When invoking a method the expressions in the parameter list need to be
50// converted to values to be passed to the method.
51// There are a very limited set of types the arguments are allowed to be.
52//   expect.Note : passed the Note instance being evaluated.
53//   string : can be supplied either a string literal or an identifier.
54//   int : can only be supplied an integer literal.
55//   *regexp.Regexp : can only be supplied a regular expression literal
56//   token.Pos : has a file position calculated as described below.
57//   token.Position : has a file position calculated as described below.
58//   expect.Range: has a start and end position as described below.
59//   interface{} : will be passed any value
60//
61// Position calculation
62//
63// There is some extra handling when a parameter is being coerced into a
64// token.Pos, token.Position or Range type argument.
65//
66// If the parameter is an identifier, it will be treated as the name of an
67// marker to look up (as if markers were global variables).
68//
69// If it is a string or regular expression, then it will be passed to
70// expect.MatchBefore to look up a match in the line at which it was declared.
71//
72// It is safe to call this repeatedly with different method sets, but it is
73// not safe to call it concurrently.
74func (e *Exported) Expect(methods map[string]interface{}) error {
75	if err := e.getNotes(); err != nil {
76		return err
77	}
78	if err := e.getMarkers(); err != nil {
79		return err
80	}
81	var err error
82	ms := make(map[string]method, len(methods))
83	for name, f := range methods {
84		mi := method{f: reflect.ValueOf(f)}
85		mi.converters = make([]converter, mi.f.Type().NumIn())
86		for i := 0; i < len(mi.converters); i++ {
87			mi.converters[i], err = e.buildConverter(mi.f.Type().In(i))
88			if err != nil {
89				return fmt.Errorf("invalid method %v: %v", name, err)
90			}
91		}
92		ms[name] = mi
93	}
94	for _, n := range e.notes {
95		if n.Args == nil {
96			// simple identifier form, convert to a call to mark
97			n = &expect.Note{
98				Pos:  n.Pos,
99				Name: markMethod,
100				Args: []interface{}{n.Name, n.Name},
101			}
102		}
103		mi, ok := ms[n.Name]
104		if !ok {
105			continue
106		}
107		params := make([]reflect.Value, len(mi.converters))
108		args := n.Args
109		for i, convert := range mi.converters {
110			params[i], args, err = convert(n, args)
111			if err != nil {
112				return fmt.Errorf("%v: %v", e.ExpectFileSet.Position(n.Pos), err)
113			}
114		}
115		if len(args) > 0 {
116			return fmt.Errorf("%v: unwanted args got %+v extra", e.ExpectFileSet.Position(n.Pos), args)
117		}
118		//TODO: catch the error returned from the method
119		mi.f.Call(params)
120	}
121	return nil
122}
123
124// Range is a type alias for span.Range for backwards compatibility, prefer
125// using span.Range directly.
126type Range = span.Range
127
128// Mark adds a new marker to the known set.
129func (e *Exported) Mark(name string, r Range) {
130	if e.markers == nil {
131		e.markers = make(map[string]span.Range)
132	}
133	e.markers[name] = r
134}
135
136func (e *Exported) getNotes() error {
137	if e.notes != nil {
138		return nil
139	}
140	notes := []*expect.Note{}
141	var dirs []string
142	for _, module := range e.written {
143		for _, filename := range module {
144			dirs = append(dirs, filepath.Dir(filename))
145		}
146	}
147	for filename := range e.Config.Overlay {
148		dirs = append(dirs, filepath.Dir(filename))
149	}
150	pkgs, err := packages.Load(e.Config, dirs...)
151	if err != nil {
152		return fmt.Errorf("unable to load packages for directories %s: %v", dirs, err)
153	}
154	seen := make(map[token.Position]struct{})
155	for _, pkg := range pkgs {
156		for _, filename := range pkg.GoFiles {
157			content, err := e.FileContents(filename)
158			if err != nil {
159				return err
160			}
161			l, err := expect.Parse(e.ExpectFileSet, filename, content)
162			if err != nil {
163				return fmt.Errorf("failed to extract expectations: %v", err)
164			}
165			for _, note := range l {
166				pos := e.ExpectFileSet.Position(note.Pos)
167				if _, ok := seen[pos]; ok {
168					continue
169				}
170				notes = append(notes, note)
171				seen[pos] = struct{}{}
172			}
173		}
174	}
175	if _, ok := e.written[e.primary]; !ok {
176		e.notes = notes
177		return nil
178	}
179	// Check go.mod markers regardless of mode, we need to do this so that our marker count
180	// matches the counts in the summary.txt.golden file for the test directory.
181	if gomod, found := e.written[e.primary]["go.mod"]; found {
182		// If we are in Modules mode, then we need to check the contents of the go.mod.temp.
183		if e.Exporter == Modules {
184			gomod += ".temp"
185		}
186		l, err := goModMarkers(e, gomod)
187		if err != nil {
188			return fmt.Errorf("failed to extract expectations for go.mod: %v", err)
189		}
190		notes = append(notes, l...)
191	}
192	e.notes = notes
193	return nil
194}
195
196func goModMarkers(e *Exported, gomod string) ([]*expect.Note, error) {
197	if _, err := os.Stat(gomod); os.IsNotExist(err) {
198		// If there is no go.mod file, we want to be able to continue.
199		return nil, nil
200	}
201	content, err := e.FileContents(gomod)
202	if err != nil {
203		return nil, err
204	}
205	if e.Exporter == GOPATH {
206		return expect.Parse(e.ExpectFileSet, gomod, content)
207	}
208	gomod = strings.TrimSuffix(gomod, ".temp")
209	// If we are in Modules mode, copy the original contents file back into go.mod
210	if err := ioutil.WriteFile(gomod, content, 0644); err != nil {
211		return nil, nil
212	}
213	return expect.Parse(e.ExpectFileSet, gomod, content)
214}
215
216func (e *Exported) getMarkers() error {
217	if e.markers != nil {
218		return nil
219	}
220	// set markers early so that we don't call getMarkers again from Expect
221	e.markers = make(map[string]span.Range)
222	return e.Expect(map[string]interface{}{
223		markMethod: e.Mark,
224	})
225}
226
227var (
228	noteType       = reflect.TypeOf((*expect.Note)(nil))
229	identifierType = reflect.TypeOf(expect.Identifier(""))
230	posType        = reflect.TypeOf(token.Pos(0))
231	positionType   = reflect.TypeOf(token.Position{})
232	rangeType      = reflect.TypeOf(span.Range{})
233	spanType       = reflect.TypeOf(span.Span{})
234	fsetType       = reflect.TypeOf((*token.FileSet)(nil))
235	regexType      = reflect.TypeOf((*regexp.Regexp)(nil))
236	exportedType   = reflect.TypeOf((*Exported)(nil))
237)
238
239// converter converts from a marker's argument parsed from the comment to
240// reflect values passed to the method during Invoke.
241// It takes the args remaining, and returns the args it did not consume.
242// This allows a converter to consume 0 args for well known types, or multiple
243// args for compound types.
244type converter func(*expect.Note, []interface{}) (reflect.Value, []interface{}, error)
245
246// method is used to track information about Invoke methods that is expensive to
247// calculate so that we can work it out once rather than per marker.
248type method struct {
249	f          reflect.Value // the reflect value of the passed in method
250	converters []converter   // the parameter converters for the method
251}
252
253// buildConverter works out what function should be used to go from an ast expressions to a reflect
254// value of the type expected by a method.
255// It is called when only the target type is know, it returns converters that are flexible across
256// all supported expression types for that target type.
257func (e *Exported) buildConverter(pt reflect.Type) (converter, error) {
258	switch {
259	case pt == noteType:
260		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
261			return reflect.ValueOf(n), args, nil
262		}, nil
263	case pt == fsetType:
264		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
265			return reflect.ValueOf(e.ExpectFileSet), args, nil
266		}, nil
267	case pt == exportedType:
268		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
269			return reflect.ValueOf(e), args, nil
270		}, nil
271	case pt == posType:
272		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
273			r, remains, err := e.rangeConverter(n, args)
274			if err != nil {
275				return reflect.Value{}, nil, err
276			}
277			return reflect.ValueOf(r.Start), remains, nil
278		}, nil
279	case pt == positionType:
280		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
281			r, remains, err := e.rangeConverter(n, args)
282			if err != nil {
283				return reflect.Value{}, nil, err
284			}
285			return reflect.ValueOf(e.ExpectFileSet.Position(r.Start)), remains, nil
286		}, nil
287	case pt == rangeType:
288		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
289			r, remains, err := e.rangeConverter(n, args)
290			if err != nil {
291				return reflect.Value{}, nil, err
292			}
293			return reflect.ValueOf(r), remains, nil
294		}, nil
295	case pt == spanType:
296		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
297			r, remains, err := e.rangeConverter(n, args)
298			if err != nil {
299				return reflect.Value{}, nil, err
300			}
301			spn, err := r.Span()
302			if err != nil {
303				return reflect.Value{}, nil, err
304			}
305			return reflect.ValueOf(spn), remains, nil
306		}, nil
307	case pt == identifierType:
308		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
309			if len(args) < 1 {
310				return reflect.Value{}, nil, fmt.Errorf("missing argument")
311			}
312			arg := args[0]
313			args = args[1:]
314			switch arg := arg.(type) {
315			case expect.Identifier:
316				return reflect.ValueOf(arg), args, nil
317			default:
318				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to string", arg)
319			}
320		}, nil
321
322	case pt == regexType:
323		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
324			if len(args) < 1 {
325				return reflect.Value{}, nil, fmt.Errorf("missing argument")
326			}
327			arg := args[0]
328			args = args[1:]
329			if _, ok := arg.(*regexp.Regexp); !ok {
330				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to *regexp.Regexp", arg)
331			}
332			return reflect.ValueOf(arg), args, nil
333		}, nil
334
335	case pt.Kind() == reflect.String:
336		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
337			if len(args) < 1 {
338				return reflect.Value{}, nil, fmt.Errorf("missing argument")
339			}
340			arg := args[0]
341			args = args[1:]
342			switch arg := arg.(type) {
343			case expect.Identifier:
344				return reflect.ValueOf(string(arg)), args, nil
345			case string:
346				return reflect.ValueOf(arg), args, nil
347			default:
348				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to string", arg)
349			}
350		}, nil
351	case pt.Kind() == reflect.Int64:
352		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
353			if len(args) < 1 {
354				return reflect.Value{}, nil, fmt.Errorf("missing argument")
355			}
356			arg := args[0]
357			args = args[1:]
358			switch arg := arg.(type) {
359			case int64:
360				return reflect.ValueOf(arg), args, nil
361			default:
362				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to int", arg)
363			}
364		}, nil
365	case pt.Kind() == reflect.Bool:
366		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
367			if len(args) < 1 {
368				return reflect.Value{}, nil, fmt.Errorf("missing argument")
369			}
370			arg := args[0]
371			args = args[1:]
372			b, ok := arg.(bool)
373			if !ok {
374				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to bool", arg)
375			}
376			return reflect.ValueOf(b), args, nil
377		}, nil
378	case pt.Kind() == reflect.Slice:
379		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
380			converter, err := e.buildConverter(pt.Elem())
381			if err != nil {
382				return reflect.Value{}, nil, err
383			}
384			result := reflect.MakeSlice(reflect.SliceOf(pt.Elem()), 0, len(args))
385			for range args {
386				value, remains, err := converter(n, args)
387				if err != nil {
388					return reflect.Value{}, nil, err
389				}
390				result = reflect.Append(result, value)
391				args = remains
392			}
393			return result, args, nil
394		}, nil
395	default:
396		if pt.Kind() == reflect.Interface && pt.NumMethod() == 0 {
397			return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
398				if len(args) < 1 {
399					return reflect.Value{}, nil, fmt.Errorf("missing argument")
400				}
401				return reflect.ValueOf(args[0]), args[1:], nil
402			}, nil
403		}
404		return nil, fmt.Errorf("param has unexpected type %v (kind %v)", pt, pt.Kind())
405	}
406}
407
408func (e *Exported) rangeConverter(n *expect.Note, args []interface{}) (span.Range, []interface{}, error) {
409	if len(args) < 1 {
410		return span.Range{}, nil, fmt.Errorf("missing argument")
411	}
412	arg := args[0]
413	args = args[1:]
414	switch arg := arg.(type) {
415	case expect.Identifier:
416		// handle the special identifiers
417		switch arg {
418		case eofIdentifier:
419			// end of file identifier, look up the current file
420			f := e.ExpectFileSet.File(n.Pos)
421			eof := f.Pos(f.Size())
422			return span.Range{FileSet: e.ExpectFileSet, Start: eof, End: token.NoPos}, args, nil
423		default:
424			// look up an marker by name
425			mark, ok := e.markers[string(arg)]
426			if !ok {
427				return span.Range{}, nil, fmt.Errorf("cannot find marker %v", arg)
428			}
429			return mark, args, nil
430		}
431	case string:
432		start, end, err := expect.MatchBefore(e.ExpectFileSet, e.FileContents, n.Pos, arg)
433		if err != nil {
434			return span.Range{}, nil, err
435		}
436		if start == token.NoPos {
437			return span.Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.ExpectFileSet.Position(n.Pos), arg)
438		}
439		return span.Range{FileSet: e.ExpectFileSet, Start: start, End: end}, args, nil
440	case *regexp.Regexp:
441		start, end, err := expect.MatchBefore(e.ExpectFileSet, e.FileContents, n.Pos, arg)
442		if err != nil {
443			return span.Range{}, nil, err
444		}
445		if start == token.NoPos {
446			return span.Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.ExpectFileSet.Position(n.Pos), arg)
447		}
448		return span.Range{FileSet: e.ExpectFileSet, Start: start, End: end}, args, nil
449	default:
450		return span.Range{}, nil, fmt.Errorf("cannot convert %v to pos", arg)
451	}
452}
453