1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package packagestest
6
7import (
8	"fmt"
9	"go/token"
10	"path/filepath"
11	"reflect"
12	"regexp"
13
14	"golang.org/x/tools/go/expect"
15	"golang.org/x/tools/go/packages"
16	"golang.org/x/tools/internal/span"
17)
18
19const (
20	markMethod    = "mark"
21	eofIdentifier = "EOF"
22)
23
24// Expect invokes the supplied methods for all expectation notes found in
25// the exported source files.
26//
27// All exported go source files are parsed to collect the expectation
28// notes.
29// See the documentation for expect.Parse for how the notes are collected
30// and parsed.
31//
32// The methods are supplied as a map of name to function, and those functions
33// will be matched against the expectations by name.
34// Notes with no matching function will be skipped, and functions with no
35// matching notes will not be invoked.
36// If there are no registered markers yet, a special pass will be run first
37// which adds any markers declared with @mark(Name, pattern) or @name. These
38// call the Mark method to add the marker to the global set.
39// You can register the "mark" method to override these in your own call to
40// Expect. The bound Mark function is usable directly in your method map, so
41//    exported.Expect(map[string]interface{}{"mark": exported.Mark})
42// replicates the built in behavior.
43//
44// Method invocation
45//
46// When invoking a method the expressions in the parameter list need to be
47// converted to values to be passed to the method.
48// There are a very limited set of types the arguments are allowed to be.
49//   expect.Note : passed the Note instance being evaluated.
50//   string : can be supplied either a string literal or an identifier.
51//   int : can only be supplied an integer literal.
52//   *regexp.Regexp : can only be supplied a regular expression literal
53//   token.Pos : has a file position calculated as described below.
54//   token.Position : has a file position calculated as described below.
55//   expect.Range: has a start and end position as described below.
56//   interface{} : will be passed any value
57//
58// Position calculation
59//
60// There is some extra handling when a parameter is being coerced into a
61// token.Pos, token.Position or Range type argument.
62//
63// If the parameter is an identifier, it will be treated as the name of an
64// marker to look up (as if markers were global variables).
65//
66// If it is a string or regular expression, then it will be passed to
67// expect.MatchBefore to look up a match in the line at which it was declared.
68//
69// It is safe to call this repeatedly with different method sets, but it is
70// not safe to call it concurrently.
71func (e *Exported) Expect(methods map[string]interface{}) error {
72	if err := e.getNotes(); err != nil {
73		return err
74	}
75	if err := e.getMarkers(); err != nil {
76		return err
77	}
78	var err error
79	ms := make(map[string]method, len(methods))
80	for name, f := range methods {
81		mi := method{f: reflect.ValueOf(f)}
82		mi.converters = make([]converter, mi.f.Type().NumIn())
83		for i := 0; i < len(mi.converters); i++ {
84			mi.converters[i], err = e.buildConverter(mi.f.Type().In(i))
85			if err != nil {
86				return fmt.Errorf("invalid method %v: %v", name, err)
87			}
88		}
89		ms[name] = mi
90	}
91	for _, n := range e.notes {
92		if n.Args == nil {
93			// simple identifier form, convert to a call to mark
94			n = &expect.Note{
95				Pos:  n.Pos,
96				Name: markMethod,
97				Args: []interface{}{n.Name, n.Name},
98			}
99		}
100		mi, ok := ms[n.Name]
101		if !ok {
102			continue
103		}
104		params := make([]reflect.Value, len(mi.converters))
105		args := n.Args
106		for i, convert := range mi.converters {
107			params[i], args, err = convert(n, args)
108			if err != nil {
109				return fmt.Errorf("%v: %v", e.ExpectFileSet.Position(n.Pos), err)
110			}
111		}
112		if len(args) > 0 {
113			return fmt.Errorf("%v: unwanted args got %+v extra", e.ExpectFileSet.Position(n.Pos), args)
114		}
115		//TODO: catch the error returned from the method
116		mi.f.Call(params)
117	}
118	return nil
119}
120
121// Range is a type alias for span.Range for backwards compatability, prefer
122// using span.Range directly.
123type Range = span.Range
124
125// Mark adds a new marker to the known set.
126func (e *Exported) Mark(name string, r Range) {
127	if e.markers == nil {
128		e.markers = make(map[string]span.Range)
129	}
130	e.markers[name] = r
131}
132
133func (e *Exported) getNotes() error {
134	if e.notes != nil {
135		return nil
136	}
137	notes := []*expect.Note{}
138	var dirs []string
139	for _, module := range e.written {
140		for _, filename := range module {
141			dirs = append(dirs, filepath.Dir(filename))
142		}
143	}
144	for filename := range e.Config.Overlay {
145		dirs = append(dirs, filepath.Dir(filename))
146	}
147	pkgs, err := packages.Load(e.Config, dirs...)
148	if err != nil {
149		return fmt.Errorf("unable to load packages for directories %s: %v", dirs, err)
150	}
151	for _, pkg := range pkgs {
152		for _, filename := range pkg.GoFiles {
153			content, err := e.FileContents(filename)
154			if err != nil {
155				return err
156			}
157			l, err := expect.Parse(e.ExpectFileSet, filename, content)
158			if err != nil {
159				return fmt.Errorf("Failed to extract expectations: %v", err)
160			}
161			notes = append(notes, l...)
162		}
163	}
164	e.notes = notes
165	return nil
166}
167
168func (e *Exported) getMarkers() error {
169	if e.markers != nil {
170		return nil
171	}
172	// set markers early so that we don't call getMarkers again from Expect
173	e.markers = make(map[string]span.Range)
174	return e.Expect(map[string]interface{}{
175		markMethod: e.Mark,
176	})
177}
178
179var (
180	noteType       = reflect.TypeOf((*expect.Note)(nil))
181	identifierType = reflect.TypeOf(expect.Identifier(""))
182	posType        = reflect.TypeOf(token.Pos(0))
183	positionType   = reflect.TypeOf(token.Position{})
184	rangeType      = reflect.TypeOf(span.Range{})
185	spanType       = reflect.TypeOf(span.Span{})
186	fsetType       = reflect.TypeOf((*token.FileSet)(nil))
187	regexType      = reflect.TypeOf((*regexp.Regexp)(nil))
188	exportedType   = reflect.TypeOf((*Exported)(nil))
189)
190
191// converter converts from a marker's argument parsed from the comment to
192// reflect values passed to the method during Invoke.
193// It takes the args remaining, and returns the args it did not consume.
194// This allows a converter to consume 0 args for well known types, or multiple
195// args for compound types.
196type converter func(*expect.Note, []interface{}) (reflect.Value, []interface{}, error)
197
198// method is used to track information about Invoke methods that is expensive to
199// calculate so that we can work it out once rather than per marker.
200type method struct {
201	f          reflect.Value // the reflect value of the passed in method
202	converters []converter   // the parameter converters for the method
203}
204
205// buildConverter works out what function should be used to go from an ast expressions to a reflect
206// value of the type expected by a method.
207// It is called when only the target type is know, it returns converters that are flexible across
208// all supported expression types for that target type.
209func (e *Exported) buildConverter(pt reflect.Type) (converter, error) {
210	switch {
211	case pt == noteType:
212		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
213			return reflect.ValueOf(n), args, nil
214		}, nil
215	case pt == fsetType:
216		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
217			return reflect.ValueOf(e.ExpectFileSet), args, nil
218		}, nil
219	case pt == exportedType:
220		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
221			return reflect.ValueOf(e), args, nil
222		}, nil
223	case pt == posType:
224		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
225			r, remains, err := e.rangeConverter(n, args)
226			if err != nil {
227				return reflect.Value{}, nil, err
228			}
229			return reflect.ValueOf(r.Start), remains, nil
230		}, nil
231	case pt == positionType:
232		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
233			r, remains, err := e.rangeConverter(n, args)
234			if err != nil {
235				return reflect.Value{}, nil, err
236			}
237			return reflect.ValueOf(e.ExpectFileSet.Position(r.Start)), remains, nil
238		}, nil
239	case pt == rangeType:
240		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
241			r, remains, err := e.rangeConverter(n, args)
242			if err != nil {
243				return reflect.Value{}, nil, err
244			}
245			return reflect.ValueOf(r), remains, nil
246		}, nil
247	case pt == spanType:
248		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
249			r, remains, err := e.rangeConverter(n, args)
250			if err != nil {
251				return reflect.Value{}, nil, err
252			}
253			spn, err := r.Span()
254			if err != nil {
255				return reflect.Value{}, nil, err
256			}
257			return reflect.ValueOf(spn), remains, nil
258		}, nil
259	case pt == identifierType:
260		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
261			if len(args) < 1 {
262				return reflect.Value{}, nil, fmt.Errorf("missing argument")
263			}
264			arg := args[0]
265			args = args[1:]
266			switch arg := arg.(type) {
267			case expect.Identifier:
268				return reflect.ValueOf(arg), args, nil
269			default:
270				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to string", arg)
271			}
272		}, nil
273
274	case pt == regexType:
275		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
276			if len(args) < 1 {
277				return reflect.Value{}, nil, fmt.Errorf("missing argument")
278			}
279			arg := args[0]
280			args = args[1:]
281			if _, ok := arg.(*regexp.Regexp); !ok {
282				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to *regexp.Regexp", arg)
283			}
284			return reflect.ValueOf(arg), args, nil
285		}, nil
286
287	case pt.Kind() == reflect.String:
288		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
289			if len(args) < 1 {
290				return reflect.Value{}, nil, fmt.Errorf("missing argument")
291			}
292			arg := args[0]
293			args = args[1:]
294			switch arg := arg.(type) {
295			case expect.Identifier:
296				return reflect.ValueOf(string(arg)), args, nil
297			case string:
298				return reflect.ValueOf(arg), args, nil
299			default:
300				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to string", arg)
301			}
302		}, nil
303	case pt.Kind() == reflect.Int64:
304		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
305			if len(args) < 1 {
306				return reflect.Value{}, nil, fmt.Errorf("missing argument")
307			}
308			arg := args[0]
309			args = args[1:]
310			switch arg := arg.(type) {
311			case int64:
312				return reflect.ValueOf(arg), args, nil
313			default:
314				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to int", arg)
315			}
316		}, nil
317	case pt.Kind() == reflect.Bool:
318		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
319			if len(args) < 1 {
320				return reflect.Value{}, nil, fmt.Errorf("missing argument")
321			}
322			arg := args[0]
323			args = args[1:]
324			b, ok := arg.(bool)
325			if !ok {
326				return reflect.Value{}, nil, fmt.Errorf("cannot convert %v to bool", arg)
327			}
328			return reflect.ValueOf(b), args, nil
329		}, nil
330	case pt.Kind() == reflect.Slice:
331		return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
332			converter, err := e.buildConverter(pt.Elem())
333			if err != nil {
334				return reflect.Value{}, nil, err
335			}
336			result := reflect.MakeSlice(reflect.SliceOf(pt.Elem()), 0, len(args))
337			for range args {
338				value, remains, err := converter(n, args)
339				if err != nil {
340					return reflect.Value{}, nil, err
341				}
342				result = reflect.Append(result, value)
343				args = remains
344			}
345			return result, args, nil
346		}, nil
347	default:
348		if pt.Kind() == reflect.Interface && pt.NumMethod() == 0 {
349			return func(n *expect.Note, args []interface{}) (reflect.Value, []interface{}, error) {
350				if len(args) < 1 {
351					return reflect.Value{}, nil, fmt.Errorf("missing argument")
352				}
353				return reflect.ValueOf(args[0]), args[1:], nil
354			}, nil
355		}
356		return nil, fmt.Errorf("param has unexpected type %v (kind %v)", pt, pt.Kind())
357	}
358}
359
360func (e *Exported) rangeConverter(n *expect.Note, args []interface{}) (span.Range, []interface{}, error) {
361	if len(args) < 1 {
362		return span.Range{}, nil, fmt.Errorf("missing argument")
363	}
364	arg := args[0]
365	args = args[1:]
366	switch arg := arg.(type) {
367	case expect.Identifier:
368		// handle the special identifiers
369		switch arg {
370		case eofIdentifier:
371			// end of file identifier, look up the current file
372			f := e.ExpectFileSet.File(n.Pos)
373			eof := f.Pos(f.Size())
374			return span.Range{FileSet: e.ExpectFileSet, Start: eof, End: token.NoPos}, args, nil
375		default:
376			// look up an marker by name
377			mark, ok := e.markers[string(arg)]
378			if !ok {
379				return span.Range{}, nil, fmt.Errorf("cannot find marker %v", arg)
380			}
381			return mark, args, nil
382		}
383	case string:
384		start, end, err := expect.MatchBefore(e.ExpectFileSet, e.FileContents, n.Pos, arg)
385		if err != nil {
386			return span.Range{}, nil, err
387		}
388		if start == token.NoPos {
389			return span.Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.ExpectFileSet.Position(n.Pos), arg)
390		}
391		return span.Range{FileSet: e.ExpectFileSet, Start: start, End: end}, args, nil
392	case *regexp.Regexp:
393		start, end, err := expect.MatchBefore(e.ExpectFileSet, e.FileContents, n.Pos, arg)
394		if err != nil {
395			return span.Range{}, nil, err
396		}
397		if start == token.NoPos {
398			return span.Range{}, nil, fmt.Errorf("%v: pattern %s did not match", e.ExpectFileSet.Position(n.Pos), arg)
399		}
400		return span.Range{FileSet: e.ExpectFileSet, Start: start, End: end}, args, nil
401	default:
402		return span.Range{}, nil, fmt.Errorf("cannot convert %v to pos", arg)
403	}
404}
405