1package transform
2
3// This file lowers func values into their final form. This is necessary for
4// funcValueSwitch, which needs full program analysis.
5
6import (
7	"sort"
8	"strconv"
9	"strings"
10
11	"github.com/tinygo-org/tinygo/compiler/llvmutil"
12	"tinygo.org/x/go-llvm"
13)
14
15// funcSignatureInfo keeps information about a single signature and its uses.
16type funcSignatureInfo struct {
17	sig                     llvm.Value   // *uint8 to identify the signature
18	funcValueWithSignatures []llvm.Value // slice of runtime.funcValueWithSignature
19}
20
21// funcWithUses keeps information about a single function used as func value and
22// the assigned function ID. More commonly used functions are assigned a lower
23// ID.
24type funcWithUses struct {
25	funcPtr  llvm.Value
26	useCount int // how often this function is used in a func value
27	id       int // assigned ID
28}
29
30// Slice to sort functions by their use counts, or else their name if they're
31// used equally often.
32type funcWithUsesList []*funcWithUses
33
34func (l funcWithUsesList) Len() int { return len(l) }
35func (l funcWithUsesList) Less(i, j int) bool {
36	if l[i].useCount != l[j].useCount {
37		// return the reverse: we want the highest use counts sorted first
38		return l[i].useCount > l[j].useCount
39	}
40	iName := l[i].funcPtr.Name()
41	jName := l[j].funcPtr.Name()
42	return iName < jName
43}
44func (l funcWithUsesList) Swap(i, j int) {
45	l[i], l[j] = l[j], l[i]
46}
47
48// LowerFuncValues lowers the runtime.funcValueWithSignature type and
49// runtime.getFuncPtr function to their final form.
50func LowerFuncValues(mod llvm.Module) {
51	ctx := mod.Context()
52	builder := ctx.NewBuilder()
53	uintptrType := ctx.IntType(llvm.NewTargetData(mod.DataLayout()).PointerSize() * 8)
54
55	// Find all func values used in the program with their signatures.
56	funcValueWithSignaturePtr := llvm.PointerType(mod.GetTypeByName("runtime.funcValueWithSignature"), 0)
57	signatures := map[string]*funcSignatureInfo{}
58	for global := mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) {
59		var sig, funcVal llvm.Value
60		switch {
61		case global.Type() == funcValueWithSignaturePtr:
62			sig = llvm.ConstExtractValue(global.Initializer(), []uint32{1})
63			funcVal = global
64		case strings.HasPrefix(global.Name(), "reflect/types.type:func:{"):
65			sig = global
66		default:
67			continue
68		}
69
70		name := sig.Name()
71		var funcValueWithSignatures []llvm.Value
72		if funcVal.IsNil() {
73			funcValueWithSignatures = []llvm.Value{}
74		} else {
75			funcValueWithSignatures = []llvm.Value{funcVal}
76		}
77		if info, ok := signatures[name]; ok {
78			info.funcValueWithSignatures = append(info.funcValueWithSignatures, funcValueWithSignatures...)
79		} else {
80			signatures[name] = &funcSignatureInfo{
81				sig:                     sig,
82				funcValueWithSignatures: funcValueWithSignatures,
83			}
84		}
85	}
86
87	// Sort the signatures, for deterministic execution.
88	names := make([]string, 0, len(signatures))
89	for name := range signatures {
90		names = append(names, name)
91	}
92	sort.Strings(names)
93
94	for _, name := range names {
95		info := signatures[name]
96		functions := make(funcWithUsesList, len(info.funcValueWithSignatures))
97		for i, use := range info.funcValueWithSignatures {
98			var useCount int
99			for _, use2 := range getUses(use) {
100				useCount += len(getUses(use2))
101			}
102			functions[i] = &funcWithUses{
103				funcPtr:  llvm.ConstExtractValue(use.Initializer(), []uint32{0}).Operand(0),
104				useCount: useCount,
105			}
106		}
107		sort.Sort(functions)
108
109		for i, fn := range functions {
110			fn.id = i + 1
111			for _, ptrtoint := range getUses(fn.funcPtr) {
112				if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt {
113					continue
114				}
115				for _, funcValueWithSignatureConstant := range getUses(ptrtoint) {
116					if !funcValueWithSignatureConstant.IsACallInst().IsNil() && funcValueWithSignatureConstant.CalledValue().Name() == "runtime.makeGoroutine" {
117						// makeGoroutine calls are handled seperately
118						continue
119					}
120					for _, funcValueWithSignatureGlobal := range getUses(funcValueWithSignatureConstant) {
121						for _, use := range getUses(funcValueWithSignatureGlobal) {
122							if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt {
123								panic("expected const ptrtoint")
124							}
125							use.ReplaceAllUsesWith(llvm.ConstInt(uintptrType, uint64(fn.id), false))
126						}
127					}
128				}
129			}
130		}
131
132		for _, getFuncPtrCall := range getUses(info.sig) {
133			if getFuncPtrCall.IsACallInst().IsNil() {
134				continue
135			}
136			if getFuncPtrCall.CalledValue().Name() != "runtime.getFuncPtr" {
137				panic("expected all call uses to be runtime.getFuncPtr")
138			}
139			funcID := getFuncPtrCall.Operand(1)
140
141			// There are functions used in a func value that
142			// implement this signature.
143			// What we'll do is transform the following:
144			//     rawPtr := runtime.getFuncPtr(func.ptr)
145			//     if rawPtr == nil {
146			//         runtime.nilPanic()
147			//     }
148			//     result := rawPtr(...args, func.context)
149			// into this:
150			//     if false {
151			//         runtime.nilPanic()
152			//     }
153			//     var result // Phi
154			//     switch fn.id {
155			//     case 0:
156			//         runtime.nilPanic()
157			//     case 1:
158			//         result = call first implementation...
159			//     case 2:
160			//         result = call second implementation...
161			//     default:
162			//         unreachable
163			//     }
164
165			// Remove some casts, checks, and the old call which we're going
166			// to replace.
167			for _, callIntPtr := range getUses(getFuncPtrCall) {
168				if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "internal/task.start" {
169					// Special case for goroutine starts.
170					addFuncLoweringSwitch(mod, builder, funcID, callIntPtr, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
171						i8ptrType := llvm.PointerType(ctx.Int8Type(), 0)
172						calleeValue := builder.CreatePtrToInt(funcPtr, uintptrType, "")
173						start := mod.NamedFunction("internal/task.start")
174						builder.CreateCall(start, []llvm.Value{calleeValue, callIntPtr.Operand(1), llvm.Undef(i8ptrType), llvm.ConstNull(i8ptrType)}, "")
175						return llvm.Value{} // void so no return value
176					}, functions)
177					callIntPtr.EraseFromParentAsInstruction()
178					continue
179				}
180				if callIntPtr.IsAIntToPtrInst().IsNil() {
181					panic("expected inttoptr")
182				}
183				for _, ptrUse := range getUses(callIntPtr) {
184					if !ptrUse.IsAICmpInst().IsNil() {
185						ptrUse.ReplaceAllUsesWith(llvm.ConstInt(ctx.Int1Type(), 0, false))
186					} else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr {
187						addFuncLoweringSwitch(mod, builder, funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value {
188							return builder.CreateCall(funcPtr, params, "")
189						}, functions)
190					} else {
191						panic("unexpected getFuncPtrCall")
192					}
193					ptrUse.EraseFromParentAsInstruction()
194				}
195				callIntPtr.EraseFromParentAsInstruction()
196			}
197			getFuncPtrCall.EraseFromParentAsInstruction()
198		}
199	}
200}
201
202// addFuncLoweringSwitch creates a new switch on a function ID and inserts calls
203// to the newly created direct calls. The funcID is the number to switch on,
204// call is the call instruction to replace, and createCall is the callback that
205// actually creates the new call. By changing createCall to something other than
206// builder.CreateCall, instead of calling a function it can start a new
207// goroutine for example.
208func addFuncLoweringSwitch(mod llvm.Module, builder llvm.Builder, funcID, call llvm.Value, createCall func(funcPtr llvm.Value, params []llvm.Value) llvm.Value, functions funcWithUsesList) {
209	ctx := mod.Context()
210	uintptrType := ctx.IntType(llvm.NewTargetData(mod.DataLayout()).PointerSize() * 8)
211	i8ptrType := llvm.PointerType(ctx.Int8Type(), 0)
212
213	// The block that cannot be reached with correct funcValues (to help the
214	// optimizer).
215	builder.SetInsertPointBefore(call)
216	defaultBlock := ctx.AddBasicBlock(call.InstructionParent().Parent(), "func.default")
217	builder.SetInsertPointAtEnd(defaultBlock)
218	builder.CreateUnreachable()
219
220	// Create the switch.
221	builder.SetInsertPointBefore(call)
222	sw := builder.CreateSwitch(funcID, defaultBlock, len(functions)+1)
223
224	// Split right after the switch. We will need to insert a few basic blocks
225	// in this gap.
226	nextBlock := llvmutil.SplitBasicBlock(builder, sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next")
227
228	// Temporarily set the insert point to set the correct debug insert location
229	// for the builder. It got destroyed by the SplitBasicBlock call.
230	builder.SetInsertPointBefore(call)
231
232	// The 0 case, which is actually a nil check.
233	nilBlock := ctx.InsertBasicBlock(nextBlock, "func.nil")
234	builder.SetInsertPointAtEnd(nilBlock)
235	nilPanic := mod.NamedFunction("runtime.nilPanic")
236	builder.CreateCall(nilPanic, []llvm.Value{llvm.Undef(i8ptrType), llvm.ConstNull(i8ptrType)}, "")
237	builder.CreateUnreachable()
238	sw.AddCase(llvm.ConstInt(uintptrType, 0, false), nilBlock)
239
240	// Gather the list of parameters for every call we're going to make.
241	callParams := make([]llvm.Value, call.OperandsCount()-1)
242	for i := range callParams {
243		callParams[i] = call.Operand(i)
244	}
245
246	// If the call produces a value, we need to get it using a PHI
247	// node.
248	phiBlocks := make([]llvm.BasicBlock, len(functions))
249	phiValues := make([]llvm.Value, len(functions))
250	for i, fn := range functions {
251		// Insert a switch case.
252		bb := ctx.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id))
253		builder.SetInsertPointAtEnd(bb)
254		result := createCall(fn.funcPtr, callParams)
255		builder.CreateBr(nextBlock)
256		sw.AddCase(llvm.ConstInt(uintptrType, uint64(fn.id), false), bb)
257		phiBlocks[i] = bb
258		phiValues[i] = result
259	}
260	if call.Type().TypeKind() != llvm.VoidTypeKind {
261		if len(functions) > 0 {
262			// Create the PHI node so that the call result flows into the
263			// next block (after the split). This is only necessary when the
264			// call produced a value.
265			builder.SetInsertPointBefore(nextBlock.FirstInstruction())
266			phi := builder.CreatePHI(call.Type(), "")
267			phi.AddIncoming(phiValues, phiBlocks)
268			call.ReplaceAllUsesWith(phi)
269		} else {
270			// This is always a nil panic, so replace the call result with undef.
271			call.ReplaceAllUsesWith(llvm.Undef(call.Type()))
272		}
273	}
274}
275