1package compiler
2
3// This file implements the 'defer' keyword in Go.
4// Defer statements are implemented by transforming the function in the
5// following way:
6//   * Creating an alloca in the entry block that contains a pointer (initially
7//     null) to the linked list of defer frames.
8//   * Every time a defer statement is executed, a new defer frame is created
9//     using alloca with a pointer to the previous defer frame, and the head
10//     pointer in the entry block is replaced with a pointer to this defer
11//     frame.
12//   * On return, runtime.rundefers is called which calls all deferred functions
13//     from the head of the linked list until it has gone through all defer
14//     frames.
15
16import (
17	"github.com/tinygo-org/tinygo/compiler/llvmutil"
18	"github.com/tinygo-org/tinygo/ir"
19	"go/types"
20	"golang.org/x/tools/go/ssa"
21	"tinygo.org/x/go-llvm"
22)
23
24// deferInitFunc sets up this function for future deferred calls. It must be
25// called from within the entry block when this function contains deferred
26// calls.
27func (b *builder) deferInitFunc() {
28	// Some setup.
29	b.deferFuncs = make(map[*ir.Function]int)
30	b.deferInvokeFuncs = make(map[string]int)
31	b.deferClosureFuncs = make(map[*ir.Function]int)
32	b.deferExprFuncs = make(map[ssa.Value]int)
33	b.deferBuiltinFuncs = make(map[ssa.Value]deferBuiltin)
34
35	// Create defer list pointer.
36	deferType := llvm.PointerType(b.getLLVMRuntimeType("_defer"), 0)
37	b.deferPtr = b.CreateAlloca(deferType, "deferPtr")
38	b.CreateStore(llvm.ConstPointerNull(deferType), b.deferPtr)
39}
40
41// isInLoop checks if there is a path from a basic block to itself.
42func isInLoop(start *ssa.BasicBlock) bool {
43	// Use a breadth-first search to scan backwards through the block graph.
44	queue := []*ssa.BasicBlock{start}
45	checked := map[*ssa.BasicBlock]struct{}{}
46
47	for len(queue) > 0 {
48		// pop a block off of the queue
49		block := queue[len(queue)-1]
50		queue = queue[:len(queue)-1]
51
52		// Search through predecessors.
53		// Searching backwards means that this is pretty fast when the block is close to the start of the function.
54		// Defers are often placed near the start of the function.
55		for _, pred := range block.Preds {
56			if pred == start {
57				// cycle found
58				return true
59			}
60
61			if _, ok := checked[pred]; ok {
62				// block already checked
63				continue
64			}
65
66			// add to queue and checked map
67			queue = append(queue, pred)
68			checked[pred] = struct{}{}
69		}
70	}
71
72	return false
73}
74
75// createDefer emits a single defer instruction, to be run when this function
76// returns.
77func (b *builder) createDefer(instr *ssa.Defer) {
78	// The pointer to the previous defer struct, which we will replace to
79	// make a linked list.
80	next := b.CreateLoad(b.deferPtr, "defer.next")
81
82	var values []llvm.Value
83	valueTypes := []llvm.Type{b.uintptrType, next.Type()}
84	if instr.Call.IsInvoke() {
85		// Method call on an interface.
86
87		// Get callback type number.
88		methodName := instr.Call.Method.FullName()
89		if _, ok := b.deferInvokeFuncs[methodName]; !ok {
90			b.deferInvokeFuncs[methodName] = len(b.allDeferFuncs)
91			b.allDeferFuncs = append(b.allDeferFuncs, &instr.Call)
92		}
93		callback := llvm.ConstInt(b.uintptrType, uint64(b.deferInvokeFuncs[methodName]), false)
94
95		// Collect all values to be put in the struct (starting with
96		// runtime._defer fields, followed by the call parameters).
97		itf := b.getValue(instr.Call.Value) // interface
98		typecode := b.CreateExtractValue(itf, 0, "invoke.func.typecode")
99		receiverValue := b.CreateExtractValue(itf, 1, "invoke.func.receiver")
100		values = []llvm.Value{callback, next, typecode, receiverValue}
101		valueTypes = append(valueTypes, b.uintptrType, b.i8ptrType)
102		for _, arg := range instr.Call.Args {
103			val := b.getValue(arg)
104			values = append(values, val)
105			valueTypes = append(valueTypes, val.Type())
106		}
107
108	} else if callee, ok := instr.Call.Value.(*ssa.Function); ok {
109		// Regular function call.
110		fn := b.ir.GetFunction(callee)
111
112		if _, ok := b.deferFuncs[fn]; !ok {
113			b.deferFuncs[fn] = len(b.allDeferFuncs)
114			b.allDeferFuncs = append(b.allDeferFuncs, fn)
115		}
116		callback := llvm.ConstInt(b.uintptrType, uint64(b.deferFuncs[fn]), false)
117
118		// Collect all values to be put in the struct (starting with
119		// runtime._defer fields).
120		values = []llvm.Value{callback, next}
121		for _, param := range instr.Call.Args {
122			llvmParam := b.getValue(param)
123			values = append(values, llvmParam)
124			valueTypes = append(valueTypes, llvmParam.Type())
125		}
126
127	} else if makeClosure, ok := instr.Call.Value.(*ssa.MakeClosure); ok {
128		// Immediately applied function literal with free variables.
129
130		// Extract the context from the closure. We won't need the function
131		// pointer.
132		// TODO: ignore this closure entirely and put pointers to the free
133		// variables directly in the defer struct, avoiding a memory allocation.
134		closure := b.getValue(instr.Call.Value)
135		context := b.CreateExtractValue(closure, 0, "")
136
137		// Get the callback number.
138		fn := b.ir.GetFunction(makeClosure.Fn.(*ssa.Function))
139		if _, ok := b.deferClosureFuncs[fn]; !ok {
140			b.deferClosureFuncs[fn] = len(b.allDeferFuncs)
141			b.allDeferFuncs = append(b.allDeferFuncs, makeClosure)
142		}
143		callback := llvm.ConstInt(b.uintptrType, uint64(b.deferClosureFuncs[fn]), false)
144
145		// Collect all values to be put in the struct (starting with
146		// runtime._defer fields, followed by all parameters including the
147		// context pointer).
148		values = []llvm.Value{callback, next}
149		for _, param := range instr.Call.Args {
150			llvmParam := b.getValue(param)
151			values = append(values, llvmParam)
152			valueTypes = append(valueTypes, llvmParam.Type())
153		}
154		values = append(values, context)
155		valueTypes = append(valueTypes, context.Type())
156
157	} else if builtin, ok := instr.Call.Value.(*ssa.Builtin); ok {
158		var funcName string
159		switch builtin.Name() {
160		case "close":
161			funcName = "chanClose"
162		default:
163			b.addError(instr.Pos(), "todo: Implement defer for "+builtin.Name())
164			return
165		}
166
167		if _, ok := b.deferBuiltinFuncs[instr.Call.Value]; !ok {
168			b.deferBuiltinFuncs[instr.Call.Value] = deferBuiltin{
169				funcName,
170				len(b.allDeferFuncs),
171			}
172			b.allDeferFuncs = append(b.allDeferFuncs, instr.Call.Value)
173		}
174		callback := llvm.ConstInt(b.uintptrType, uint64(b.deferBuiltinFuncs[instr.Call.Value].callback), false)
175
176		// Collect all values to be put in the struct (starting with
177		// runtime._defer fields).
178		values = []llvm.Value{callback, next}
179		for _, param := range instr.Call.Args {
180			llvmParam := b.getValue(param)
181			values = append(values, llvmParam)
182			valueTypes = append(valueTypes, llvmParam.Type())
183		}
184
185	} else {
186		funcValue := b.getValue(instr.Call.Value)
187
188		if _, ok := b.deferExprFuncs[instr.Call.Value]; !ok {
189			b.deferExprFuncs[instr.Call.Value] = len(b.allDeferFuncs)
190			b.allDeferFuncs = append(b.allDeferFuncs, &instr.Call)
191		}
192
193		callback := llvm.ConstInt(b.uintptrType, uint64(b.deferExprFuncs[instr.Call.Value]), false)
194
195		// Collect all values to be put in the struct (starting with
196		// runtime._defer fields, followed by all parameters including the
197		// context pointer).
198		values = []llvm.Value{callback, next, funcValue}
199		valueTypes = append(valueTypes, funcValue.Type())
200		for _, param := range instr.Call.Args {
201			llvmParam := b.getValue(param)
202			values = append(values, llvmParam)
203			valueTypes = append(valueTypes, llvmParam.Type())
204		}
205	}
206
207	// Make a struct out of the collected values to put in the defer frame.
208	deferFrameType := b.ctx.StructType(valueTypes, false)
209	deferFrame := llvm.ConstNull(deferFrameType)
210	for i, value := range values {
211		deferFrame = b.CreateInsertValue(deferFrame, value, i, "")
212	}
213
214	// Put this struct in an allocation.
215	var alloca llvm.Value
216	if !isInLoop(instr.Block()) {
217		// This can safely use a stack allocation.
218		alloca = llvmutil.CreateEntryBlockAlloca(b.Builder, deferFrameType, "defer.alloca")
219	} else {
220		// This may be hit a variable number of times, so use a heap allocation.
221		size := b.targetData.TypeAllocSize(deferFrameType)
222		sizeValue := llvm.ConstInt(b.uintptrType, size, false)
223		allocCall := b.createRuntimeCall("alloc", []llvm.Value{sizeValue}, "defer.alloc.call")
224		alloca = b.CreateBitCast(allocCall, llvm.PointerType(deferFrameType, 0), "defer.alloc")
225	}
226	if b.NeedsStackObjects() {
227		b.trackPointer(alloca)
228	}
229	b.CreateStore(deferFrame, alloca)
230
231	// Push it on top of the linked list by replacing deferPtr.
232	allocaCast := b.CreateBitCast(alloca, next.Type(), "defer.alloca.cast")
233	b.CreateStore(allocaCast, b.deferPtr)
234}
235
236// createRunDefers emits code to run all deferred functions.
237func (b *builder) createRunDefers() {
238	// Add a loop like the following:
239	//     for stack != nil {
240	//         _stack := stack
241	//         stack = stack.next
242	//         switch _stack.callback {
243	//         case 0:
244	//             // run first deferred call
245	//         case 1:
246	//             // run second deferred call
247	//             // etc.
248	//         default:
249	//             unreachable
250	//         }
251	//     }
252
253	// Create loop.
254	loophead := b.ctx.AddBasicBlock(b.fn.LLVMFn, "rundefers.loophead")
255	loop := b.ctx.AddBasicBlock(b.fn.LLVMFn, "rundefers.loop")
256	unreachable := b.ctx.AddBasicBlock(b.fn.LLVMFn, "rundefers.default")
257	end := b.ctx.AddBasicBlock(b.fn.LLVMFn, "rundefers.end")
258	b.CreateBr(loophead)
259
260	// Create loop head:
261	//     for stack != nil {
262	b.SetInsertPointAtEnd(loophead)
263	deferData := b.CreateLoad(b.deferPtr, "")
264	stackIsNil := b.CreateICmp(llvm.IntEQ, deferData, llvm.ConstPointerNull(deferData.Type()), "stackIsNil")
265	b.CreateCondBr(stackIsNil, end, loop)
266
267	// Create loop body:
268	//     _stack := stack
269	//     stack = stack.next
270	//     switch stack.callback {
271	b.SetInsertPointAtEnd(loop)
272	nextStackGEP := b.CreateInBoundsGEP(deferData, []llvm.Value{
273		llvm.ConstInt(b.ctx.Int32Type(), 0, false),
274		llvm.ConstInt(b.ctx.Int32Type(), 1, false), // .next field
275	}, "stack.next.gep")
276	nextStack := b.CreateLoad(nextStackGEP, "stack.next")
277	b.CreateStore(nextStack, b.deferPtr)
278	gep := b.CreateInBoundsGEP(deferData, []llvm.Value{
279		llvm.ConstInt(b.ctx.Int32Type(), 0, false),
280		llvm.ConstInt(b.ctx.Int32Type(), 0, false), // .callback field
281	}, "callback.gep")
282	callback := b.CreateLoad(gep, "callback")
283	sw := b.CreateSwitch(callback, unreachable, len(b.allDeferFuncs))
284
285	for i, callback := range b.allDeferFuncs {
286		// Create switch case, for example:
287		//     case 0:
288		//         // run first deferred call
289		block := b.ctx.AddBasicBlock(b.fn.LLVMFn, "rundefers.callback")
290		sw.AddCase(llvm.ConstInt(b.uintptrType, uint64(i), false), block)
291		b.SetInsertPointAtEnd(block)
292		switch callback := callback.(type) {
293		case *ssa.CallCommon:
294			// Call on an value or interface value.
295
296			// Get the real defer struct type and cast to it.
297			valueTypes := []llvm.Type{b.uintptrType, llvm.PointerType(b.getLLVMRuntimeType("_defer"), 0)}
298
299			if !callback.IsInvoke() {
300				//Expect funcValue to be passed through the defer frame.
301				valueTypes = append(valueTypes, b.getFuncType(callback.Signature()))
302			} else {
303				//Expect typecode
304				valueTypes = append(valueTypes, b.uintptrType, b.i8ptrType)
305			}
306
307			for _, arg := range callback.Args {
308				valueTypes = append(valueTypes, b.getLLVMType(arg.Type()))
309			}
310
311			deferFrameType := b.ctx.StructType(valueTypes, false)
312			deferFramePtr := b.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame")
313
314			// Extract the params from the struct (including receiver).
315			forwardParams := []llvm.Value{}
316			zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false)
317			for i := 2; i < len(valueTypes); i++ {
318				gep := b.CreateInBoundsGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(b.ctx.Int32Type(), uint64(i), false)}, "gep")
319				forwardParam := b.CreateLoad(gep, "param")
320				forwardParams = append(forwardParams, forwardParam)
321			}
322
323			var fnPtr llvm.Value
324
325			if !callback.IsInvoke() {
326				// Isolate the func value.
327				funcValue := forwardParams[0]
328				forwardParams = forwardParams[1:]
329
330				//Get function pointer and context
331				fp, context := b.decodeFuncValue(funcValue, callback.Signature())
332				fnPtr = fp
333
334				//Pass context
335				forwardParams = append(forwardParams, context)
336			} else {
337				// Isolate the typecode.
338				typecode := forwardParams[0]
339				forwardParams = forwardParams[1:]
340				fnPtr = b.getInvokePtr(callback, typecode)
341
342				// Add the context parameter. An interface call cannot also be a
343				// closure but we have to supply the parameter anyway for platforms
344				// with a strict calling convention.
345				forwardParams = append(forwardParams, llvm.Undef(b.i8ptrType))
346			}
347
348			// Parent coroutine handle.
349			forwardParams = append(forwardParams, llvm.Undef(b.i8ptrType))
350
351			b.createCall(fnPtr, forwardParams, "")
352
353		case *ir.Function:
354			// Direct call.
355
356			// Get the real defer struct type and cast to it.
357			valueTypes := []llvm.Type{b.uintptrType, llvm.PointerType(b.getLLVMRuntimeType("_defer"), 0)}
358			for _, param := range callback.Params {
359				valueTypes = append(valueTypes, b.getLLVMType(param.Type()))
360			}
361			deferFrameType := b.ctx.StructType(valueTypes, false)
362			deferFramePtr := b.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame")
363
364			// Extract the params from the struct.
365			forwardParams := []llvm.Value{}
366			zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false)
367			for i := range callback.Params {
368				gep := b.CreateInBoundsGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(b.ctx.Int32Type(), uint64(i+2), false)}, "gep")
369				forwardParam := b.CreateLoad(gep, "param")
370				forwardParams = append(forwardParams, forwardParam)
371			}
372
373			// Plain TinyGo functions add some extra parameters to implement async functionality and function recievers.
374			// These parameters should not be supplied when calling into an external C/ASM function.
375			if !callback.IsExported() {
376				// Add the context parameter. We know it is ignored by the receiving
377				// function, but we have to pass one anyway.
378				forwardParams = append(forwardParams, llvm.Undef(b.i8ptrType))
379
380				// Parent coroutine handle.
381				forwardParams = append(forwardParams, llvm.Undef(b.i8ptrType))
382			}
383
384			// Call real function.
385			b.createCall(callback.LLVMFn, forwardParams, "")
386
387		case *ssa.MakeClosure:
388			// Get the real defer struct type and cast to it.
389			fn := b.ir.GetFunction(callback.Fn.(*ssa.Function))
390			valueTypes := []llvm.Type{b.uintptrType, llvm.PointerType(b.getLLVMRuntimeType("_defer"), 0)}
391			params := fn.Signature.Params()
392			for i := 0; i < params.Len(); i++ {
393				valueTypes = append(valueTypes, b.getLLVMType(params.At(i).Type()))
394			}
395			valueTypes = append(valueTypes, b.i8ptrType) // closure
396			deferFrameType := b.ctx.StructType(valueTypes, false)
397			deferFramePtr := b.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame")
398
399			// Extract the params from the struct.
400			forwardParams := []llvm.Value{}
401			zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false)
402			for i := 2; i < len(valueTypes); i++ {
403				gep := b.CreateInBoundsGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(b.ctx.Int32Type(), uint64(i), false)}, "")
404				forwardParam := b.CreateLoad(gep, "param")
405				forwardParams = append(forwardParams, forwardParam)
406			}
407
408			// Parent coroutine handle.
409			forwardParams = append(forwardParams, llvm.Undef(b.i8ptrType))
410
411			// Call deferred function.
412			b.createCall(fn.LLVMFn, forwardParams, "")
413		case *ssa.Builtin:
414			db := b.deferBuiltinFuncs[callback]
415
416			//Get parameter types
417			valueTypes := []llvm.Type{b.uintptrType, llvm.PointerType(b.getLLVMRuntimeType("_defer"), 0)}
418
419			//Get signature from call results
420			params := callback.Type().Underlying().(*types.Signature).Params()
421			for i := 0; i < params.Len(); i++ {
422				valueTypes = append(valueTypes, b.getLLVMType(params.At(i).Type()))
423			}
424
425			deferFrameType := b.ctx.StructType(valueTypes, false)
426			deferFramePtr := b.CreateBitCast(deferData, llvm.PointerType(deferFrameType, 0), "deferFrame")
427
428			// Extract the params from the struct.
429			var forwardParams []llvm.Value
430			zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false)
431			for i := 0; i < params.Len(); i++ {
432				gep := b.CreateInBoundsGEP(deferFramePtr, []llvm.Value{zero, llvm.ConstInt(b.ctx.Int32Type(), uint64(i+2), false)}, "gep")
433				forwardParam := b.CreateLoad(gep, "param")
434				forwardParams = append(forwardParams, forwardParam)
435			}
436
437			b.createRuntimeCall(db.funcName, forwardParams, "")
438		default:
439			panic("unknown deferred function type")
440		}
441
442		// Branch back to the start of the loop.
443		b.CreateBr(loophead)
444	}
445
446	// Create default unreachable block:
447	//     default:
448	//         unreachable
449	//     }
450	b.SetInsertPointAtEnd(unreachable)
451	b.CreateUnreachable()
452
453	// End of loop.
454	b.SetInsertPointAtEnd(end)
455}
456