1package transform
2
3// This file lowers asynchronous functions and goroutine starts when using the coroutines scheduler.
4// This is accomplished by inserting LLVM intrinsics which are used in order to save the states of functions.
5
6import (
7	"errors"
8	"strconv"
9
10	"github.com/tinygo-org/tinygo/compiler/llvmutil"
11	"tinygo.org/x/go-llvm"
12)
13
14// LowerCoroutines turns async functions into coroutines.
15// This must be run with the coroutines scheduler.
16//
17// Before this pass, goroutine starts are expressed as a call to an intrinsic called "internal/task.start".
18// This intrinsic accepts the function pointer and a pointer to a struct containing the function's arguments.
19//
20// Before this pass, an intrinsic called "internal/task.Pause" is used to express suspensions of the current goroutine.
21//
22// This pass first accumulates a list of blocking functions.
23// A function is considered "blocking" if it calls "internal/task.Pause" or any other blocking function.
24//
25// Blocking calls are implemented by turning blocking functions into a coroutine.
26// The body of each blocking function is modified to start a new coroutine, and to return after the first suspend.
27// After calling a blocking function, the caller coroutine suspends.
28// The caller also provides a buffer to store the return value into.
29// When a blocking function returns, the return value is written into this buffer and then the caller is queued to run.
30//
31// Goroutine starts which invoke non-blocking functions are implemented as direct calls.
32// Goroutine starts are replaced with the creation of a new task data structure followed by a call to the start of the blocking function.
33// The task structure is populated with a "noop" coroutine before invoking the blocking function.
34// When the blocking function returns, it resumes this "noop" coroutine which does nothing.
35// The goroutine starter is able to continue after the first suspend of the started goroutine.
36//
37// The transformation of a function to a coroutine is accomplished using LLVM's coroutines system (https://llvm.org/docs/Coroutines.html).
38// The simplest implementation of a coroutine inserts a suspend point after every blocking call.
39//
40// Transforming blocking functions into coroutines and calls into suspend points is extremely expensive.
41// In many cases, a blocking call is followed immediately by a function terminator (a return or an "unreachable" instruction).
42// This is a blocking "tail call".
43// In a non-returning tail call (call to a non-returning function, such as an infinite loop), the coroutine can exit without any extra work.
44// In a returning tail call, the returned value must either be the return of the call or a value known before the call.
45// If the return value of the caller is the return of the callee, the coroutine can exit without any extra work and the tailed call will instead return to the caller of the caller.
46// If the return value is known in advance, this result can be stored into the parent's return buffer before the call so that a suspend is unnecessary.
47// If the callee returns an unnecessary value, a return buffer can be allocated on the heap so that it will outlive the coroutine.
48//
49// In the implementation of time.Sleep, the current task is pushed onto a timer queue and then suspended.
50// Since the only suspend point is a call to "internal/task.Pause" followed by a return, there is no need to transform this into a coroutine.
51// This generalizes to all blocking functions in which all suspend points can be elided.
52// This optimization saves a substantial amount of binary size.
53func LowerCoroutines(mod llvm.Module, needStackSlots bool) error {
54	ctx := mod.Context()
55
56	builder := ctx.NewBuilder()
57	defer builder.Dispose()
58
59	target := llvm.NewTargetData(mod.DataLayout())
60	defer target.Dispose()
61
62	pass := &coroutineLoweringPass{
63		mod:            mod,
64		ctx:            ctx,
65		builder:        builder,
66		target:         target,
67		needStackSlots: needStackSlots,
68	}
69
70	err := pass.load()
71	if err != nil {
72		return err
73	}
74
75	// Supply task operands to async calls.
76	pass.supplyTaskOperands()
77
78	// Analyze async returns.
79	pass.returnAnalysisPass()
80
81	// Categorize async calls.
82	pass.categorizeCalls()
83
84	// Lower async functions.
85	pass.lowerFuncsPass()
86
87	// Lower calls to internal/task.Current.
88	pass.lowerCurrent()
89
90	// Lower goroutine starts.
91	pass.lowerStartsPass()
92
93	// Fix annotations on async call params.
94	pass.fixAnnotations()
95
96	if needStackSlots {
97		// Set up garbage collector tracking of tasks at start.
98		err = pass.trackGoroutines()
99		if err != nil {
100			return err
101		}
102	}
103
104	return nil
105}
106
107// asyncFunc is a metadata container for an asynchronous function.
108type asyncFunc struct {
109	// fn is the underlying function pointer.
110	fn llvm.Value
111
112	// rawTask is the parameter where the task pointer is passed in.
113	rawTask llvm.Value
114
115	// callers is a set of all functions which call this async function.
116	callers map[llvm.Value]struct{}
117
118	// returns is a list of returns in the function, along with metadata.
119	returns []asyncReturn
120
121	// calls is a list of all calls in the asyncFunc.
122	// normalCalls is a list of all intermideate suspending calls in the asyncFunc.
123	// tailCalls is a list of all tail calls in the asyncFunc.
124	calls, normalCalls, tailCalls []llvm.Value
125}
126
127// asyncReturn is a metadata container for a return from an asynchronous function.
128type asyncReturn struct {
129	// block is the basic block terminated by the return.
130	block llvm.BasicBlock
131
132	// kind is the kind of the return.
133	kind returnKind
134}
135
136// coroutineLoweringPass is a goroutine lowering pass which is used with the "coroutines" scheduler.
137type coroutineLoweringPass struct {
138	mod     llvm.Module
139	ctx     llvm.Context
140	builder llvm.Builder
141	target  llvm.TargetData
142
143	// asyncFuncs is a map of all asyncFuncs.
144	// The map keys are function pointers.
145	asyncFuncs map[llvm.Value]*asyncFunc
146
147	asyncFuncsOrdered []*asyncFunc
148
149	// calls is a slice of all of the async calls in the module.
150	calls []llvm.Value
151
152	i8ptr llvm.Type
153
154	// memory management functions from the runtime
155	alloc, free llvm.Value
156
157	// coroutine intrinsics
158	start, pause, current                                   llvm.Value
159	setState, setRetPtr, getRetPtr, returnTo, returnCurrent llvm.Value
160	createTask                                              llvm.Value
161
162	// llvm.coro intrinsics
163	coroId, coroSize, coroBegin, coroSuspend, coroEnd, coroFree, coroSave llvm.Value
164
165	trackPointer   llvm.Value
166	needStackSlots bool
167}
168
169// findAsyncFuncs finds all asynchronous functions.
170// A function is considered asynchronous if it calls an asynchronous function or intrinsic.
171func (c *coroutineLoweringPass) findAsyncFuncs() {
172	asyncs := map[llvm.Value]*asyncFunc{}
173	asyncsOrdered := []llvm.Value{}
174	calls := []llvm.Value{}
175
176	// Use a breadth-first search to find all async functions.
177	worklist := []llvm.Value{c.pause}
178	for len(worklist) > 0 {
179		// Pop a function off the worklist.
180		fn := worklist[0]
181		worklist = worklist[1:]
182
183		// Get task pointer argument.
184		task := fn.LastParam()
185		if fn != c.pause && (task.IsNil() || task.Name() != "parentHandle") {
186			panic("trying to make exported function async: " + fn.Name())
187		}
188
189		// Search all uses of the function while collecting callers.
190		callers := map[llvm.Value]struct{}{}
191		for use := fn.FirstUse(); !use.IsNil(); use = use.NextUse() {
192			user := use.User()
193			if user.IsACallInst().IsNil() {
194				// User is not a call instruction, so this is irrelevant.
195				continue
196			}
197			if user.CalledValue() != fn {
198				// Not the called value.
199				continue
200			}
201
202			// Add to calls list.
203			calls = append(calls, user)
204
205			// Get the caller.
206			caller := user.InstructionParent().Parent()
207
208			// Add as caller.
209			callers[caller] = struct{}{}
210
211			if _, ok := asyncs[caller]; ok {
212				// Already marked caller as async.
213				continue
214			}
215
216			// Mark the caller as async.
217			// Use nil as a temporary value. It will be replaced later.
218			asyncs[caller] = nil
219			asyncsOrdered = append(asyncsOrdered, caller)
220
221			// Put the caller on the worklist.
222			worklist = append(worklist, caller)
223		}
224
225		asyncs[fn] = &asyncFunc{
226			fn:      fn,
227			rawTask: task,
228			callers: callers,
229		}
230	}
231
232	// Flip the order of the async functions so that the top ones are lowered first.
233	for i := 0; i < len(asyncsOrdered)/2; i++ {
234		asyncsOrdered[i], asyncsOrdered[len(asyncsOrdered)-(i+1)] = asyncsOrdered[len(asyncsOrdered)-(i+1)], asyncsOrdered[i]
235	}
236
237	// Map the elements of asyncsOrdered to *asyncFunc.
238	asyncFuncsOrdered := make([]*asyncFunc, len(asyncsOrdered))
239	for i, v := range asyncsOrdered {
240		asyncFuncsOrdered[i] = asyncs[v]
241	}
242
243	c.asyncFuncs = asyncs
244	c.asyncFuncsOrdered = asyncFuncsOrdered
245	c.calls = calls
246}
247
248func (c *coroutineLoweringPass) load() error {
249	// Find memory management functions from the runtime.
250	c.alloc = c.mod.NamedFunction("runtime.alloc")
251	if c.alloc.IsNil() {
252		return ErrMissingIntrinsic{"runtime.alloc"}
253	}
254	c.free = c.mod.NamedFunction("runtime.free")
255	if c.free.IsNil() {
256		return ErrMissingIntrinsic{"runtime.free"}
257	}
258
259	// Find intrinsics.
260	c.pause = c.mod.NamedFunction("internal/task.Pause")
261	if c.pause.IsNil() {
262		return ErrMissingIntrinsic{"internal/task.Pause"}
263	}
264	c.start = c.mod.NamedFunction("internal/task.start")
265	if c.start.IsNil() {
266		return ErrMissingIntrinsic{"internal/task.start"}
267	}
268	c.current = c.mod.NamedFunction("internal/task.Current")
269	if c.current.IsNil() {
270		return ErrMissingIntrinsic{"internal/task.Current"}
271	}
272	c.setState = c.mod.NamedFunction("(*internal/task.Task).setState")
273	if c.setState.IsNil() {
274		return ErrMissingIntrinsic{"(*internal/task.Task).setState"}
275	}
276	c.setRetPtr = c.mod.NamedFunction("(*internal/task.Task).setReturnPtr")
277	if c.setRetPtr.IsNil() {
278		return ErrMissingIntrinsic{"(*internal/task.Task).setReturnPtr"}
279	}
280	c.getRetPtr = c.mod.NamedFunction("(*internal/task.Task).getReturnPtr")
281	if c.getRetPtr.IsNil() {
282		return ErrMissingIntrinsic{"(*internal/task.Task).getReturnPtr"}
283	}
284	c.returnTo = c.mod.NamedFunction("(*internal/task.Task).returnTo")
285	if c.returnTo.IsNil() {
286		return ErrMissingIntrinsic{"(*internal/task.Task).returnTo"}
287	}
288	c.returnCurrent = c.mod.NamedFunction("(*internal/task.Task).returnCurrent")
289	if c.returnCurrent.IsNil() {
290		return ErrMissingIntrinsic{"(*internal/task.Task).returnCurrent"}
291	}
292	c.createTask = c.mod.NamedFunction("internal/task.createTask")
293	if c.createTask.IsNil() {
294		return ErrMissingIntrinsic{"internal/task.createTask"}
295	}
296
297	if c.needStackSlots {
298		c.trackPointer = c.mod.NamedFunction("runtime.trackPointer")
299		if c.trackPointer.IsNil() {
300			return ErrMissingIntrinsic{"runtime.trackPointer"}
301		}
302	}
303
304	// Find async functions.
305	c.findAsyncFuncs()
306
307	// Get i8* type.
308	c.i8ptr = llvm.PointerType(c.ctx.Int8Type(), 0)
309
310	// Build LLVM coroutine intrinsic.
311	coroIdType := llvm.FunctionType(c.ctx.TokenType(), []llvm.Type{c.ctx.Int32Type(), c.i8ptr, c.i8ptr, c.i8ptr}, false)
312	c.coroId = llvm.AddFunction(c.mod, "llvm.coro.id", coroIdType)
313
314	sizeT := c.alloc.Param(0).Type()
315	coroSizeType := llvm.FunctionType(sizeT, nil, false)
316	c.coroSize = llvm.AddFunction(c.mod, "llvm.coro.size.i"+strconv.Itoa(sizeT.IntTypeWidth()), coroSizeType)
317
318	coroBeginType := llvm.FunctionType(c.i8ptr, []llvm.Type{c.ctx.TokenType(), c.i8ptr}, false)
319	c.coroBegin = llvm.AddFunction(c.mod, "llvm.coro.begin", coroBeginType)
320
321	coroSuspendType := llvm.FunctionType(c.ctx.Int8Type(), []llvm.Type{c.ctx.TokenType(), c.ctx.Int1Type()}, false)
322	c.coroSuspend = llvm.AddFunction(c.mod, "llvm.coro.suspend", coroSuspendType)
323
324	coroEndType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.i8ptr, c.ctx.Int1Type()}, false)
325	c.coroEnd = llvm.AddFunction(c.mod, "llvm.coro.end", coroEndType)
326
327	coroFreeType := llvm.FunctionType(c.i8ptr, []llvm.Type{c.ctx.TokenType(), c.i8ptr}, false)
328	c.coroFree = llvm.AddFunction(c.mod, "llvm.coro.free", coroFreeType)
329
330	coroSaveType := llvm.FunctionType(c.ctx.TokenType(), []llvm.Type{c.i8ptr}, false)
331	c.coroSave = llvm.AddFunction(c.mod, "llvm.coro.save", coroSaveType)
332
333	return nil
334}
335
336func (c *coroutineLoweringPass) track(ptr llvm.Value) {
337	if c.needStackSlots {
338		if ptr.Type() != c.i8ptr {
339			ptr = c.builder.CreateBitCast(ptr, c.i8ptr, "track.bitcast")
340		}
341		c.builder.CreateCall(c.trackPointer, []llvm.Value{ptr, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
342	}
343}
344
345// lowerStartSync lowers a goroutine start of a synchronous function to a synchronous call.
346func (c *coroutineLoweringPass) lowerStartSync(start llvm.Value) {
347	c.builder.SetInsertPointBefore(start)
348
349	// Get function to call.
350	fn := start.Operand(0).Operand(0)
351
352	// Create the list of params for the call.
353	paramTypes := fn.Type().ElementType().ParamTypes()
354	params := llvmutil.EmitPointerUnpack(c.builder, c.mod, start.Operand(1), paramTypes[:len(paramTypes)-1])
355	params = append(params, llvm.Undef(c.i8ptr))
356
357	// Generate call to function.
358	c.builder.CreateCall(fn, params, "")
359
360	// Remove start call.
361	start.EraseFromParentAsInstruction()
362}
363
364// supplyTaskOperands fills in the task operands of async calls.
365func (c *coroutineLoweringPass) supplyTaskOperands() {
366	var curCalls []llvm.Value
367	for use := c.current.FirstUse(); !use.IsNil(); use = use.NextUse() {
368		curCalls = append(curCalls, use.User())
369	}
370	for _, call := range append(curCalls, c.calls...) {
371		c.builder.SetInsertPointBefore(call)
372		task := c.asyncFuncs[call.InstructionParent().Parent()].rawTask
373		call.SetOperand(call.OperandsCount()-2, task)
374	}
375}
376
377// returnKind is a classification of a type of function terminator.
378type returnKind uint8
379
380const (
381	// returnNormal is a terminator that returns a value normally from a function.
382	returnNormal returnKind = iota
383
384	// returnVoid is a terminator that exits normally without returning a value.
385	returnVoid
386
387	// returnVoidTail is a terminator which is a tail call to a void-returning function in a void-returning function.
388	returnVoidTail
389
390	// returnTail is a terinator which is a tail call to a value-returning function where the value is returned by the callee.
391	returnTail
392
393	// returnDeadTail is a terminator which is a call to a non-returning asynchronous function.
394	returnDeadTail
395
396	// returnAlternateTail is a terminator which is a tail call to a value-returning function where a previously acquired value is returned by the callee.
397	returnAlternateTail
398
399	// returnDitchedTail is a terminator which is a tail call to a value-returning function, where the callee returns void.
400	returnDitchedTail
401
402	// returnDelayedValue is a terminator in which a void-returning tail call is followed by a return of a previous value.
403	returnDelayedValue
404)
405
406// isAsyncCall returns whether the specified call is async.
407func (c *coroutineLoweringPass) isAsyncCall(call llvm.Value) bool {
408	_, ok := c.asyncFuncs[call.CalledValue()]
409	return ok
410}
411
412// analyzeFuncReturns analyzes and classifies the returns of a function.
413func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
414	returns := []asyncReturn{}
415	if fn.fn == c.pause {
416		// Skip pause.
417		fn.returns = returns
418		return
419	}
420
421	for _, bb := range fn.fn.BasicBlocks() {
422		last := bb.LastInstruction()
423		switch last.InstructionOpcode() {
424		case llvm.Ret:
425			// Check if it is a void return.
426			isVoid := fn.fn.Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind
427
428			// Analyze previous instruction.
429			prev := llvm.PrevInstruction(last)
430			switch {
431			case prev.IsNil():
432				fallthrough
433			case prev.IsACallInst().IsNil():
434				fallthrough
435			case !c.isAsyncCall(prev):
436				// This is not any form of asynchronous tail call.
437				if isVoid {
438					returns = append(returns, asyncReturn{
439						block: bb,
440						kind:  returnVoid,
441					})
442				} else {
443					returns = append(returns, asyncReturn{
444						block: bb,
445						kind:  returnNormal,
446					})
447				}
448			case isVoid:
449				if prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind {
450					// This is a tail call to a void-returning function from a function with a void return.
451					returns = append(returns, asyncReturn{
452						block: bb,
453						kind:  returnVoidTail,
454					})
455				} else {
456					// This is a tail call to a value-returning function from a function with a void return.
457					// The returned value will be ditched.
458					returns = append(returns, asyncReturn{
459						block: bb,
460						kind:  returnDitchedTail,
461					})
462				}
463			case last.Operand(0) == prev:
464				// This is a regular tail call. The return of the callee is returned to the parent.
465				returns = append(returns, asyncReturn{
466					block: bb,
467					kind:  returnTail,
468				})
469			case prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind:
470				// This is a tail call that returns a previous value after waiting on a void function.
471				returns = append(returns, asyncReturn{
472					block: bb,
473					kind:  returnDelayedValue,
474				})
475			default:
476				// This is a tail call that returns a value that is available before the function call.
477				returns = append(returns, asyncReturn{
478					block: bb,
479					kind:  returnAlternateTail,
480				})
481			}
482		case llvm.Unreachable:
483			prev := llvm.PrevInstruction(last)
484
485			if prev.IsNil() || prev.IsACallInst().IsNil() || !c.isAsyncCall(prev) {
486				// This unreachable instruction does not behave as an asynchronous return.
487				continue
488			}
489
490			// This is an asyncnhronous tail call to function that does not return.
491			returns = append(returns, asyncReturn{
492				block: bb,
493				kind:  returnDeadTail,
494			})
495		}
496	}
497
498	fn.returns = returns
499}
500
501// returnAnalysisPass runs an analysis pass which classifies the returns of all async functions.
502func (c *coroutineLoweringPass) returnAnalysisPass() {
503	for _, async := range c.asyncFuncsOrdered {
504		c.analyzeFuncReturns(async)
505	}
506}
507
508// categorizeCalls categorizes all asynchronous calls into regular vs. async and matches them to their callers.
509func (c *coroutineLoweringPass) categorizeCalls() {
510	// Sort calls into their respective callers.
511	for _, call := range c.calls {
512		caller := c.asyncFuncs[call.InstructionParent().Parent()]
513		caller.calls = append(caller.calls, call)
514	}
515
516	// Seperate regular and tail calls.
517	for _, async := range c.asyncFuncsOrdered {
518		// Search returns for tail calls.
519		tails := map[llvm.Value]struct{}{}
520		for _, ret := range async.returns {
521			switch ret.kind {
522			case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
523				// This is a tail return. The previous instruction is a tail call.
524				tails[llvm.PrevInstruction(ret.block.LastInstruction())] = struct{}{}
525			}
526		}
527
528		// Seperate tail calls and regular calls.
529		normalCalls, tailCalls := []llvm.Value{}, []llvm.Value{}
530		for _, call := range async.calls {
531			if _, ok := tails[call]; ok {
532				// This is a tail call.
533				tailCalls = append(tailCalls, call)
534			} else {
535				// This is a regular call.
536				normalCalls = append(normalCalls, call)
537			}
538		}
539
540		async.normalCalls = normalCalls
541		async.tailCalls = tailCalls
542	}
543}
544
545// lowerFuncsPass lowers all functions, turning them into coroutines if necessary.
546func (c *coroutineLoweringPass) lowerFuncsPass() {
547	for _, fn := range c.asyncFuncs {
548		if fn.fn == c.pause {
549			// Skip. It is an intrinsic.
550			continue
551		}
552
553		if len(fn.normalCalls) == 0 {
554			// No suspend points. Lower without turning it into a coroutine.
555			c.lowerFuncFast(fn)
556		} else {
557			// There are suspend points, so it is necessary to turn this into a coroutine.
558			c.lowerFuncCoro(fn)
559		}
560	}
561}
562
563func (async *asyncFunc) hasValueStoreReturn() bool {
564	for _, ret := range async.returns {
565		switch ret.kind {
566		case returnNormal, returnAlternateTail, returnDelayedValue:
567			return true
568		}
569	}
570
571	return false
572}
573
574// heapAlloc creates a heap allocation large enough to hold the supplied type.
575// The allocation is returned as a raw i8* pointer.
576// This allocation is not automatically tracked by the garbage collector, and should thus be stored into a tracked memory object immediately.
577func (c *coroutineLoweringPass) heapAlloc(t llvm.Type, name string) llvm.Value {
578	sizeT := c.alloc.FirstParam().Type()
579	size := llvm.ConstInt(sizeT, c.target.TypeAllocSize(t), false)
580	return c.builder.CreateCall(c.alloc, []llvm.Value{size, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, name)
581}
582
583// lowerFuncFast lowers an async function that has no suspend points.
584func (c *coroutineLoweringPass) lowerFuncFast(fn *asyncFunc) {
585	// Get return type.
586	retType := fn.fn.Type().ElementType().ReturnType()
587
588	// Get task value.
589	c.insertPointAfterAllocas(fn.fn)
590	task := c.builder.CreateCall(c.current, []llvm.Value{llvm.Undef(c.i8ptr), fn.rawTask}, "task")
591
592	// Get return pointer if applicable.
593	var rawRetPtr, retPtr llvm.Value
594	if fn.hasValueStoreReturn() {
595		rawRetPtr = c.builder.CreateCall(c.getRetPtr, []llvm.Value{task, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "ret.ptr")
596		retType = fn.fn.Type().ElementType().ReturnType()
597		retPtr = c.builder.CreateBitCast(rawRetPtr, llvm.PointerType(retType, 0), "ret.ptr.bitcast")
598	}
599
600	// Lower returns.
601	for _, ret := range fn.returns {
602		// Get terminator.
603		terminator := ret.block.LastInstruction()
604
605		// Get tail call if applicable.
606		var call llvm.Value
607		switch ret.kind {
608		case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
609			call = llvm.PrevInstruction(terminator)
610		}
611
612		switch ret.kind {
613		case returnNormal:
614			c.builder.SetInsertPointBefore(terminator)
615
616			// Store value into return pointer.
617			c.builder.CreateStore(terminator.Operand(0), retPtr)
618
619			// Resume caller.
620			c.builder.CreateCall(c.returnCurrent, []llvm.Value{task, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
621
622			// Erase return argument.
623			terminator.SetOperand(0, llvm.Undef(retType))
624		case returnVoid:
625			c.builder.SetInsertPointBefore(terminator)
626
627			// Resume caller.
628			c.builder.CreateCall(c.returnCurrent, []llvm.Value{task, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
629		case returnVoidTail:
630			// Nothing to do. There is already a tail call followed by a void return.
631		case returnTail:
632			// Erase return argument.
633			terminator.SetOperand(0, llvm.Undef(retType))
634		case returnDeadTail:
635			// Replace unreachable with immediate return, without resuming the caller.
636			c.builder.SetInsertPointBefore(terminator)
637			if retType.TypeKind() == llvm.VoidTypeKind {
638				c.builder.CreateRetVoid()
639			} else {
640				c.builder.CreateRet(llvm.Undef(retType))
641			}
642			terminator.EraseFromParentAsInstruction()
643		case returnAlternateTail:
644			c.builder.SetInsertPointBefore(call)
645
646			// Store return value.
647			c.builder.CreateStore(terminator.Operand(0), retPtr)
648
649			// Heap-allocate a return buffer for the discarded return.
650			alternateBuf := c.heapAlloc(call.Type(), "ret.alternate")
651			c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, alternateBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
652
653			// Erase return argument.
654			terminator.SetOperand(0, llvm.Undef(retType))
655		case returnDitchedTail:
656			c.builder.SetInsertPointBefore(call)
657
658			// Heap-allocate a return buffer for the discarded return.
659			ditchBuf := c.heapAlloc(call.Type(), "ret.ditch")
660			c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, ditchBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
661		case returnDelayedValue:
662			c.builder.SetInsertPointBefore(call)
663
664			// Store value into return pointer.
665			c.builder.CreateStore(terminator.Operand(0), retPtr)
666
667			// Erase return argument.
668			terminator.SetOperand(0, llvm.Undef(retType))
669		}
670
671		// Delete call if it is a pause, because it has already been lowered.
672		if !call.IsNil() && call.CalledValue() == c.pause {
673			call.EraseFromParentAsInstruction()
674		}
675	}
676}
677
678// insertPointAfterAllocas sets the insert point of the builder to be immediately after the last alloca in the entry block.
679func (c *coroutineLoweringPass) insertPointAfterAllocas(fn llvm.Value) {
680	inst := fn.EntryBasicBlock().FirstInstruction()
681	for !inst.IsAAllocaInst().IsNil() {
682		inst = llvm.NextInstruction(inst)
683	}
684	c.builder.SetInsertPointBefore(inst)
685}
686
687// lowerCallReturn lowers the return value of an async call by creating a return buffer and loading the returned value from it.
688func (c *coroutineLoweringPass) lowerCallReturn(caller *asyncFunc, call llvm.Value) {
689	// Get return type.
690	retType := call.Type()
691	if retType.TypeKind() == llvm.VoidTypeKind {
692		// Void return. Nothing to do.
693		return
694	}
695
696	// Create alloca for return buffer.
697	alloca := llvmutil.CreateInstructionAlloca(c.builder, c.mod, retType, call, "call.return")
698
699	// Store new return buffer into task before call.
700	c.builder.SetInsertPointBefore(call)
701	task := c.builder.CreateCall(c.current, []llvm.Value{llvm.Undef(c.i8ptr), caller.rawTask}, "call.task")
702	retPtr := c.builder.CreateBitCast(alloca, c.i8ptr, "call.return.bitcast")
703	c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, retPtr, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
704
705	// Load return value after call.
706	c.builder.SetInsertPointBefore(llvm.NextInstruction(call))
707	ret := c.builder.CreateLoad(alloca, "call.return.load")
708
709	// Replace call value with loaded return.
710	call.ReplaceAllUsesWith(ret)
711}
712
713// lowerFuncCoro transforms an async function into a coroutine by lowering async operations to `llvm.coro` intrinsics.
714// See https://llvm.org/docs/Coroutines.html for more information on these intrinsics.
715func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
716	returnType := fn.fn.Type().ElementType().ReturnType()
717
718	// Prepare coroutine state.
719	c.insertPointAfterAllocas(fn.fn)
720	// %coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
721	coroId := c.builder.CreateCall(c.coroId, []llvm.Value{
722		llvm.ConstInt(c.ctx.Int32Type(), 0, false),
723		llvm.ConstNull(c.i8ptr),
724		llvm.ConstNull(c.i8ptr),
725		llvm.ConstNull(c.i8ptr),
726	}, "coro.id")
727	// %coro.size = call i32 @llvm.coro.size.i32()
728	coroSize := c.builder.CreateCall(c.coroSize, []llvm.Value{}, "coro.size")
729	// %coro.alloc = call i8* runtime.alloc(i32 %coro.size)
730	coroAlloc := c.builder.CreateCall(c.alloc, []llvm.Value{coroSize, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "coro.alloc")
731	// %coro.state = call noalias i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc)
732	coroState := c.builder.CreateCall(c.coroBegin, []llvm.Value{coroId, coroAlloc}, "coro.state")
733	c.track(coroState)
734	// Store state into task.
735	task := c.builder.CreateCall(c.current, []llvm.Value{llvm.Undef(c.i8ptr), fn.rawTask}, "task")
736	parentState := c.builder.CreateCall(c.setState, []llvm.Value{task, coroState, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "task.state.parent")
737	// Get return pointer if needed.
738	var retPtr llvm.Value
739	if fn.hasValueStoreReturn() {
740		retPtr = c.builder.CreateCall(c.getRetPtr, []llvm.Value{task, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "task.retPtr")
741		retPtr = c.builder.CreateBitCast(retPtr, llvm.PointerType(fn.fn.Type().ElementType().ReturnType(), 0), "task.retPtr.bitcast")
742	}
743
744	// Build suspend block.
745	// This is executed when the coroutine is about to suspend.
746	suspend := c.ctx.AddBasicBlock(fn.fn, "suspend")
747	c.builder.SetInsertPointAtEnd(suspend)
748	// %unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false)
749	c.builder.CreateCall(c.coroEnd, []llvm.Value{coroState, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "unused")
750	// Insert return.
751	if returnType.TypeKind() == llvm.VoidTypeKind {
752		c.builder.CreateRetVoid()
753	} else {
754		c.builder.CreateRet(llvm.Undef(returnType))
755	}
756
757	// Build cleanup block.
758	// This is executed before the function returns in order to clean up resources.
759	cleanup := c.ctx.AddBasicBlock(fn.fn, "cleanup")
760	c.builder.SetInsertPointAtEnd(cleanup)
761	// %coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state)
762	coroMemFree := c.builder.CreateCall(c.coroFree, []llvm.Value{coroId, coroState}, "coro.memFree")
763	// call i8* runtime.free(i8* %coro.memFree)
764	c.builder.CreateCall(c.free, []llvm.Value{coroMemFree, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
765	// Branch to suspend block.
766	c.builder.CreateBr(suspend)
767
768	// Restore old state before tail calls.
769	for _, call := range fn.tailCalls {
770		if !llvm.NextInstruction(call).IsAUnreachableInst().IsNil() {
771			// Callee never returns, so the state restore is ineffectual.
772			continue
773		}
774
775		c.builder.SetInsertPointBefore(call)
776		c.builder.CreateCall(c.setState, []llvm.Value{task, parentState, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "coro.state.restore")
777	}
778
779	// Lower returns.
780	for _, ret := range fn.returns {
781		// Get terminator instruction.
782		terminator := ret.block.LastInstruction()
783
784		// Get tail call if applicable.
785		var call llvm.Value
786		switch ret.kind {
787		case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
788			call = llvm.PrevInstruction(terminator)
789		}
790
791		switch ret.kind {
792		case returnNormal:
793			c.builder.SetInsertPointBefore(terminator)
794
795			// Store value into return pointer.
796			c.builder.CreateStore(terminator.Operand(0), retPtr)
797
798			// Resume caller.
799			c.builder.CreateCall(c.returnTo, []llvm.Value{task, parentState, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
800		case returnVoid:
801			c.builder.SetInsertPointBefore(terminator)
802
803			// Resume caller.
804			c.builder.CreateCall(c.returnTo, []llvm.Value{task, parentState, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
805		case returnVoidTail, returnTail, returnDeadTail:
806			// Nothing to do.
807		case returnAlternateTail:
808			c.builder.SetInsertPointBefore(call)
809
810			// Store return value.
811			c.builder.CreateStore(terminator.Operand(0), retPtr)
812
813			// Heap-allocate a return buffer for the discarded return.
814			alternateBuf := c.heapAlloc(call.Type(), "ret.alternate")
815			c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, alternateBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
816		case returnDitchedTail:
817			c.builder.SetInsertPointBefore(call)
818
819			// Heap-allocate a return buffer for the discarded return.
820			ditchBuf := c.heapAlloc(call.Type(), "ret.ditch")
821			c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, ditchBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
822		case returnDelayedValue:
823			c.builder.SetInsertPointBefore(call)
824
825			// Store return value.
826			c.builder.CreateStore(terminator.Operand(0), retPtr)
827		}
828
829		// Delete call if it is a pause, because it has already been lowered.
830		if !call.IsNil() && call.CalledValue() == c.pause {
831			call.EraseFromParentAsInstruction()
832		}
833
834		// Replace terminator with branch to cleanup.
835		terminator.EraseFromParentAsInstruction()
836		c.builder.SetInsertPointAtEnd(ret.block)
837		c.builder.CreateBr(cleanup)
838	}
839
840	// Lower regular calls.
841	for _, call := range fn.normalCalls {
842		// Lower return value of call.
843		c.lowerCallReturn(fn, call)
844
845		// Get originating basic block.
846		bb := call.InstructionParent()
847
848		// Split block.
849		wakeup := llvmutil.SplitBasicBlock(c.builder, call, llvm.NextBasicBlock(bb), "wakeup")
850
851		// Insert suspension and switch.
852		c.builder.SetInsertPointAtEnd(bb)
853		// %coro.save = call token @llvm.coro.save(i8* %coro.state)
854		save := c.builder.CreateCall(c.coroSave, []llvm.Value{coroState}, "coro.save")
855		// %call.suspend = llvm.coro.suspend(token %coro.save, i1 false)
856		// switch i8 %call.suspend, label %suspend [i8 0, label %wakeup
857		//                                          i8 1, label %cleanup]
858		suspendValue := c.builder.CreateCall(c.coroSuspend, []llvm.Value{save, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "call.suspend")
859		sw := c.builder.CreateSwitch(suspendValue, suspend, 2)
860		sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), wakeup)
861		sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), cleanup)
862
863		// Delete call if it is a pause, because it has already been lowered.
864		if call.CalledValue() == c.pause {
865			call.EraseFromParentAsInstruction()
866		}
867
868		c.builder.SetInsertPointBefore(wakeup.FirstInstruction())
869		c.track(coroState)
870	}
871}
872
873// lowerCurrent lowers calls to internal/task.Current to bitcasts.
874func (c *coroutineLoweringPass) lowerCurrent() error {
875	taskType := c.current.Type().ElementType().ReturnType()
876	deleteQueue := []llvm.Value{}
877	for use := c.current.FirstUse(); !use.IsNil(); use = use.NextUse() {
878		// Get user.
879		user := use.User()
880
881		if user.IsACallInst().IsNil() || user.CalledValue() != c.current {
882			return errorAt(user, "unexpected non-call use of task.Current")
883		}
884
885		// Replace with bitcast.
886		c.builder.SetInsertPointBefore(user)
887		raw := user.Operand(1)
888		if !raw.IsAUndefValue().IsNil() || raw.IsNull() {
889			return errors.New("undefined task")
890		}
891		task := c.builder.CreateBitCast(raw, taskType, "task.current")
892		user.ReplaceAllUsesWith(task)
893		deleteQueue = append(deleteQueue, user)
894	}
895
896	// Delete calls.
897	for _, inst := range deleteQueue {
898		inst.EraseFromParentAsInstruction()
899	}
900
901	return nil
902}
903
904// lowerStart lowers a goroutine start into a task creation and call or a synchronous call.
905func (c *coroutineLoweringPass) lowerStart(start llvm.Value) {
906	c.builder.SetInsertPointBefore(start)
907
908	// Get function to call.
909	fn := start.Operand(0).Operand(0)
910
911	if _, ok := c.asyncFuncs[fn]; !ok {
912		// Turn into synchronous call.
913		c.lowerStartSync(start)
914		return
915	}
916
917	// Create the list of params for the call.
918	paramTypes := fn.Type().ElementType().ParamTypes()
919	params := llvmutil.EmitPointerUnpack(c.builder, c.mod, start.Operand(1), paramTypes[:len(paramTypes)-1])
920
921	// Create task.
922	task := c.builder.CreateCall(c.createTask, []llvm.Value{llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "start.task")
923	rawTask := c.builder.CreateBitCast(task, c.i8ptr, "start.task.bitcast")
924	params = append(params, rawTask)
925
926	// Generate a return buffer if necessary.
927	returnType := fn.Type().ElementType().ReturnType()
928	if returnType.TypeKind() == llvm.VoidTypeKind {
929		// No return buffer necessary for a void return.
930	} else {
931		// Check for any undead returns.
932		var undead bool
933		for _, ret := range c.asyncFuncs[fn].returns {
934			if ret.kind != returnDeadTail {
935				// This return results in a value being eventually stored.
936				undead = true
937				break
938			}
939		}
940		if undead {
941			// The function stores a value into a return buffer, so we need to create one.
942			retBuf := c.heapAlloc(returnType, "ret.ditch")
943			c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, retBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
944		}
945	}
946
947	// Generate call to function.
948	c.builder.CreateCall(fn, params, "")
949
950	// Erase start call.
951	start.EraseFromParentAsInstruction()
952}
953
954// lowerStartsPass lowers all goroutine starts.
955func (c *coroutineLoweringPass) lowerStartsPass() {
956	starts := []llvm.Value{}
957	for use := c.start.FirstUse(); !use.IsNil(); use = use.NextUse() {
958		starts = append(starts, use.User())
959	}
960	for _, start := range starts {
961		c.lowerStart(start)
962	}
963}
964
965func (c *coroutineLoweringPass) fixAnnotations() {
966	for f := range c.asyncFuncs {
967		// These properties were added by the functionattrs pass. Remove
968		// them, because now we start using the parameter.
969		// https://llvm.org/docs/Passes.html#functionattrs-deduce-function-attributes
970		for _, kind := range []string{"nocapture", "readnone"} {
971			kindID := llvm.AttributeKindID(kind)
972			n := f.ParamsCount()
973			for i := 0; i <= n; i++ {
974				f.RemoveEnumAttributeAtIndex(i, kindID)
975			}
976		}
977	}
978}
979
980// trackGoroutines adds runtime.trackPointer calls to track goroutine starts and data.
981func (c *coroutineLoweringPass) trackGoroutines() error {
982	trackPointer := c.mod.NamedFunction("runtime.trackPointer")
983	if trackPointer.IsNil() {
984		return ErrMissingIntrinsic{"runtime.trackPointer"}
985	}
986
987	trackFunctions := []llvm.Value{c.createTask, c.setState, c.getRetPtr}
988	for _, fn := range trackFunctions {
989		for use := fn.FirstUse(); !use.IsNil(); use = use.NextUse() {
990			call := use.User()
991
992			c.builder.SetInsertPointBefore(llvm.NextInstruction(call))
993			ptr := call
994			if ptr.Type() != c.i8ptr {
995				ptr = c.builder.CreateBitCast(call, c.i8ptr, "")
996			}
997			c.builder.CreateCall(trackPointer, []llvm.Value{ptr, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "")
998		}
999	}
1000
1001	return nil
1002}
1003