1// Copyright 2017 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package pointer
6
7import (
8	"errors"
9	"fmt"
10	"go/ast"
11	"go/parser"
12	"go/token"
13	"go/types"
14	"strconv"
15)
16
17// An extendedQuery represents a sequence of destructuring operations
18// applied to an ssa.Value (denoted by "x").
19type extendedQuery struct {
20	ops []interface{}
21	ptr *Pointer
22}
23
24// indexValue returns the value of an integer literal used as an
25// index.
26func indexValue(expr ast.Expr) (int, error) {
27	lit, ok := expr.(*ast.BasicLit)
28	if !ok {
29		return 0, fmt.Errorf("non-integer index (%T)", expr)
30	}
31	if lit.Kind != token.INT {
32		return 0, fmt.Errorf("non-integer index %s", lit.Value)
33	}
34	return strconv.Atoi(lit.Value)
35}
36
37// parseExtendedQuery parses and validates a destructuring Go
38// expression and returns the sequence of destructuring operations.
39// See parseDestructuringExpr for details.
40func parseExtendedQuery(typ types.Type, query string) ([]interface{}, types.Type, error) {
41	expr, err := parser.ParseExpr(query)
42	if err != nil {
43		return nil, nil, err
44	}
45	ops, typ, err := destructuringOps(typ, expr)
46	if err != nil {
47		return nil, nil, err
48	}
49	if len(ops) == 0 {
50		return nil, nil, errors.New("invalid query: must not be empty")
51	}
52	if ops[0] != "x" {
53		return nil, nil, fmt.Errorf("invalid query: query operand must be named x")
54	}
55	if !CanPoint(typ) {
56		return nil, nil, fmt.Errorf("query does not describe a pointer-like value: %s", typ)
57	}
58	return ops, typ, nil
59}
60
61// destructuringOps parses a Go expression consisting only of an
62// identifier "x", field selections, indexing, channel receives, load
63// operations and parens---for example: "<-(*x[i])[key]"--- and
64// returns the sequence of destructuring operations on x.
65func destructuringOps(typ types.Type, expr ast.Expr) ([]interface{}, types.Type, error) {
66	switch expr := expr.(type) {
67	case *ast.SelectorExpr:
68		out, typ, err := destructuringOps(typ, expr.X)
69		if err != nil {
70			return nil, nil, err
71		}
72
73		var structT *types.Struct
74		switch typ := typ.Underlying().(type) {
75		case *types.Pointer:
76			var ok bool
77			structT, ok = typ.Elem().Underlying().(*types.Struct)
78			if !ok {
79				return nil, nil, fmt.Errorf("cannot access field %s of pointer to type %s", expr.Sel.Name, typ.Elem())
80			}
81
82			out = append(out, "load")
83		case *types.Struct:
84			structT = typ
85		default:
86			return nil, nil, fmt.Errorf("cannot access field %s of type %s", expr.Sel.Name, typ)
87		}
88
89		for i := 0; i < structT.NumFields(); i++ {
90			field := structT.Field(i)
91			if field.Name() == expr.Sel.Name {
92				out = append(out, "field", i)
93				return out, field.Type().Underlying(), nil
94			}
95		}
96		// TODO(dh): supporting embedding would need something like
97		// types.LookupFieldOrMethod, but without taking package
98		// boundaries into account, because we may want to access
99		// unexported fields. If we were only interested in one level
100		// of unexported name, we could determine the appropriate
101		// package and run LookupFieldOrMethod with that. However, a
102		// single query may want to cross multiple package boundaries,
103		// and at this point it's not really worth the complexity.
104		return nil, nil, fmt.Errorf("no field %s in %s (embedded fields must be resolved manually)", expr.Sel.Name, structT)
105	case *ast.Ident:
106		return []interface{}{expr.Name}, typ, nil
107	case *ast.BasicLit:
108		return []interface{}{expr.Value}, nil, nil
109	case *ast.IndexExpr:
110		out, typ, err := destructuringOps(typ, expr.X)
111		if err != nil {
112			return nil, nil, err
113		}
114		switch typ := typ.Underlying().(type) {
115		case *types.Array:
116			out = append(out, "arrayelem")
117			return out, typ.Elem().Underlying(), nil
118		case *types.Slice:
119			out = append(out, "sliceelem")
120			return out, typ.Elem().Underlying(), nil
121		case *types.Map:
122			out = append(out, "mapelem")
123			return out, typ.Elem().Underlying(), nil
124		case *types.Tuple:
125			out = append(out, "index")
126			idx, err := indexValue(expr.Index)
127			if err != nil {
128				return nil, nil, err
129			}
130			out = append(out, idx)
131			if idx >= typ.Len() || idx < 0 {
132				return nil, nil, fmt.Errorf("tuple index %d out of bounds", idx)
133			}
134			return out, typ.At(idx).Type().Underlying(), nil
135		default:
136			return nil, nil, fmt.Errorf("cannot index type %s", typ)
137		}
138
139	case *ast.UnaryExpr:
140		if expr.Op != token.ARROW {
141			return nil, nil, fmt.Errorf("unsupported unary operator %s", expr.Op)
142		}
143		out, typ, err := destructuringOps(typ, expr.X)
144		if err != nil {
145			return nil, nil, err
146		}
147		ch, ok := typ.(*types.Chan)
148		if !ok {
149			return nil, nil, fmt.Errorf("cannot receive from value of type %s", typ)
150		}
151		out = append(out, "recv")
152		return out, ch.Elem().Underlying(), err
153	case *ast.ParenExpr:
154		return destructuringOps(typ, expr.X)
155	case *ast.StarExpr:
156		out, typ, err := destructuringOps(typ, expr.X)
157		if err != nil {
158			return nil, nil, err
159		}
160		ptr, ok := typ.(*types.Pointer)
161		if !ok {
162			return nil, nil, fmt.Errorf("cannot dereference type %s", typ)
163		}
164		out = append(out, "load")
165		return out, ptr.Elem().Underlying(), err
166	default:
167		return nil, nil, fmt.Errorf("unsupported expression %T", expr)
168	}
169}
170
171func (a *analysis) evalExtendedQuery(t types.Type, id nodeid, ops []interface{}) (types.Type, nodeid) {
172	pid := id
173	// TODO(dh): we're allocating intermediary nodes each time
174	// evalExtendedQuery is called. We should probably only generate
175	// them once per (v, ops) pair.
176	for i := 1; i < len(ops); i++ {
177		var nid nodeid
178		switch ops[i] {
179		case "recv":
180			t = t.(*types.Chan).Elem().Underlying()
181			nid = a.addNodes(t, "query.extended")
182			a.load(nid, pid, 0, a.sizeof(t))
183		case "field":
184			i++ // fetch field index
185			tt := t.(*types.Struct)
186			idx := ops[i].(int)
187			offset := a.offsetOf(t, idx)
188			t = tt.Field(idx).Type().Underlying()
189			nid = a.addNodes(t, "query.extended")
190			a.copy(nid, pid+nodeid(offset), a.sizeof(t))
191		case "arrayelem":
192			t = t.(*types.Array).Elem().Underlying()
193			nid = a.addNodes(t, "query.extended")
194			a.copy(nid, 1+pid, a.sizeof(t))
195		case "sliceelem":
196			t = t.(*types.Slice).Elem().Underlying()
197			nid = a.addNodes(t, "query.extended")
198			a.load(nid, pid, 1, a.sizeof(t))
199		case "mapelem":
200			tt := t.(*types.Map)
201			t = tt.Elem()
202			ksize := a.sizeof(tt.Key())
203			vsize := a.sizeof(tt.Elem())
204			nid = a.addNodes(t, "query.extended")
205			a.load(nid, pid, ksize, vsize)
206		case "index":
207			i++ // fetch index
208			tt := t.(*types.Tuple)
209			idx := ops[i].(int)
210			t = tt.At(idx).Type().Underlying()
211			nid = a.addNodes(t, "query.extended")
212			a.copy(nid, pid+nodeid(idx), a.sizeof(t))
213		case "load":
214			t = t.(*types.Pointer).Elem().Underlying()
215			nid = a.addNodes(t, "query.extended")
216			a.load(nid, pid, 0, a.sizeof(t))
217		default:
218			// shouldn't happen
219			panic(fmt.Sprintf("unknown op %q", ops[i]))
220		}
221		pid = nid
222	}
223
224	return t, pid
225}
226