1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"fmt"
9	"go/token"
10	"go/types"
11
12	"golang.org/x/tools/cmd/guru/serial"
13	"golang.org/x/tools/go/callgraph"
14	"golang.org/x/tools/go/loader"
15	"golang.org/x/tools/go/ssa"
16	"golang.org/x/tools/go/ssa/ssautil"
17)
18
19// The callers function reports the possible callers of the function
20// immediately enclosing the specified source location.
21//
22func callers(q *Query) error {
23	lconf := loader.Config{Build: q.Build}
24
25	if err := setPTAScope(&lconf, q.Scope); err != nil {
26		return err
27	}
28
29	// Load/parse/type-check the program.
30	lprog, err := loadWithSoftErrors(&lconf)
31	if err != nil {
32		return err
33	}
34
35	qpos, err := parseQueryPos(lprog, q.Pos, false)
36	if err != nil {
37		return err
38	}
39
40	prog := ssautil.CreateProgram(lprog, 0)
41
42	ptaConfig, err := setupPTA(prog, lprog, q.PTALog, q.Reflection)
43	if err != nil {
44		return err
45	}
46
47	pkg := prog.Package(qpos.info.Pkg)
48	if pkg == nil {
49		return fmt.Errorf("no SSA package")
50	}
51	if !ssa.HasEnclosingFunction(pkg, qpos.path) {
52		return fmt.Errorf("this position is not inside a function")
53	}
54
55	// Defer SSA construction till after errors are reported.
56	prog.Build()
57
58	target := ssa.EnclosingFunction(pkg, qpos.path)
59	if target == nil {
60		return fmt.Errorf("no SSA function built for this location (dead code?)")
61	}
62
63	// If the function is never address-taken, all calls are direct
64	// and can be found quickly by inspecting the whole SSA program.
65	cg := directCallsTo(target, entryPoints(ptaConfig.Mains))
66	if cg == nil {
67		// Run the pointer analysis, recording each
68		// call found to originate from target.
69		// (Pointer analysis may return fewer results than
70		// directCallsTo because it ignores dead code.)
71		ptaConfig.BuildCallGraph = true
72		cg = ptrAnalysis(ptaConfig).CallGraph
73	}
74	cg.DeleteSyntheticNodes()
75	edges := cg.CreateNode(target).In
76
77	// TODO(adonovan): sort + dedup calls to ensure test determinism.
78
79	q.Output(lprog.Fset, &callersResult{
80		target:    target,
81		callgraph: cg,
82		edges:     edges,
83	})
84	return nil
85}
86
87// directCallsTo inspects the whole program and returns a callgraph
88// containing edges for all direct calls to the target function.
89// directCallsTo returns nil if the function is ever address-taken.
90func directCallsTo(target *ssa.Function, entrypoints []*ssa.Function) *callgraph.Graph {
91	cg := callgraph.New(nil) // use nil as root *Function
92	targetNode := cg.CreateNode(target)
93
94	// Is the function a program entry point?
95	// If so, add edge from callgraph root.
96	for _, f := range entrypoints {
97		if f == target {
98			callgraph.AddEdge(cg.Root, nil, targetNode)
99		}
100	}
101
102	// Find receiver type (for methods).
103	var recvType types.Type
104	if recv := target.Signature.Recv(); recv != nil {
105		recvType = recv.Type()
106	}
107
108	// Find all direct calls to function,
109	// or a place where its address is taken.
110	var space [32]*ssa.Value // preallocate
111	for fn := range ssautil.AllFunctions(target.Prog) {
112		for _, b := range fn.Blocks {
113			for _, instr := range b.Instrs {
114				// Is this a method (T).f of a concrete type T
115				// whose runtime type descriptor is address-taken?
116				// (To be fully sound, we would have to check that
117				// the type doesn't make it to reflection as a
118				// subelement of some other address-taken type.)
119				if recvType != nil {
120					if mi, ok := instr.(*ssa.MakeInterface); ok {
121						if types.Identical(mi.X.Type(), recvType) {
122							return nil // T is address-taken
123						}
124						if ptr, ok := mi.X.Type().(*types.Pointer); ok &&
125							types.Identical(ptr.Elem(), recvType) {
126							return nil // *T is address-taken
127						}
128					}
129				}
130
131				// Direct call to target?
132				rands := instr.Operands(space[:0])
133				if site, ok := instr.(ssa.CallInstruction); ok &&
134					site.Common().Value == target {
135					callgraph.AddEdge(cg.CreateNode(fn), site, targetNode)
136					rands = rands[1:] // skip .Value (rands[0])
137				}
138
139				// Address-taken?
140				for _, rand := range rands {
141					if rand != nil && *rand == target {
142						return nil
143					}
144				}
145			}
146		}
147	}
148
149	return cg
150}
151
152func entryPoints(mains []*ssa.Package) []*ssa.Function {
153	var entrypoints []*ssa.Function
154	for _, pkg := range mains {
155		entrypoints = append(entrypoints, pkg.Func("init"))
156		if main := pkg.Func("main"); main != nil && pkg.Pkg.Name() == "main" {
157			entrypoints = append(entrypoints, main)
158		}
159	}
160	return entrypoints
161}
162
163type callersResult struct {
164	target    *ssa.Function
165	callgraph *callgraph.Graph
166	edges     []*callgraph.Edge
167}
168
169func (r *callersResult) PrintPlain(printf printfFunc) {
170	root := r.callgraph.Root
171	if r.edges == nil {
172		printf(r.target, "%s is not reachable in this program.", r.target)
173	} else {
174		printf(r.target, "%s is called from these %d sites:", r.target, len(r.edges))
175		for _, edge := range r.edges {
176			if edge.Caller == root {
177				printf(r.target, "the root of the call graph")
178			} else {
179				printf(edge, "\t%s from %s", edge.Description(), edge.Caller.Func)
180			}
181		}
182	}
183}
184
185func (r *callersResult) JSON(fset *token.FileSet) []byte {
186	var callers []serial.Caller
187	for _, edge := range r.edges {
188		callers = append(callers, serial.Caller{
189			Caller: edge.Caller.Func.String(),
190			Pos:    fset.Position(edge.Pos()).String(),
191			Desc:   edge.Description(),
192		})
193	}
194	return toJSON(callers)
195}
196