1package typematch
2
3import (
4	"fmt"
5	"go/ast"
6	"go/parser"
7	"go/token"
8	"go/types"
9	"strconv"
10	"strings"
11
12	"github.com/quasilyte/go-ruleguard/internal/xtypes"
13)
14
15//go:generate stringer -type=patternOp
16type patternOp int
17
18const (
19	opBuiltinType patternOp = iota
20	opPointer
21	opVar
22	opVarSeq
23	opSlice
24	opArray
25	opMap
26	opChan
27	opFunc
28	opStructNoSeq
29	opStruct
30	opNamed
31)
32
33type Pattern struct {
34	typeMatches  map[string]types.Type
35	int64Matches map[string]int64
36
37	root *pattern
38}
39
40type pattern struct {
41	value interface{}
42	op    patternOp
43	subs  []*pattern
44}
45
46func (pat pattern) String() string {
47	if len(pat.subs) == 0 {
48		return fmt.Sprintf("<%s %#v>", pat.op, pat.value)
49	}
50	parts := make([]string, len(pat.subs))
51	for i, sub := range pat.subs {
52		parts[i] = sub.String()
53	}
54	return fmt.Sprintf("<%s %#v (%s)>", pat.op, pat.value, strings.Join(parts, ", "))
55}
56
57type ImportsTab struct {
58	imports []map[string]string
59}
60
61func NewImportsTab(initial map[string]string) *ImportsTab {
62	return &ImportsTab{imports: []map[string]string{initial}}
63}
64
65func (itab *ImportsTab) Lookup(pkgName string) (string, bool) {
66	for i := len(itab.imports) - 1; i >= 0; i-- {
67		pkgPath, ok := itab.imports[i][pkgName]
68		if ok {
69			return pkgPath, true
70		}
71	}
72	return "", false
73}
74
75func (itab *ImportsTab) Load(pkgName, pkgPath string) {
76	itab.imports[len(itab.imports)-1][pkgName] = pkgPath
77}
78
79func (itab *ImportsTab) EnterScope() {
80	itab.imports = append(itab.imports, map[string]string{})
81}
82
83func (itab *ImportsTab) LeaveScope() {
84	itab.imports = itab.imports[:len(itab.imports)-1]
85}
86
87type Context struct {
88	Itab *ImportsTab
89}
90
91const (
92	varPrefix    = `ᐸvarᐳ`
93	varSeqPrefix = `ᐸvar_seqᐳ`
94)
95
96func Parse(ctx *Context, s string) (*Pattern, error) {
97	noDollars := strings.ReplaceAll(s, "$*", varSeqPrefix)
98	noDollars = strings.ReplaceAll(noDollars, "$", varPrefix)
99	n, err := parser.ParseExpr(noDollars)
100	if err != nil {
101		return nil, err
102	}
103	root := parseExpr(ctx, n)
104	if root == nil {
105		return nil, fmt.Errorf("can't convert %s type expression", s)
106	}
107	p := &Pattern{
108		typeMatches:  map[string]types.Type{},
109		int64Matches: map[string]int64{},
110		root:         root,
111	}
112	return p, nil
113}
114
115var (
116	builtinTypeByName = map[string]types.Type{
117		"bool":       types.Typ[types.Bool],
118		"int":        types.Typ[types.Int],
119		"int8":       types.Typ[types.Int8],
120		"int16":      types.Typ[types.Int16],
121		"int32":      types.Typ[types.Int32],
122		"int64":      types.Typ[types.Int64],
123		"uint":       types.Typ[types.Uint],
124		"uint8":      types.Typ[types.Uint8],
125		"uint16":     types.Typ[types.Uint16],
126		"uint32":     types.Typ[types.Uint32],
127		"uint64":     types.Typ[types.Uint64],
128		"uintptr":    types.Typ[types.Uintptr],
129		"float32":    types.Typ[types.Float32],
130		"float64":    types.Typ[types.Float64],
131		"complex64":  types.Typ[types.Complex64],
132		"complex128": types.Typ[types.Complex128],
133		"string":     types.Typ[types.String],
134
135		"error": types.Universe.Lookup("error").Type(),
136
137		// Aliases.
138		"byte": types.Typ[types.Uint8],
139		"rune": types.Typ[types.Int32],
140	}
141
142	efaceType = types.NewInterfaceType(nil, nil)
143)
144
145func parseExpr(ctx *Context, e ast.Expr) *pattern {
146	switch e := e.(type) {
147	case *ast.Ident:
148		basic, ok := builtinTypeByName[e.Name]
149		if ok {
150			return &pattern{op: opBuiltinType, value: basic}
151		}
152		if strings.HasPrefix(e.Name, varPrefix) {
153			name := strings.TrimPrefix(e.Name, varPrefix)
154			return &pattern{op: opVar, value: name}
155		}
156		if strings.HasPrefix(e.Name, varSeqPrefix) {
157			name := strings.TrimPrefix(e.Name, varSeqPrefix)
158			// Only unnamed seq are supported right now.
159			if name == "_" {
160				return &pattern{op: opVarSeq, value: name}
161			}
162		}
163
164	case *ast.SelectorExpr:
165		pkg, ok := e.X.(*ast.Ident)
166		if !ok {
167			return nil
168		}
169		pkgPath, ok := ctx.Itab.Lookup(pkg.Name)
170		if !ok {
171			return nil
172		}
173		return &pattern{op: opNamed, value: [2]string{pkgPath, e.Sel.Name}}
174
175	case *ast.StarExpr:
176		elem := parseExpr(ctx, e.X)
177		if elem == nil {
178			return nil
179		}
180		return &pattern{op: opPointer, subs: []*pattern{elem}}
181
182	case *ast.ArrayType:
183		elem := parseExpr(ctx, e.Elt)
184		if elem == nil {
185			return nil
186		}
187		if e.Len == nil {
188			return &pattern{
189				op:   opSlice,
190				subs: []*pattern{elem},
191			}
192		}
193		if id, ok := e.Len.(*ast.Ident); ok && strings.HasPrefix(id.Name, varPrefix) {
194			name := strings.TrimPrefix(id.Name, varPrefix)
195			return &pattern{
196				op:    opArray,
197				value: name,
198				subs:  []*pattern{elem},
199			}
200		}
201		lit, ok := e.Len.(*ast.BasicLit)
202		if !ok || lit.Kind != token.INT {
203			return nil
204		}
205		length, err := strconv.ParseInt(lit.Value, 10, 64)
206		if err != nil {
207			return nil
208		}
209		return &pattern{
210			op:    opArray,
211			value: length,
212			subs:  []*pattern{elem},
213		}
214
215	case *ast.MapType:
216		keyType := parseExpr(ctx, e.Key)
217		if keyType == nil {
218			return nil
219		}
220		valType := parseExpr(ctx, e.Value)
221		if valType == nil {
222			return nil
223		}
224		return &pattern{
225			op:   opMap,
226			subs: []*pattern{keyType, valType},
227		}
228
229	case *ast.ChanType:
230		valType := parseExpr(ctx, e.Value)
231		if valType == nil {
232			return nil
233		}
234		var dir types.ChanDir
235		switch {
236		case e.Dir&ast.SEND != 0 && e.Dir&ast.RECV != 0:
237			dir = types.SendRecv
238		case e.Dir&ast.SEND != 0:
239			dir = types.SendOnly
240		case e.Dir&ast.RECV != 0:
241			dir = types.RecvOnly
242		default:
243			return nil
244		}
245		return &pattern{
246			op:    opChan,
247			value: dir,
248			subs:  []*pattern{valType},
249		}
250
251	case *ast.ParenExpr:
252		return parseExpr(ctx, e.X)
253
254	case *ast.FuncType:
255		var params []*pattern
256		var results []*pattern
257		if e.Params != nil {
258			for _, field := range e.Params.List {
259				p := parseExpr(ctx, field.Type)
260				if p == nil {
261					return nil
262				}
263				if len(field.Names) != 0 {
264					return nil
265				}
266				params = append(params, p)
267			}
268		}
269		if e.Results != nil {
270			for _, field := range e.Results.List {
271				p := parseExpr(ctx, field.Type)
272				if p == nil {
273					return nil
274				}
275				if len(field.Names) != 0 {
276					return nil
277				}
278				results = append(results, p)
279			}
280		}
281		return &pattern{
282			op:    opFunc,
283			value: len(params),
284			subs:  append(params, results...),
285		}
286
287	case *ast.StructType:
288		hasSeq := false
289		members := make([]*pattern, 0, len(e.Fields.List))
290		for _, field := range e.Fields.List {
291			p := parseExpr(ctx, field.Type)
292			if p == nil {
293				return nil
294			}
295			if len(field.Names) != 0 {
296				return nil
297			}
298			if p.op == opVarSeq {
299				hasSeq = true
300			}
301			members = append(members, p)
302		}
303		op := opStructNoSeq
304		if hasSeq {
305			op = opStruct
306		}
307		return &pattern{
308			op:   op,
309			subs: members,
310		}
311
312	case *ast.InterfaceType:
313		if len(e.Methods.List) == 0 {
314			return &pattern{op: opBuiltinType, value: efaceType}
315		}
316	}
317
318	return nil
319}
320
321// MatchIdentical returns true if the go typ matches pattern p.
322func (p *Pattern) MatchIdentical(typ types.Type) bool {
323	p.reset()
324	return p.matchIdentical(p.root, typ)
325}
326
327func (p *Pattern) reset() {
328	if len(p.int64Matches) != 0 {
329		p.int64Matches = map[string]int64{}
330	}
331	if len(p.typeMatches) != 0 {
332		p.typeMatches = map[string]types.Type{}
333	}
334}
335
336func (p *Pattern) matchIdenticalFielder(subs []*pattern, f fielder) bool {
337	// TODO: do backtracking.
338
339	numFields := f.NumFields()
340	fieldsMatched := 0
341
342	if len(subs) == 0 && numFields != 0 {
343		return false
344	}
345
346	matchAny := false
347
348	i := 0
349	for i < len(subs) {
350		pat := subs[i]
351
352		if pat.op == opVarSeq {
353			matchAny = true
354		}
355
356		fieldsLeft := numFields - fieldsMatched
357		if matchAny {
358			switch {
359			// "Nothing left to match" stop condition.
360			case fieldsLeft == 0:
361				matchAny = false
362				i++
363			// Lookahead for non-greedy matching.
364			case i+1 < len(subs) && p.matchIdentical(subs[i+1], f.Field(fieldsMatched).Type()):
365				matchAny = false
366				i += 2
367				fieldsMatched++
368			default:
369				fieldsMatched++
370			}
371			continue
372		}
373
374		if fieldsLeft == 0 || !p.matchIdentical(pat, f.Field(fieldsMatched).Type()) {
375			return false
376		}
377		i++
378		fieldsMatched++
379	}
380
381	return numFields == fieldsMatched
382}
383
384func (p *Pattern) matchIdentical(sub *pattern, typ types.Type) bool {
385	switch sub.op {
386	case opVar:
387		name := sub.value.(string)
388		if name == "_" {
389			return true
390		}
391		y, ok := p.typeMatches[name]
392		if !ok {
393			p.typeMatches[name] = typ
394			return true
395		}
396		if y == nil {
397			return typ == nil
398		}
399		return xtypes.Identical(typ, y)
400
401	case opBuiltinType:
402		return xtypes.Identical(typ, sub.value.(types.Type))
403
404	case opPointer:
405		typ, ok := typ.(*types.Pointer)
406		if !ok {
407			return false
408		}
409		return p.matchIdentical(sub.subs[0], typ.Elem())
410
411	case opSlice:
412		typ, ok := typ.(*types.Slice)
413		if !ok {
414			return false
415		}
416		return p.matchIdentical(sub.subs[0], typ.Elem())
417
418	case opArray:
419		typ, ok := typ.(*types.Array)
420		if !ok {
421			return false
422		}
423		var wantLen int64
424		switch v := sub.value.(type) {
425		case string:
426			if v == "_" {
427				wantLen = typ.Len()
428				break
429			}
430			length, ok := p.int64Matches[v]
431			if ok {
432				wantLen = length
433			} else {
434				p.int64Matches[v] = typ.Len()
435				wantLen = typ.Len()
436			}
437		case int64:
438			wantLen = v
439		}
440		return wantLen == typ.Len() && p.matchIdentical(sub.subs[0], typ.Elem())
441
442	case opMap:
443		typ, ok := typ.(*types.Map)
444		if !ok {
445			return false
446		}
447		return p.matchIdentical(sub.subs[0], typ.Key()) &&
448			p.matchIdentical(sub.subs[1], typ.Elem())
449
450	case opChan:
451		typ, ok := typ.(*types.Chan)
452		if !ok {
453			return false
454		}
455		dir := sub.value.(types.ChanDir)
456		return dir == typ.Dir() && p.matchIdentical(sub.subs[0], typ.Elem())
457
458	case opNamed:
459		typ, ok := typ.(*types.Named)
460		if !ok {
461			return false
462		}
463		obj := typ.Obj()
464		pkg := obj.Pkg()
465		// pkg can be nil for builtin named types.
466		// There is no point in checking anything else as we never
467		// generate the opNamed for such types.
468		if pkg == nil {
469			return false
470		}
471		pkgPath := sub.value.([2]string)[0]
472		typeName := sub.value.([2]string)[1]
473		// obj.Pkg().Path() may be in a vendor directory.
474		path := strings.SplitAfter(obj.Pkg().Path(), "/vendor/")
475		return path[len(path)-1] == pkgPath && typeName == obj.Name()
476
477	case opFunc:
478		typ, ok := typ.(*types.Signature)
479		if !ok {
480			return false
481		}
482		numParams := sub.value.(int)
483		params := sub.subs[:numParams]
484		results := sub.subs[numParams:]
485		if typ.Params().Len() != len(params) {
486			return false
487		}
488		if typ.Results().Len() != len(results) {
489			return false
490		}
491		for i := 0; i < typ.Params().Len(); i++ {
492			if !p.matchIdentical(params[i], typ.Params().At(i).Type()) {
493				return false
494			}
495		}
496		for i := 0; i < typ.Results().Len(); i++ {
497			if !p.matchIdentical(results[i], typ.Results().At(i).Type()) {
498				return false
499			}
500		}
501		return true
502
503	case opStructNoSeq:
504		typ, ok := typ.(*types.Struct)
505		if !ok {
506			return false
507		}
508		if typ.NumFields() != len(sub.subs) {
509			return false
510		}
511		for i, member := range sub.subs {
512			if !p.matchIdentical(member, typ.Field(i).Type()) {
513				return false
514			}
515		}
516		return true
517
518	case opStruct:
519		typ, ok := typ.(*types.Struct)
520		if !ok {
521			return false
522		}
523		if !p.matchIdenticalFielder(sub.subs, typ) {
524			return false
525		}
526		return true
527
528	default:
529		return false
530	}
531}
532
533type fielder interface {
534	Field(i int) *types.Var
535	NumFields() int
536}
537