1// Copyright 2021 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package vta
6
7import (
8	"fmt"
9	"go/token"
10	"go/types"
11
12	"golang.org/x/tools/go/callgraph"
13	"golang.org/x/tools/go/ssa"
14	"golang.org/x/tools/go/types/typeutil"
15)
16
17// node interface for VTA nodes.
18type node interface {
19	Type() types.Type
20	String() string
21}
22
23// constant node for VTA.
24type constant struct {
25	typ types.Type
26}
27
28func (c constant) Type() types.Type {
29	return c.typ
30}
31
32func (c constant) String() string {
33	return fmt.Sprintf("Constant(%v)", c.Type())
34}
35
36// pointer node for VTA.
37type pointer struct {
38	typ *types.Pointer
39}
40
41func (p pointer) Type() types.Type {
42	return p.typ
43}
44
45func (p pointer) String() string {
46	return fmt.Sprintf("Pointer(%v)", p.Type())
47}
48
49// mapKey node for VTA, modeling reachable map key types.
50type mapKey struct {
51	typ types.Type
52}
53
54func (mk mapKey) Type() types.Type {
55	return mk.typ
56}
57
58func (mk mapKey) String() string {
59	return fmt.Sprintf("MapKey(%v)", mk.Type())
60}
61
62// mapValue node for VTA, modeling reachable map value types.
63type mapValue struct {
64	typ types.Type
65}
66
67func (mv mapValue) Type() types.Type {
68	return mv.typ
69}
70
71func (mv mapValue) String() string {
72	return fmt.Sprintf("MapValue(%v)", mv.Type())
73}
74
75// sliceElem node for VTA, modeling reachable slice element types.
76type sliceElem struct {
77	typ types.Type
78}
79
80func (s sliceElem) Type() types.Type {
81	return s.typ
82}
83
84func (s sliceElem) String() string {
85	return fmt.Sprintf("Slice([]%v)", s.Type())
86}
87
88// channelElem node for VTA, modeling reachable channel element types.
89type channelElem struct {
90	typ types.Type
91}
92
93func (c channelElem) Type() types.Type {
94	return c.typ
95}
96
97func (c channelElem) String() string {
98	return fmt.Sprintf("Channel(chan %v)", c.Type())
99}
100
101// field node for VTA.
102type field struct {
103	StructType types.Type
104	index      int // index of the field in the struct
105}
106
107func (f field) Type() types.Type {
108	s := f.StructType.Underlying().(*types.Struct)
109	return s.Field(f.index).Type()
110}
111
112func (f field) String() string {
113	s := f.StructType.Underlying().(*types.Struct)
114	return fmt.Sprintf("Field(%v:%s)", f.StructType, s.Field(f.index).Name())
115}
116
117// global node for VTA.
118type global struct {
119	val *ssa.Global
120}
121
122func (g global) Type() types.Type {
123	return g.val.Type()
124}
125
126func (g global) String() string {
127	return fmt.Sprintf("Global(%s)", g.val.Name())
128}
129
130// local node for VTA modeling local variables
131// and function/method parameters.
132type local struct {
133	val ssa.Value
134}
135
136func (l local) Type() types.Type {
137	return l.val.Type()
138}
139
140func (l local) String() string {
141	return fmt.Sprintf("Local(%s)", l.val.Name())
142}
143
144// indexedLocal node for VTA node. Models indexed locals
145// related to the ssa extract instructions.
146type indexedLocal struct {
147	val   ssa.Value
148	index int
149	typ   types.Type
150}
151
152func (i indexedLocal) Type() types.Type {
153	return i.typ
154}
155
156func (i indexedLocal) String() string {
157	return fmt.Sprintf("Local(%s[%d])", i.val.Name(), i.index)
158}
159
160// function node for VTA.
161type function struct {
162	f *ssa.Function
163}
164
165func (f function) Type() types.Type {
166	return f.f.Type()
167}
168
169func (f function) String() string {
170	return fmt.Sprintf("Function(%s)", f.f.Name())
171}
172
173// nestedPtrInterface node represents all references and dereferences
174// of locals and globals that have a nested pointer to interface type.
175// We merge such constructs into a single node for simplicity and without
176// much precision sacrifice as such variables are rare in practice. Both
177// a and b would be represented as the same PtrInterface(I) node in:
178//   type I interface
179//   var a ***I
180//   var b **I
181type nestedPtrInterface struct {
182	typ types.Type
183}
184
185func (l nestedPtrInterface) Type() types.Type {
186	return l.typ
187}
188
189func (l nestedPtrInterface) String() string {
190	return fmt.Sprintf("PtrInterface(%v)", l.typ)
191}
192
193// panicArg models types of all arguments passed to panic.
194type panicArg struct{}
195
196func (p panicArg) Type() types.Type {
197	return nil
198}
199
200func (p panicArg) String() string {
201	return "Panic"
202}
203
204// recoverReturn models types of all return values of recover().
205type recoverReturn struct{}
206
207func (r recoverReturn) Type() types.Type {
208	return nil
209}
210
211func (r recoverReturn) String() string {
212	return "Recover"
213}
214
215// vtaGraph remembers for each VTA node the set of its successors.
216// Tailored for VTA, hence does not support singleton (sub)graphs.
217type vtaGraph map[node]map[node]bool
218
219// addEdge adds an edge x->y to the graph.
220func (g vtaGraph) addEdge(x, y node) {
221	succs, ok := g[x]
222	if !ok {
223		succs = make(map[node]bool)
224		g[x] = succs
225	}
226	succs[y] = true
227}
228
229// successors returns all of n's immediate successors in the graph.
230// The order of successor nodes is arbitrary.
231func (g vtaGraph) successors(n node) []node {
232	var succs []node
233	for succ := range g[n] {
234		succs = append(succs, succ)
235	}
236	return succs
237}
238
239// typePropGraph builds a VTA graph for a set of `funcs` and initial
240// `callgraph` needed to establish interprocedural edges. Returns the
241// graph and a map for unique type representatives.
242func typePropGraph(funcs map[*ssa.Function]bool, callgraph *callgraph.Graph) (vtaGraph, *typeutil.Map) {
243	b := builder{graph: make(vtaGraph), callGraph: callgraph}
244	b.visit(funcs)
245	return b.graph, &b.canon
246}
247
248// Data structure responsible for linearly traversing the
249// code and building a VTA graph.
250type builder struct {
251	graph     vtaGraph
252	callGraph *callgraph.Graph // initial call graph for creating flows at unresolved call sites.
253
254	// Specialized type map for canonicalization of types.Type.
255	// Semantically equivalent types can have different implementations,
256	// i.e., they are different pointer values. The map allows us to
257	// have one unique representative. The keys are fixed and from the
258	// client perspective they are types. The values in our case are
259	// types too, in particular type representatives. Each value is a
260	// pointer so this map is not expected to take much memory.
261	canon typeutil.Map
262}
263
264func (b *builder) visit(funcs map[*ssa.Function]bool) {
265	// Add the fixed edge Panic -> Recover
266	b.graph.addEdge(panicArg{}, recoverReturn{})
267
268	for f, in := range funcs {
269		if in {
270			b.fun(f)
271		}
272	}
273}
274
275func (b *builder) fun(f *ssa.Function) {
276	for _, bl := range f.Blocks {
277		for _, instr := range bl.Instrs {
278			b.instr(instr)
279		}
280	}
281}
282
283func (b *builder) instr(instr ssa.Instruction) {
284	switch i := instr.(type) {
285	case *ssa.Store:
286		b.addInFlowAliasEdges(b.nodeFromVal(i.Addr), b.nodeFromVal(i.Val))
287	case *ssa.MakeInterface:
288		b.addInFlowEdge(b.nodeFromVal(i.X), b.nodeFromVal(i))
289	case *ssa.MakeClosure:
290		b.closure(i)
291	case *ssa.UnOp:
292		b.unop(i)
293	case *ssa.Phi:
294		b.phi(i)
295	case *ssa.ChangeInterface:
296		// Although in change interface a := A(b) command a and b are
297		// the same object, the only interesting flow happens when A
298		// is an interface. We create flow b -> a, but omit a -> b.
299		// The latter flow is not needed: if a gets assigned concrete
300		// type later on, that cannot be propagated back to b as b
301		// is a separate variable. The a -> b flow can happen when
302		// A is a pointer to interface, but then the command is of
303		// type ChangeType, handled below.
304		b.addInFlowEdge(b.nodeFromVal(i.X), b.nodeFromVal(i))
305	case *ssa.ChangeType:
306		// change type command a := A(b) results in a and b being the
307		// same value. For concrete type A, there is no interesting flow.
308		//
309		// Note: When A is an interface, most interface casts are handled
310		// by the ChangeInterface instruction. The relevant case here is
311		// when converting a pointer to an interface type. This can happen
312		// when the underlying interfaces have the same method set.
313		//   type I interface{ foo() }
314		//   type J interface{ foo() }
315		//   var b *I
316		//   a := (*J)(b)
317		// When this happens we add flows between a <--> b.
318		b.addInFlowAliasEdges(b.nodeFromVal(i), b.nodeFromVal(i.X))
319	case *ssa.TypeAssert:
320		b.tassert(i)
321	case *ssa.Extract:
322		b.extract(i)
323	case *ssa.Field:
324		b.field(i)
325	case *ssa.FieldAddr:
326		b.fieldAddr(i)
327	case *ssa.Send:
328		b.send(i)
329	case *ssa.Select:
330		b.selekt(i)
331	case *ssa.Index:
332		b.index(i)
333	case *ssa.IndexAddr:
334		b.indexAddr(i)
335	case *ssa.Lookup:
336		b.lookup(i)
337	case *ssa.MapUpdate:
338		b.mapUpdate(i)
339	case *ssa.Next:
340		b.next(i)
341	case ssa.CallInstruction:
342		b.call(i)
343	case *ssa.Panic:
344		b.panic(i)
345	case *ssa.Return:
346		b.rtrn(i)
347	case *ssa.MakeChan, *ssa.MakeMap, *ssa.MakeSlice, *ssa.BinOp,
348		*ssa.Alloc, *ssa.DebugRef, *ssa.Convert, *ssa.Jump, *ssa.If,
349		*ssa.Slice, *ssa.Range, *ssa.RunDefers:
350		// No interesting flow here.
351		return
352	default:
353		panic(fmt.Sprintf("unsupported instruction %v\n", instr))
354	}
355}
356
357func (b *builder) unop(u *ssa.UnOp) {
358	switch u.Op {
359	case token.MUL:
360		// Multiplication operator * is used here as a dereference operator.
361		b.addInFlowAliasEdges(b.nodeFromVal(u), b.nodeFromVal(u.X))
362	case token.ARROW:
363		t := u.X.Type().Underlying().(*types.Chan).Elem()
364		b.addInFlowAliasEdges(b.nodeFromVal(u), channelElem{typ: t})
365	default:
366		// There is no interesting type flow otherwise.
367	}
368}
369
370func (b *builder) phi(p *ssa.Phi) {
371	for _, edge := range p.Edges {
372		b.addInFlowAliasEdges(b.nodeFromVal(p), b.nodeFromVal(edge))
373	}
374}
375
376func (b *builder) tassert(a *ssa.TypeAssert) {
377	if !a.CommaOk {
378		b.addInFlowEdge(b.nodeFromVal(a.X), b.nodeFromVal(a))
379		return
380	}
381	// The case where a is <a.AssertedType, bool> register so there
382	// is a flow from a.X to a[0]. Here, a[0] is represented as an
383	// indexedLocal: an entry into local tuple register a at index 0.
384	tup := a.Type().Underlying().(*types.Tuple)
385	t := tup.At(0).Type()
386
387	local := indexedLocal{val: a, typ: t, index: 0}
388	b.addInFlowEdge(b.nodeFromVal(a.X), local)
389}
390
391// extract instruction t1 := t2[i] generates flows between t2[i]
392// and t1 where the source is indexed local representing a value
393// from tuple register t2 at index i and the target is t1.
394func (b *builder) extract(e *ssa.Extract) {
395	tup := e.Tuple.Type().Underlying().(*types.Tuple)
396	t := tup.At(e.Index).Type()
397
398	local := indexedLocal{val: e.Tuple, typ: t, index: e.Index}
399	b.addInFlowAliasEdges(b.nodeFromVal(e), local)
400}
401
402func (b *builder) field(f *ssa.Field) {
403	fnode := field{StructType: f.X.Type(), index: f.Field}
404	b.addInFlowEdge(fnode, b.nodeFromVal(f))
405}
406
407func (b *builder) fieldAddr(f *ssa.FieldAddr) {
408	t := f.X.Type().Underlying().(*types.Pointer).Elem()
409
410	// Since we are getting pointer to a field, make a bidirectional edge.
411	fnode := field{StructType: t, index: f.Field}
412	b.addInFlowEdge(fnode, b.nodeFromVal(f))
413	b.addInFlowEdge(b.nodeFromVal(f), fnode)
414}
415
416func (b *builder) send(s *ssa.Send) {
417	t := s.Chan.Type().Underlying().(*types.Chan).Elem()
418	b.addInFlowAliasEdges(channelElem{typ: t}, b.nodeFromVal(s.X))
419}
420
421// selekt generates flows for select statement
422//   a = select blocking/nonblocking [c_1 <- t_1, c_2 <- t_2, ..., <- o_1, <- o_2, ...]
423// between receiving channel registers c_i and corresponding input register t_i. Further,
424// flows are generated between o_i and a[2 + i]. Note that a is a tuple register of type
425// <int, bool, r_1, r_2, ...> where the type of r_i is the element type of channel o_i.
426func (b *builder) selekt(s *ssa.Select) {
427	recvIndex := 0
428	for _, state := range s.States {
429		t := state.Chan.Type().Underlying().(*types.Chan).Elem()
430
431		if state.Dir == types.SendOnly {
432			b.addInFlowAliasEdges(channelElem{typ: t}, b.nodeFromVal(state.Send))
433		} else {
434			// state.Dir == RecvOnly by definition of select instructions.
435			tupEntry := indexedLocal{val: s, typ: t, index: 2 + recvIndex}
436			b.addInFlowAliasEdges(tupEntry, channelElem{typ: t})
437			recvIndex++
438		}
439	}
440}
441
442// index instruction a := b[c] on slices creates flows between a and
443// SliceElem(t) flow where t is an interface type of c. Arrays and
444// slice elements are both modeled as SliceElem.
445func (b *builder) index(i *ssa.Index) {
446	et := sliceArrayElem(i.X.Type())
447	b.addInFlowAliasEdges(b.nodeFromVal(i), sliceElem{typ: et})
448}
449
450// indexAddr instruction a := &b[c] fetches address of a index
451// into the field so we create bidirectional flow a <-> SliceElem(t)
452// where t is an interface type of c. Arrays and slice elements are
453// both modeled as SliceElem.
454func (b *builder) indexAddr(i *ssa.IndexAddr) {
455	et := sliceArrayElem(i.X.Type())
456	b.addInFlowEdge(sliceElem{typ: et}, b.nodeFromVal(i))
457	b.addInFlowEdge(b.nodeFromVal(i), sliceElem{typ: et})
458}
459
460// lookup handles map query commands a := m[b] where m is of type
461// map[...]V and V is an interface. It creates flows between `a`
462// and MapValue(V).
463func (b *builder) lookup(l *ssa.Lookup) {
464	t, ok := l.X.Type().Underlying().(*types.Map)
465	if !ok {
466		// No interesting flows for string lookups.
467		return
468	}
469	b.addInFlowAliasEdges(b.nodeFromVal(l), mapValue{typ: t.Elem()})
470}
471
472// mapUpdate handles map update commands m[b] = a where m is of type
473// map[K]V and K and V are interfaces. It creates flows between `a`
474// and MapValue(V) as well as between MapKey(K) and `b`.
475func (b *builder) mapUpdate(u *ssa.MapUpdate) {
476	t, ok := u.Map.Type().Underlying().(*types.Map)
477	if !ok {
478		// No interesting flows for string updates.
479		return
480	}
481
482	b.addInFlowAliasEdges(mapKey{typ: t.Key()}, b.nodeFromVal(u.Key))
483	b.addInFlowAliasEdges(mapValue{typ: t.Elem()}, b.nodeFromVal(u.Value))
484}
485
486// next instruction <ok, key, value> := next r, where r
487// is a range over map or string generates flow between
488// key and MapKey as well value and MapValue nodes.
489func (b *builder) next(n *ssa.Next) {
490	if n.IsString {
491		return
492	}
493	tup := n.Type().Underlying().(*types.Tuple)
494	kt := tup.At(1).Type()
495	vt := tup.At(2).Type()
496
497	b.addInFlowAliasEdges(indexedLocal{val: n, typ: kt, index: 1}, mapKey{typ: kt})
498	b.addInFlowAliasEdges(indexedLocal{val: n, typ: vt, index: 2}, mapValue{typ: vt})
499}
500
501// addInFlowAliasEdges adds an edge r -> l to b.graph if l is a node that can
502// have an inflow, i.e., a node that represents an interface or an unresolved
503// function value. Similarly for the edge l -> r with an additional condition
504// of that l and r can potentially alias.
505func (b *builder) addInFlowAliasEdges(l, r node) {
506	b.addInFlowEdge(r, l)
507
508	if canAlias(l, r) {
509		b.addInFlowEdge(l, r)
510	}
511}
512
513func (b *builder) closure(c *ssa.MakeClosure) {
514	f := c.Fn.(*ssa.Function)
515	b.addInFlowEdge(function{f: f}, b.nodeFromVal(c))
516
517	for i, fv := range f.FreeVars {
518		b.addInFlowAliasEdges(b.nodeFromVal(fv), b.nodeFromVal(c.Bindings[i]))
519	}
520}
521
522// panic creates a flow from arguments to panic instructions to return
523// registers of all recover statements in the program. Introduces a
524// global panic node Panic and
525//  1) for every panic statement p: add p -> Panic
526//  2) for every recover statement r: add Panic -> r (handled in call)
527// TODO(zpavlinovic): improve precision by explicitly modeling how panic
528// values flow from callees to callers and into deferred recover instructions.
529func (b *builder) panic(p *ssa.Panic) {
530	// Panics often have, for instance, strings as arguments which do
531	// not create interesting flows.
532	if !canHaveMethods(p.X.Type()) {
533		return
534	}
535
536	b.addInFlowEdge(b.nodeFromVal(p.X), panicArg{})
537}
538
539// call adds flows between arguments/parameters and return values/registers
540// for both static and dynamic calls, as well as go and defer calls.
541func (b *builder) call(c ssa.CallInstruction) {
542	// When c is r := recover() call register instruction, we add Recover -> r.
543	if bf, ok := c.Common().Value.(*ssa.Builtin); ok && bf.Name() == "recover" {
544		b.addInFlowEdge(recoverReturn{}, b.nodeFromVal(c.(*ssa.Call)))
545		return
546	}
547
548	for _, f := range siteCallees(c, b.callGraph) {
549		addArgumentFlows(b, c, f)
550	}
551}
552
553func addArgumentFlows(b *builder, c ssa.CallInstruction, f *ssa.Function) {
554	cc := c.Common()
555	// When c is an unresolved method call (cc.Method != nil), cc.Value contains
556	// the receiver object rather than cc.Args[0].
557	if cc.Method != nil {
558		b.addInFlowAliasEdges(b.nodeFromVal(f.Params[0]), b.nodeFromVal(cc.Value))
559	}
560
561	offset := 0
562	if cc.Method != nil {
563		offset = 1
564	}
565	for i, v := range cc.Args {
566		b.addInFlowAliasEdges(b.nodeFromVal(f.Params[i+offset]), b.nodeFromVal(v))
567	}
568}
569
570// rtrn produces flows between values of r and c where
571// c is a call instruction that resolves to the enclosing
572// function of r based on b.callGraph.
573func (b *builder) rtrn(r *ssa.Return) {
574	n := b.callGraph.Nodes[r.Parent()]
575	// n != nil when b.callgraph is sound, but the client can
576	// pass any callgraph, including an underapproximate one.
577	if n == nil {
578		return
579	}
580
581	for _, e := range n.In {
582		if cv, ok := e.Site.(ssa.Value); ok {
583			addReturnFlows(b, r, cv)
584		}
585	}
586}
587
588func addReturnFlows(b *builder, r *ssa.Return, site ssa.Value) {
589	results := r.Results
590	if len(results) == 1 {
591		// When there is only one return value, the destination register does not
592		// have a tuple type.
593		b.addInFlowEdge(b.nodeFromVal(results[0]), b.nodeFromVal(site))
594		return
595	}
596
597	tup := site.Type().Underlying().(*types.Tuple)
598	for i, r := range results {
599		local := indexedLocal{val: site, typ: tup.At(i).Type(), index: i}
600		b.addInFlowEdge(b.nodeFromVal(r), local)
601	}
602}
603
604// addInFlowEdge adds s -> d to g if d is node that can have an inflow, i.e., a node
605// that represents an interface or an unresolved function value. Otherwise, there
606// is no interesting type flow so the edge is ommited.
607func (b *builder) addInFlowEdge(s, d node) {
608	if hasInFlow(d) {
609		b.graph.addEdge(b.representative(s), b.representative(d))
610	}
611}
612
613// Creates const, pointer, global, func, and local nodes based on register instructions.
614func (b *builder) nodeFromVal(val ssa.Value) node {
615	if p, ok := val.Type().(*types.Pointer); ok && !isInterface(p.Elem()) {
616		// Nested pointer to interfaces are modeled as a special
617		// nestedPtrInterface node.
618		if i := interfaceUnderPtr(p.Elem()); i != nil {
619			return nestedPtrInterface{typ: i}
620		}
621		return pointer{typ: p}
622	}
623
624	switch v := val.(type) {
625	case *ssa.Const:
626		return constant{typ: val.Type()}
627	case *ssa.Global:
628		return global{val: v}
629	case *ssa.Function:
630		return function{f: v}
631	case *ssa.Parameter, *ssa.FreeVar, ssa.Instruction:
632		// ssa.Param, ssa.FreeVar, and a specific set of "register" instructions,
633		// satisifying the ssa.Value interface, can serve as local variables.
634		return local{val: v}
635	default:
636		panic(fmt.Errorf("unsupported value %v in node creation", val))
637	}
638	return nil
639}
640
641// representative returns a unique representative for node `n`. Since
642// semantically equivalent types can have different implementations,
643// this method guarantees the same implementation is always used.
644func (b *builder) representative(n node) node {
645	if !hasInitialTypes(n) {
646		return n
647	}
648	t := canonicalize(n.Type(), &b.canon)
649
650	switch i := n.(type) {
651	case constant:
652		return constant{typ: t}
653	case pointer:
654		return pointer{typ: t.(*types.Pointer)}
655	case sliceElem:
656		return sliceElem{typ: t}
657	case mapKey:
658		return mapKey{typ: t}
659	case mapValue:
660		return mapValue{typ: t}
661	case channelElem:
662		return channelElem{typ: t}
663	case nestedPtrInterface:
664		return nestedPtrInterface{typ: t}
665	case field:
666		return field{StructType: canonicalize(i.StructType, &b.canon), index: i.index}
667	case indexedLocal:
668		return indexedLocal{typ: t, val: i.val, index: i.index}
669	case local, global, panicArg, recoverReturn, function:
670		return n
671	default:
672		panic(fmt.Errorf("canonicalizing unrecognized node %v", n))
673	}
674}
675
676// canonicalize returns a type representative of `t` unique subject
677// to type map `canon`.
678func canonicalize(t types.Type, canon *typeutil.Map) types.Type {
679	rep := canon.At(t)
680	if rep != nil {
681		return rep.(types.Type)
682	}
683	canon.Set(t, t)
684	return t
685}
686