1package transform 2 3// This file lowers asynchronous functions and goroutine starts when using the coroutines scheduler. 4// This is accomplished by inserting LLVM intrinsics which are used in order to save the states of functions. 5 6import ( 7 "errors" 8 "strconv" 9 10 "github.com/tinygo-org/tinygo/compiler/llvmutil" 11 "tinygo.org/x/go-llvm" 12) 13 14// LowerCoroutines turns async functions into coroutines. 15// This must be run with the coroutines scheduler. 16// 17// Before this pass, goroutine starts are expressed as a call to an intrinsic called "internal/task.start". 18// This intrinsic accepts the function pointer and a pointer to a struct containing the function's arguments. 19// 20// Before this pass, an intrinsic called "internal/task.Pause" is used to express suspensions of the current goroutine. 21// 22// This pass first accumulates a list of blocking functions. 23// A function is considered "blocking" if it calls "internal/task.Pause" or any other blocking function. 24// 25// Blocking calls are implemented by turning blocking functions into a coroutine. 26// The body of each blocking function is modified to start a new coroutine, and to return after the first suspend. 27// After calling a blocking function, the caller coroutine suspends. 28// The caller also provides a buffer to store the return value into. 29// When a blocking function returns, the return value is written into this buffer and then the caller is queued to run. 30// 31// Goroutine starts which invoke non-blocking functions are implemented as direct calls. 32// Goroutine starts are replaced with the creation of a new task data structure followed by a call to the start of the blocking function. 33// The task structure is populated with a "noop" coroutine before invoking the blocking function. 34// When the blocking function returns, it resumes this "noop" coroutine which does nothing. 35// The goroutine starter is able to continue after the first suspend of the started goroutine. 36// 37// The transformation of a function to a coroutine is accomplished using LLVM's coroutines system (https://llvm.org/docs/Coroutines.html). 38// The simplest implementation of a coroutine inserts a suspend point after every blocking call. 39// 40// Transforming blocking functions into coroutines and calls into suspend points is extremely expensive. 41// In many cases, a blocking call is followed immediately by a function terminator (a return or an "unreachable" instruction). 42// This is a blocking "tail call". 43// In a non-returning tail call (call to a non-returning function, such as an infinite loop), the coroutine can exit without any extra work. 44// In a returning tail call, the returned value must either be the return of the call or a value known before the call. 45// If the return value of the caller is the return of the callee, the coroutine can exit without any extra work and the tailed call will instead return to the caller of the caller. 46// If the return value is known in advance, this result can be stored into the parent's return buffer before the call so that a suspend is unnecessary. 47// If the callee returns an unnecessary value, a return buffer can be allocated on the heap so that it will outlive the coroutine. 48// 49// In the implementation of time.Sleep, the current task is pushed onto a timer queue and then suspended. 50// Since the only suspend point is a call to "internal/task.Pause" followed by a return, there is no need to transform this into a coroutine. 51// This generalizes to all blocking functions in which all suspend points can be elided. 52// This optimization saves a substantial amount of binary size. 53func LowerCoroutines(mod llvm.Module, needStackSlots bool) error { 54 ctx := mod.Context() 55 56 builder := ctx.NewBuilder() 57 defer builder.Dispose() 58 59 target := llvm.NewTargetData(mod.DataLayout()) 60 defer target.Dispose() 61 62 pass := &coroutineLoweringPass{ 63 mod: mod, 64 ctx: ctx, 65 builder: builder, 66 target: target, 67 needStackSlots: needStackSlots, 68 } 69 70 err := pass.load() 71 if err != nil { 72 return err 73 } 74 75 // Supply task operands to async calls. 76 pass.supplyTaskOperands() 77 78 // Analyze async returns. 79 pass.returnAnalysisPass() 80 81 // Categorize async calls. 82 pass.categorizeCalls() 83 84 // Lower async functions. 85 pass.lowerFuncsPass() 86 87 // Lower calls to internal/task.Current. 88 pass.lowerCurrent() 89 90 // Lower goroutine starts. 91 pass.lowerStartsPass() 92 93 // Fix annotations on async call params. 94 pass.fixAnnotations() 95 96 if needStackSlots { 97 // Set up garbage collector tracking of tasks at start. 98 err = pass.trackGoroutines() 99 if err != nil { 100 return err 101 } 102 } 103 104 return nil 105} 106 107// asyncFunc is a metadata container for an asynchronous function. 108type asyncFunc struct { 109 // fn is the underlying function pointer. 110 fn llvm.Value 111 112 // rawTask is the parameter where the task pointer is passed in. 113 rawTask llvm.Value 114 115 // callers is a set of all functions which call this async function. 116 callers map[llvm.Value]struct{} 117 118 // returns is a list of returns in the function, along with metadata. 119 returns []asyncReturn 120 121 // calls is a list of all calls in the asyncFunc. 122 // normalCalls is a list of all intermideate suspending calls in the asyncFunc. 123 // tailCalls is a list of all tail calls in the asyncFunc. 124 calls, normalCalls, tailCalls []llvm.Value 125} 126 127// asyncReturn is a metadata container for a return from an asynchronous function. 128type asyncReturn struct { 129 // block is the basic block terminated by the return. 130 block llvm.BasicBlock 131 132 // kind is the kind of the return. 133 kind returnKind 134} 135 136// coroutineLoweringPass is a goroutine lowering pass which is used with the "coroutines" scheduler. 137type coroutineLoweringPass struct { 138 mod llvm.Module 139 ctx llvm.Context 140 builder llvm.Builder 141 target llvm.TargetData 142 143 // asyncFuncs is a map of all asyncFuncs. 144 // The map keys are function pointers. 145 asyncFuncs map[llvm.Value]*asyncFunc 146 147 asyncFuncsOrdered []*asyncFunc 148 149 // calls is a slice of all of the async calls in the module. 150 calls []llvm.Value 151 152 i8ptr llvm.Type 153 154 // memory management functions from the runtime 155 alloc, free llvm.Value 156 157 // coroutine intrinsics 158 start, pause, current llvm.Value 159 setState, setRetPtr, getRetPtr, returnTo, returnCurrent llvm.Value 160 createTask llvm.Value 161 162 // llvm.coro intrinsics 163 coroId, coroSize, coroBegin, coroSuspend, coroEnd, coroFree, coroSave llvm.Value 164 165 trackPointer llvm.Value 166 needStackSlots bool 167} 168 169// findAsyncFuncs finds all asynchronous functions. 170// A function is considered asynchronous if it calls an asynchronous function or intrinsic. 171func (c *coroutineLoweringPass) findAsyncFuncs() { 172 asyncs := map[llvm.Value]*asyncFunc{} 173 asyncsOrdered := []llvm.Value{} 174 calls := []llvm.Value{} 175 176 // Use a breadth-first search to find all async functions. 177 worklist := []llvm.Value{c.pause} 178 for len(worklist) > 0 { 179 // Pop a function off the worklist. 180 fn := worklist[0] 181 worklist = worklist[1:] 182 183 // Get task pointer argument. 184 task := fn.LastParam() 185 if fn != c.pause && (task.IsNil() || task.Name() != "parentHandle") { 186 panic("trying to make exported function async: " + fn.Name()) 187 } 188 189 // Search all uses of the function while collecting callers. 190 callers := map[llvm.Value]struct{}{} 191 for use := fn.FirstUse(); !use.IsNil(); use = use.NextUse() { 192 user := use.User() 193 if user.IsACallInst().IsNil() { 194 // User is not a call instruction, so this is irrelevant. 195 continue 196 } 197 if user.CalledValue() != fn { 198 // Not the called value. 199 continue 200 } 201 202 // Add to calls list. 203 calls = append(calls, user) 204 205 // Get the caller. 206 caller := user.InstructionParent().Parent() 207 208 // Add as caller. 209 callers[caller] = struct{}{} 210 211 if _, ok := asyncs[caller]; ok { 212 // Already marked caller as async. 213 continue 214 } 215 216 // Mark the caller as async. 217 // Use nil as a temporary value. It will be replaced later. 218 asyncs[caller] = nil 219 asyncsOrdered = append(asyncsOrdered, caller) 220 221 // Put the caller on the worklist. 222 worklist = append(worklist, caller) 223 } 224 225 asyncs[fn] = &asyncFunc{ 226 fn: fn, 227 rawTask: task, 228 callers: callers, 229 } 230 } 231 232 // Flip the order of the async functions so that the top ones are lowered first. 233 for i := 0; i < len(asyncsOrdered)/2; i++ { 234 asyncsOrdered[i], asyncsOrdered[len(asyncsOrdered)-(i+1)] = asyncsOrdered[len(asyncsOrdered)-(i+1)], asyncsOrdered[i] 235 } 236 237 // Map the elements of asyncsOrdered to *asyncFunc. 238 asyncFuncsOrdered := make([]*asyncFunc, len(asyncsOrdered)) 239 for i, v := range asyncsOrdered { 240 asyncFuncsOrdered[i] = asyncs[v] 241 } 242 243 c.asyncFuncs = asyncs 244 c.asyncFuncsOrdered = asyncFuncsOrdered 245 c.calls = calls 246} 247 248func (c *coroutineLoweringPass) load() error { 249 // Find memory management functions from the runtime. 250 c.alloc = c.mod.NamedFunction("runtime.alloc") 251 if c.alloc.IsNil() { 252 return ErrMissingIntrinsic{"runtime.alloc"} 253 } 254 c.free = c.mod.NamedFunction("runtime.free") 255 if c.free.IsNil() { 256 return ErrMissingIntrinsic{"runtime.free"} 257 } 258 259 // Find intrinsics. 260 c.pause = c.mod.NamedFunction("internal/task.Pause") 261 if c.pause.IsNil() { 262 return ErrMissingIntrinsic{"internal/task.Pause"} 263 } 264 c.start = c.mod.NamedFunction("internal/task.start") 265 if c.start.IsNil() { 266 return ErrMissingIntrinsic{"internal/task.start"} 267 } 268 c.current = c.mod.NamedFunction("internal/task.Current") 269 if c.current.IsNil() { 270 return ErrMissingIntrinsic{"internal/task.Current"} 271 } 272 c.setState = c.mod.NamedFunction("(*internal/task.Task).setState") 273 if c.setState.IsNil() { 274 return ErrMissingIntrinsic{"(*internal/task.Task).setState"} 275 } 276 c.setRetPtr = c.mod.NamedFunction("(*internal/task.Task).setReturnPtr") 277 if c.setRetPtr.IsNil() { 278 return ErrMissingIntrinsic{"(*internal/task.Task).setReturnPtr"} 279 } 280 c.getRetPtr = c.mod.NamedFunction("(*internal/task.Task).getReturnPtr") 281 if c.getRetPtr.IsNil() { 282 return ErrMissingIntrinsic{"(*internal/task.Task).getReturnPtr"} 283 } 284 c.returnTo = c.mod.NamedFunction("(*internal/task.Task).returnTo") 285 if c.returnTo.IsNil() { 286 return ErrMissingIntrinsic{"(*internal/task.Task).returnTo"} 287 } 288 c.returnCurrent = c.mod.NamedFunction("(*internal/task.Task).returnCurrent") 289 if c.returnCurrent.IsNil() { 290 return ErrMissingIntrinsic{"(*internal/task.Task).returnCurrent"} 291 } 292 c.createTask = c.mod.NamedFunction("internal/task.createTask") 293 if c.createTask.IsNil() { 294 return ErrMissingIntrinsic{"internal/task.createTask"} 295 } 296 297 if c.needStackSlots { 298 c.trackPointer = c.mod.NamedFunction("runtime.trackPointer") 299 if c.trackPointer.IsNil() { 300 return ErrMissingIntrinsic{"runtime.trackPointer"} 301 } 302 } 303 304 // Find async functions. 305 c.findAsyncFuncs() 306 307 // Get i8* type. 308 c.i8ptr = llvm.PointerType(c.ctx.Int8Type(), 0) 309 310 // Build LLVM coroutine intrinsic. 311 coroIdType := llvm.FunctionType(c.ctx.TokenType(), []llvm.Type{c.ctx.Int32Type(), c.i8ptr, c.i8ptr, c.i8ptr}, false) 312 c.coroId = llvm.AddFunction(c.mod, "llvm.coro.id", coroIdType) 313 314 sizeT := c.alloc.Param(0).Type() 315 coroSizeType := llvm.FunctionType(sizeT, nil, false) 316 c.coroSize = llvm.AddFunction(c.mod, "llvm.coro.size.i"+strconv.Itoa(sizeT.IntTypeWidth()), coroSizeType) 317 318 coroBeginType := llvm.FunctionType(c.i8ptr, []llvm.Type{c.ctx.TokenType(), c.i8ptr}, false) 319 c.coroBegin = llvm.AddFunction(c.mod, "llvm.coro.begin", coroBeginType) 320 321 coroSuspendType := llvm.FunctionType(c.ctx.Int8Type(), []llvm.Type{c.ctx.TokenType(), c.ctx.Int1Type()}, false) 322 c.coroSuspend = llvm.AddFunction(c.mod, "llvm.coro.suspend", coroSuspendType) 323 324 coroEndType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.i8ptr, c.ctx.Int1Type()}, false) 325 c.coroEnd = llvm.AddFunction(c.mod, "llvm.coro.end", coroEndType) 326 327 coroFreeType := llvm.FunctionType(c.i8ptr, []llvm.Type{c.ctx.TokenType(), c.i8ptr}, false) 328 c.coroFree = llvm.AddFunction(c.mod, "llvm.coro.free", coroFreeType) 329 330 coroSaveType := llvm.FunctionType(c.ctx.TokenType(), []llvm.Type{c.i8ptr}, false) 331 c.coroSave = llvm.AddFunction(c.mod, "llvm.coro.save", coroSaveType) 332 333 return nil 334} 335 336func (c *coroutineLoweringPass) track(ptr llvm.Value) { 337 if c.needStackSlots { 338 if ptr.Type() != c.i8ptr { 339 ptr = c.builder.CreateBitCast(ptr, c.i8ptr, "track.bitcast") 340 } 341 c.builder.CreateCall(c.trackPointer, []llvm.Value{ptr, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 342 } 343} 344 345// lowerStartSync lowers a goroutine start of a synchronous function to a synchronous call. 346func (c *coroutineLoweringPass) lowerStartSync(start llvm.Value) { 347 c.builder.SetInsertPointBefore(start) 348 349 // Get function to call. 350 fn := start.Operand(0).Operand(0) 351 352 // Create the list of params for the call. 353 paramTypes := fn.Type().ElementType().ParamTypes() 354 params := llvmutil.EmitPointerUnpack(c.builder, c.mod, start.Operand(1), paramTypes[:len(paramTypes)-1]) 355 params = append(params, llvm.Undef(c.i8ptr)) 356 357 // Generate call to function. 358 c.builder.CreateCall(fn, params, "") 359 360 // Remove start call. 361 start.EraseFromParentAsInstruction() 362} 363 364// supplyTaskOperands fills in the task operands of async calls. 365func (c *coroutineLoweringPass) supplyTaskOperands() { 366 var curCalls []llvm.Value 367 for use := c.current.FirstUse(); !use.IsNil(); use = use.NextUse() { 368 curCalls = append(curCalls, use.User()) 369 } 370 for _, call := range append(curCalls, c.calls...) { 371 c.builder.SetInsertPointBefore(call) 372 task := c.asyncFuncs[call.InstructionParent().Parent()].rawTask 373 call.SetOperand(call.OperandsCount()-2, task) 374 } 375} 376 377// returnKind is a classification of a type of function terminator. 378type returnKind uint8 379 380const ( 381 // returnNormal is a terminator that returns a value normally from a function. 382 returnNormal returnKind = iota 383 384 // returnVoid is a terminator that exits normally without returning a value. 385 returnVoid 386 387 // returnVoidTail is a terminator which is a tail call to a void-returning function in a void-returning function. 388 returnVoidTail 389 390 // returnTail is a terinator which is a tail call to a value-returning function where the value is returned by the callee. 391 returnTail 392 393 // returnDeadTail is a terminator which is a call to a non-returning asynchronous function. 394 returnDeadTail 395 396 // returnAlternateTail is a terminator which is a tail call to a value-returning function where a previously acquired value is returned by the callee. 397 returnAlternateTail 398 399 // returnDitchedTail is a terminator which is a tail call to a value-returning function, where the callee returns void. 400 returnDitchedTail 401 402 // returnDelayedValue is a terminator in which a void-returning tail call is followed by a return of a previous value. 403 returnDelayedValue 404) 405 406// isAsyncCall returns whether the specified call is async. 407func (c *coroutineLoweringPass) isAsyncCall(call llvm.Value) bool { 408 _, ok := c.asyncFuncs[call.CalledValue()] 409 return ok 410} 411 412// analyzeFuncReturns analyzes and classifies the returns of a function. 413func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) { 414 returns := []asyncReturn{} 415 if fn.fn == c.pause { 416 // Skip pause. 417 fn.returns = returns 418 return 419 } 420 421 for _, bb := range fn.fn.BasicBlocks() { 422 last := bb.LastInstruction() 423 switch last.InstructionOpcode() { 424 case llvm.Ret: 425 // Check if it is a void return. 426 isVoid := fn.fn.Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind 427 428 // Analyze previous instruction. 429 prev := llvm.PrevInstruction(last) 430 switch { 431 case prev.IsNil(): 432 fallthrough 433 case prev.IsACallInst().IsNil(): 434 fallthrough 435 case !c.isAsyncCall(prev): 436 // This is not any form of asynchronous tail call. 437 if isVoid { 438 returns = append(returns, asyncReturn{ 439 block: bb, 440 kind: returnVoid, 441 }) 442 } else { 443 returns = append(returns, asyncReturn{ 444 block: bb, 445 kind: returnNormal, 446 }) 447 } 448 case isVoid: 449 if prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind { 450 // This is a tail call to a void-returning function from a function with a void return. 451 returns = append(returns, asyncReturn{ 452 block: bb, 453 kind: returnVoidTail, 454 }) 455 } else { 456 // This is a tail call to a value-returning function from a function with a void return. 457 // The returned value will be ditched. 458 returns = append(returns, asyncReturn{ 459 block: bb, 460 kind: returnDitchedTail, 461 }) 462 } 463 case last.Operand(0) == prev: 464 // This is a regular tail call. The return of the callee is returned to the parent. 465 returns = append(returns, asyncReturn{ 466 block: bb, 467 kind: returnTail, 468 }) 469 case prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind: 470 // This is a tail call that returns a previous value after waiting on a void function. 471 returns = append(returns, asyncReturn{ 472 block: bb, 473 kind: returnDelayedValue, 474 }) 475 default: 476 // This is a tail call that returns a value that is available before the function call. 477 returns = append(returns, asyncReturn{ 478 block: bb, 479 kind: returnAlternateTail, 480 }) 481 } 482 case llvm.Unreachable: 483 prev := llvm.PrevInstruction(last) 484 485 if prev.IsNil() || prev.IsACallInst().IsNil() || !c.isAsyncCall(prev) { 486 // This unreachable instruction does not behave as an asynchronous return. 487 continue 488 } 489 490 // This is an asyncnhronous tail call to function that does not return. 491 returns = append(returns, asyncReturn{ 492 block: bb, 493 kind: returnDeadTail, 494 }) 495 } 496 } 497 498 fn.returns = returns 499} 500 501// returnAnalysisPass runs an analysis pass which classifies the returns of all async functions. 502func (c *coroutineLoweringPass) returnAnalysisPass() { 503 for _, async := range c.asyncFuncsOrdered { 504 c.analyzeFuncReturns(async) 505 } 506} 507 508// categorizeCalls categorizes all asynchronous calls into regular vs. async and matches them to their callers. 509func (c *coroutineLoweringPass) categorizeCalls() { 510 // Sort calls into their respective callers. 511 for _, call := range c.calls { 512 caller := c.asyncFuncs[call.InstructionParent().Parent()] 513 caller.calls = append(caller.calls, call) 514 } 515 516 // Seperate regular and tail calls. 517 for _, async := range c.asyncFuncsOrdered { 518 // Search returns for tail calls. 519 tails := map[llvm.Value]struct{}{} 520 for _, ret := range async.returns { 521 switch ret.kind { 522 case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue: 523 // This is a tail return. The previous instruction is a tail call. 524 tails[llvm.PrevInstruction(ret.block.LastInstruction())] = struct{}{} 525 } 526 } 527 528 // Seperate tail calls and regular calls. 529 normalCalls, tailCalls := []llvm.Value{}, []llvm.Value{} 530 for _, call := range async.calls { 531 if _, ok := tails[call]; ok { 532 // This is a tail call. 533 tailCalls = append(tailCalls, call) 534 } else { 535 // This is a regular call. 536 normalCalls = append(normalCalls, call) 537 } 538 } 539 540 async.normalCalls = normalCalls 541 async.tailCalls = tailCalls 542 } 543} 544 545// lowerFuncsPass lowers all functions, turning them into coroutines if necessary. 546func (c *coroutineLoweringPass) lowerFuncsPass() { 547 for _, fn := range c.asyncFuncs { 548 if fn.fn == c.pause { 549 // Skip. It is an intrinsic. 550 continue 551 } 552 553 if len(fn.normalCalls) == 0 { 554 // No suspend points. Lower without turning it into a coroutine. 555 c.lowerFuncFast(fn) 556 } else { 557 // There are suspend points, so it is necessary to turn this into a coroutine. 558 c.lowerFuncCoro(fn) 559 } 560 } 561} 562 563func (async *asyncFunc) hasValueStoreReturn() bool { 564 for _, ret := range async.returns { 565 switch ret.kind { 566 case returnNormal, returnAlternateTail, returnDelayedValue: 567 return true 568 } 569 } 570 571 return false 572} 573 574// heapAlloc creates a heap allocation large enough to hold the supplied type. 575// The allocation is returned as a raw i8* pointer. 576// This allocation is not automatically tracked by the garbage collector, and should thus be stored into a tracked memory object immediately. 577func (c *coroutineLoweringPass) heapAlloc(t llvm.Type, name string) llvm.Value { 578 sizeT := c.alloc.FirstParam().Type() 579 size := llvm.ConstInt(sizeT, c.target.TypeAllocSize(t), false) 580 return c.builder.CreateCall(c.alloc, []llvm.Value{size, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, name) 581} 582 583// lowerFuncFast lowers an async function that has no suspend points. 584func (c *coroutineLoweringPass) lowerFuncFast(fn *asyncFunc) { 585 // Get return type. 586 retType := fn.fn.Type().ElementType().ReturnType() 587 588 // Get task value. 589 c.insertPointAfterAllocas(fn.fn) 590 task := c.builder.CreateCall(c.current, []llvm.Value{llvm.Undef(c.i8ptr), fn.rawTask}, "task") 591 592 // Get return pointer if applicable. 593 var rawRetPtr, retPtr llvm.Value 594 if fn.hasValueStoreReturn() { 595 rawRetPtr = c.builder.CreateCall(c.getRetPtr, []llvm.Value{task, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "ret.ptr") 596 retType = fn.fn.Type().ElementType().ReturnType() 597 retPtr = c.builder.CreateBitCast(rawRetPtr, llvm.PointerType(retType, 0), "ret.ptr.bitcast") 598 } 599 600 // Lower returns. 601 for _, ret := range fn.returns { 602 // Get terminator. 603 terminator := ret.block.LastInstruction() 604 605 // Get tail call if applicable. 606 var call llvm.Value 607 switch ret.kind { 608 case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue: 609 call = llvm.PrevInstruction(terminator) 610 } 611 612 switch ret.kind { 613 case returnNormal: 614 c.builder.SetInsertPointBefore(terminator) 615 616 // Store value into return pointer. 617 c.builder.CreateStore(terminator.Operand(0), retPtr) 618 619 // Resume caller. 620 c.builder.CreateCall(c.returnCurrent, []llvm.Value{task, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 621 622 // Erase return argument. 623 terminator.SetOperand(0, llvm.Undef(retType)) 624 case returnVoid: 625 c.builder.SetInsertPointBefore(terminator) 626 627 // Resume caller. 628 c.builder.CreateCall(c.returnCurrent, []llvm.Value{task, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 629 case returnVoidTail: 630 // Nothing to do. There is already a tail call followed by a void return. 631 case returnTail: 632 // Erase return argument. 633 terminator.SetOperand(0, llvm.Undef(retType)) 634 case returnDeadTail: 635 // Replace unreachable with immediate return, without resuming the caller. 636 c.builder.SetInsertPointBefore(terminator) 637 if retType.TypeKind() == llvm.VoidTypeKind { 638 c.builder.CreateRetVoid() 639 } else { 640 c.builder.CreateRet(llvm.Undef(retType)) 641 } 642 terminator.EraseFromParentAsInstruction() 643 case returnAlternateTail: 644 c.builder.SetInsertPointBefore(call) 645 646 // Store return value. 647 c.builder.CreateStore(terminator.Operand(0), retPtr) 648 649 // Heap-allocate a return buffer for the discarded return. 650 alternateBuf := c.heapAlloc(call.Type(), "ret.alternate") 651 c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, alternateBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 652 653 // Erase return argument. 654 terminator.SetOperand(0, llvm.Undef(retType)) 655 case returnDitchedTail: 656 c.builder.SetInsertPointBefore(call) 657 658 // Heap-allocate a return buffer for the discarded return. 659 ditchBuf := c.heapAlloc(call.Type(), "ret.ditch") 660 c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, ditchBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 661 case returnDelayedValue: 662 c.builder.SetInsertPointBefore(call) 663 664 // Store value into return pointer. 665 c.builder.CreateStore(terminator.Operand(0), retPtr) 666 667 // Erase return argument. 668 terminator.SetOperand(0, llvm.Undef(retType)) 669 } 670 671 // Delete call if it is a pause, because it has already been lowered. 672 if !call.IsNil() && call.CalledValue() == c.pause { 673 call.EraseFromParentAsInstruction() 674 } 675 } 676} 677 678// insertPointAfterAllocas sets the insert point of the builder to be immediately after the last alloca in the entry block. 679func (c *coroutineLoweringPass) insertPointAfterAllocas(fn llvm.Value) { 680 inst := fn.EntryBasicBlock().FirstInstruction() 681 for !inst.IsAAllocaInst().IsNil() { 682 inst = llvm.NextInstruction(inst) 683 } 684 c.builder.SetInsertPointBefore(inst) 685} 686 687// lowerCallReturn lowers the return value of an async call by creating a return buffer and loading the returned value from it. 688func (c *coroutineLoweringPass) lowerCallReturn(caller *asyncFunc, call llvm.Value) { 689 // Get return type. 690 retType := call.Type() 691 if retType.TypeKind() == llvm.VoidTypeKind { 692 // Void return. Nothing to do. 693 return 694 } 695 696 // Create alloca for return buffer. 697 alloca := llvmutil.CreateInstructionAlloca(c.builder, c.mod, retType, call, "call.return") 698 699 // Store new return buffer into task before call. 700 c.builder.SetInsertPointBefore(call) 701 task := c.builder.CreateCall(c.current, []llvm.Value{llvm.Undef(c.i8ptr), caller.rawTask}, "call.task") 702 retPtr := c.builder.CreateBitCast(alloca, c.i8ptr, "call.return.bitcast") 703 c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, retPtr, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 704 705 // Load return value after call. 706 c.builder.SetInsertPointBefore(llvm.NextInstruction(call)) 707 ret := c.builder.CreateLoad(alloca, "call.return.load") 708 709 // Replace call value with loaded return. 710 call.ReplaceAllUsesWith(ret) 711} 712 713// lowerFuncCoro transforms an async function into a coroutine by lowering async operations to `llvm.coro` intrinsics. 714// See https://llvm.org/docs/Coroutines.html for more information on these intrinsics. 715func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) { 716 returnType := fn.fn.Type().ElementType().ReturnType() 717 718 // Prepare coroutine state. 719 c.insertPointAfterAllocas(fn.fn) 720 // %coro.id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) 721 coroId := c.builder.CreateCall(c.coroId, []llvm.Value{ 722 llvm.ConstInt(c.ctx.Int32Type(), 0, false), 723 llvm.ConstNull(c.i8ptr), 724 llvm.ConstNull(c.i8ptr), 725 llvm.ConstNull(c.i8ptr), 726 }, "coro.id") 727 // %coro.size = call i32 @llvm.coro.size.i32() 728 coroSize := c.builder.CreateCall(c.coroSize, []llvm.Value{}, "coro.size") 729 // %coro.alloc = call i8* runtime.alloc(i32 %coro.size) 730 coroAlloc := c.builder.CreateCall(c.alloc, []llvm.Value{coroSize, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "coro.alloc") 731 // %coro.state = call noalias i8* @llvm.coro.begin(token %coro.id, i8* %coro.alloc) 732 coroState := c.builder.CreateCall(c.coroBegin, []llvm.Value{coroId, coroAlloc}, "coro.state") 733 c.track(coroState) 734 // Store state into task. 735 task := c.builder.CreateCall(c.current, []llvm.Value{llvm.Undef(c.i8ptr), fn.rawTask}, "task") 736 parentState := c.builder.CreateCall(c.setState, []llvm.Value{task, coroState, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "task.state.parent") 737 // Get return pointer if needed. 738 var retPtr llvm.Value 739 if fn.hasValueStoreReturn() { 740 retPtr = c.builder.CreateCall(c.getRetPtr, []llvm.Value{task, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "task.retPtr") 741 retPtr = c.builder.CreateBitCast(retPtr, llvm.PointerType(fn.fn.Type().ElementType().ReturnType(), 0), "task.retPtr.bitcast") 742 } 743 744 // Build suspend block. 745 // This is executed when the coroutine is about to suspend. 746 suspend := c.ctx.AddBasicBlock(fn.fn, "suspend") 747 c.builder.SetInsertPointAtEnd(suspend) 748 // %unused = call i1 @llvm.coro.end(i8* %coro.state, i1 false) 749 c.builder.CreateCall(c.coroEnd, []llvm.Value{coroState, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "unused") 750 // Insert return. 751 if returnType.TypeKind() == llvm.VoidTypeKind { 752 c.builder.CreateRetVoid() 753 } else { 754 c.builder.CreateRet(llvm.Undef(returnType)) 755 } 756 757 // Build cleanup block. 758 // This is executed before the function returns in order to clean up resources. 759 cleanup := c.ctx.AddBasicBlock(fn.fn, "cleanup") 760 c.builder.SetInsertPointAtEnd(cleanup) 761 // %coro.memFree = call i8* @llvm.coro.free(token %coro.id, i8* %coro.state) 762 coroMemFree := c.builder.CreateCall(c.coroFree, []llvm.Value{coroId, coroState}, "coro.memFree") 763 // call i8* runtime.free(i8* %coro.memFree) 764 c.builder.CreateCall(c.free, []llvm.Value{coroMemFree, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 765 // Branch to suspend block. 766 c.builder.CreateBr(suspend) 767 768 // Restore old state before tail calls. 769 for _, call := range fn.tailCalls { 770 if !llvm.NextInstruction(call).IsAUnreachableInst().IsNil() { 771 // Callee never returns, so the state restore is ineffectual. 772 continue 773 } 774 775 c.builder.SetInsertPointBefore(call) 776 c.builder.CreateCall(c.setState, []llvm.Value{task, parentState, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "coro.state.restore") 777 } 778 779 // Lower returns. 780 for _, ret := range fn.returns { 781 // Get terminator instruction. 782 terminator := ret.block.LastInstruction() 783 784 // Get tail call if applicable. 785 var call llvm.Value 786 switch ret.kind { 787 case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue: 788 call = llvm.PrevInstruction(terminator) 789 } 790 791 switch ret.kind { 792 case returnNormal: 793 c.builder.SetInsertPointBefore(terminator) 794 795 // Store value into return pointer. 796 c.builder.CreateStore(terminator.Operand(0), retPtr) 797 798 // Resume caller. 799 c.builder.CreateCall(c.returnTo, []llvm.Value{task, parentState, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 800 case returnVoid: 801 c.builder.SetInsertPointBefore(terminator) 802 803 // Resume caller. 804 c.builder.CreateCall(c.returnTo, []llvm.Value{task, parentState, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 805 case returnVoidTail, returnTail, returnDeadTail: 806 // Nothing to do. 807 case returnAlternateTail: 808 c.builder.SetInsertPointBefore(call) 809 810 // Store return value. 811 c.builder.CreateStore(terminator.Operand(0), retPtr) 812 813 // Heap-allocate a return buffer for the discarded return. 814 alternateBuf := c.heapAlloc(call.Type(), "ret.alternate") 815 c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, alternateBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 816 case returnDitchedTail: 817 c.builder.SetInsertPointBefore(call) 818 819 // Heap-allocate a return buffer for the discarded return. 820 ditchBuf := c.heapAlloc(call.Type(), "ret.ditch") 821 c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, ditchBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 822 case returnDelayedValue: 823 c.builder.SetInsertPointBefore(call) 824 825 // Store return value. 826 c.builder.CreateStore(terminator.Operand(0), retPtr) 827 } 828 829 // Delete call if it is a pause, because it has already been lowered. 830 if !call.IsNil() && call.CalledValue() == c.pause { 831 call.EraseFromParentAsInstruction() 832 } 833 834 // Replace terminator with branch to cleanup. 835 terminator.EraseFromParentAsInstruction() 836 c.builder.SetInsertPointAtEnd(ret.block) 837 c.builder.CreateBr(cleanup) 838 } 839 840 // Lower regular calls. 841 for _, call := range fn.normalCalls { 842 // Lower return value of call. 843 c.lowerCallReturn(fn, call) 844 845 // Get originating basic block. 846 bb := call.InstructionParent() 847 848 // Split block. 849 wakeup := llvmutil.SplitBasicBlock(c.builder, call, llvm.NextBasicBlock(bb), "wakeup") 850 851 // Insert suspension and switch. 852 c.builder.SetInsertPointAtEnd(bb) 853 // %coro.save = call token @llvm.coro.save(i8* %coro.state) 854 save := c.builder.CreateCall(c.coroSave, []llvm.Value{coroState}, "coro.save") 855 // %call.suspend = llvm.coro.suspend(token %coro.save, i1 false) 856 // switch i8 %call.suspend, label %suspend [i8 0, label %wakeup 857 // i8 1, label %cleanup] 858 suspendValue := c.builder.CreateCall(c.coroSuspend, []llvm.Value{save, llvm.ConstInt(c.ctx.Int1Type(), 0, false)}, "call.suspend") 859 sw := c.builder.CreateSwitch(suspendValue, suspend, 2) 860 sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), wakeup) 861 sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), cleanup) 862 863 // Delete call if it is a pause, because it has already been lowered. 864 if call.CalledValue() == c.pause { 865 call.EraseFromParentAsInstruction() 866 } 867 868 c.builder.SetInsertPointBefore(wakeup.FirstInstruction()) 869 c.track(coroState) 870 } 871} 872 873// lowerCurrent lowers calls to internal/task.Current to bitcasts. 874func (c *coroutineLoweringPass) lowerCurrent() error { 875 taskType := c.current.Type().ElementType().ReturnType() 876 deleteQueue := []llvm.Value{} 877 for use := c.current.FirstUse(); !use.IsNil(); use = use.NextUse() { 878 // Get user. 879 user := use.User() 880 881 if user.IsACallInst().IsNil() || user.CalledValue() != c.current { 882 return errorAt(user, "unexpected non-call use of task.Current") 883 } 884 885 // Replace with bitcast. 886 c.builder.SetInsertPointBefore(user) 887 raw := user.Operand(1) 888 if !raw.IsAUndefValue().IsNil() || raw.IsNull() { 889 return errors.New("undefined task") 890 } 891 task := c.builder.CreateBitCast(raw, taskType, "task.current") 892 user.ReplaceAllUsesWith(task) 893 deleteQueue = append(deleteQueue, user) 894 } 895 896 // Delete calls. 897 for _, inst := range deleteQueue { 898 inst.EraseFromParentAsInstruction() 899 } 900 901 return nil 902} 903 904// lowerStart lowers a goroutine start into a task creation and call or a synchronous call. 905func (c *coroutineLoweringPass) lowerStart(start llvm.Value) { 906 c.builder.SetInsertPointBefore(start) 907 908 // Get function to call. 909 fn := start.Operand(0).Operand(0) 910 911 if _, ok := c.asyncFuncs[fn]; !ok { 912 // Turn into synchronous call. 913 c.lowerStartSync(start) 914 return 915 } 916 917 // Create the list of params for the call. 918 paramTypes := fn.Type().ElementType().ParamTypes() 919 params := llvmutil.EmitPointerUnpack(c.builder, c.mod, start.Operand(1), paramTypes[:len(paramTypes)-1]) 920 921 // Create task. 922 task := c.builder.CreateCall(c.createTask, []llvm.Value{llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "start.task") 923 rawTask := c.builder.CreateBitCast(task, c.i8ptr, "start.task.bitcast") 924 params = append(params, rawTask) 925 926 // Generate a return buffer if necessary. 927 returnType := fn.Type().ElementType().ReturnType() 928 if returnType.TypeKind() == llvm.VoidTypeKind { 929 // No return buffer necessary for a void return. 930 } else { 931 // Check for any undead returns. 932 var undead bool 933 for _, ret := range c.asyncFuncs[fn].returns { 934 if ret.kind != returnDeadTail { 935 // This return results in a value being eventually stored. 936 undead = true 937 break 938 } 939 } 940 if undead { 941 // The function stores a value into a return buffer, so we need to create one. 942 retBuf := c.heapAlloc(returnType, "ret.ditch") 943 c.builder.CreateCall(c.setRetPtr, []llvm.Value{task, retBuf, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 944 } 945 } 946 947 // Generate call to function. 948 c.builder.CreateCall(fn, params, "") 949 950 // Erase start call. 951 start.EraseFromParentAsInstruction() 952} 953 954// lowerStartsPass lowers all goroutine starts. 955func (c *coroutineLoweringPass) lowerStartsPass() { 956 starts := []llvm.Value{} 957 for use := c.start.FirstUse(); !use.IsNil(); use = use.NextUse() { 958 starts = append(starts, use.User()) 959 } 960 for _, start := range starts { 961 c.lowerStart(start) 962 } 963} 964 965func (c *coroutineLoweringPass) fixAnnotations() { 966 for f := range c.asyncFuncs { 967 // These properties were added by the functionattrs pass. Remove 968 // them, because now we start using the parameter. 969 // https://llvm.org/docs/Passes.html#functionattrs-deduce-function-attributes 970 for _, kind := range []string{"nocapture", "readnone"} { 971 kindID := llvm.AttributeKindID(kind) 972 n := f.ParamsCount() 973 for i := 0; i <= n; i++ { 974 f.RemoveEnumAttributeAtIndex(i, kindID) 975 } 976 } 977 } 978} 979 980// trackGoroutines adds runtime.trackPointer calls to track goroutine starts and data. 981func (c *coroutineLoweringPass) trackGoroutines() error { 982 trackPointer := c.mod.NamedFunction("runtime.trackPointer") 983 if trackPointer.IsNil() { 984 return ErrMissingIntrinsic{"runtime.trackPointer"} 985 } 986 987 trackFunctions := []llvm.Value{c.createTask, c.setState, c.getRetPtr} 988 for _, fn := range trackFunctions { 989 for use := fn.FirstUse(); !use.IsNil(); use = use.NextUse() { 990 call := use.User() 991 992 c.builder.SetInsertPointBefore(llvm.NextInstruction(call)) 993 ptr := call 994 if ptr.Type() != c.i8ptr { 995 ptr = c.builder.CreateBitCast(call, c.i8ptr, "") 996 } 997 c.builder.CreateCall(trackPointer, []llvm.Value{ptr, llvm.Undef(c.i8ptr), llvm.Undef(c.i8ptr)}, "") 998 } 999 } 1000 1001 return nil 1002} 1003