1package transform 2 3// This file lowers func values into their final form. This is necessary for 4// funcValueSwitch, which needs full program analysis. 5 6import ( 7 "sort" 8 "strconv" 9 "strings" 10 11 "github.com/tinygo-org/tinygo/compiler/llvmutil" 12 "tinygo.org/x/go-llvm" 13) 14 15// funcSignatureInfo keeps information about a single signature and its uses. 16type funcSignatureInfo struct { 17 sig llvm.Value // *uint8 to identify the signature 18 funcValueWithSignatures []llvm.Value // slice of runtime.funcValueWithSignature 19} 20 21// funcWithUses keeps information about a single function used as func value and 22// the assigned function ID. More commonly used functions are assigned a lower 23// ID. 24type funcWithUses struct { 25 funcPtr llvm.Value 26 useCount int // how often this function is used in a func value 27 id int // assigned ID 28} 29 30// Slice to sort functions by their use counts, or else their name if they're 31// used equally often. 32type funcWithUsesList []*funcWithUses 33 34func (l funcWithUsesList) Len() int { return len(l) } 35func (l funcWithUsesList) Less(i, j int) bool { 36 if l[i].useCount != l[j].useCount { 37 // return the reverse: we want the highest use counts sorted first 38 return l[i].useCount > l[j].useCount 39 } 40 iName := l[i].funcPtr.Name() 41 jName := l[j].funcPtr.Name() 42 return iName < jName 43} 44func (l funcWithUsesList) Swap(i, j int) { 45 l[i], l[j] = l[j], l[i] 46} 47 48// LowerFuncValues lowers the runtime.funcValueWithSignature type and 49// runtime.getFuncPtr function to their final form. 50func LowerFuncValues(mod llvm.Module) { 51 ctx := mod.Context() 52 builder := ctx.NewBuilder() 53 uintptrType := ctx.IntType(llvm.NewTargetData(mod.DataLayout()).PointerSize() * 8) 54 55 // Find all func values used in the program with their signatures. 56 funcValueWithSignaturePtr := llvm.PointerType(mod.GetTypeByName("runtime.funcValueWithSignature"), 0) 57 signatures := map[string]*funcSignatureInfo{} 58 for global := mod.FirstGlobal(); !global.IsNil(); global = llvm.NextGlobal(global) { 59 var sig, funcVal llvm.Value 60 switch { 61 case global.Type() == funcValueWithSignaturePtr: 62 sig = llvm.ConstExtractValue(global.Initializer(), []uint32{1}) 63 funcVal = global 64 case strings.HasPrefix(global.Name(), "reflect/types.type:func:{"): 65 sig = global 66 default: 67 continue 68 } 69 70 name := sig.Name() 71 var funcValueWithSignatures []llvm.Value 72 if funcVal.IsNil() { 73 funcValueWithSignatures = []llvm.Value{} 74 } else { 75 funcValueWithSignatures = []llvm.Value{funcVal} 76 } 77 if info, ok := signatures[name]; ok { 78 info.funcValueWithSignatures = append(info.funcValueWithSignatures, funcValueWithSignatures...) 79 } else { 80 signatures[name] = &funcSignatureInfo{ 81 sig: sig, 82 funcValueWithSignatures: funcValueWithSignatures, 83 } 84 } 85 } 86 87 // Sort the signatures, for deterministic execution. 88 names := make([]string, 0, len(signatures)) 89 for name := range signatures { 90 names = append(names, name) 91 } 92 sort.Strings(names) 93 94 for _, name := range names { 95 info := signatures[name] 96 functions := make(funcWithUsesList, len(info.funcValueWithSignatures)) 97 for i, use := range info.funcValueWithSignatures { 98 var useCount int 99 for _, use2 := range getUses(use) { 100 useCount += len(getUses(use2)) 101 } 102 functions[i] = &funcWithUses{ 103 funcPtr: llvm.ConstExtractValue(use.Initializer(), []uint32{0}).Operand(0), 104 useCount: useCount, 105 } 106 } 107 sort.Sort(functions) 108 109 for i, fn := range functions { 110 fn.id = i + 1 111 for _, ptrtoint := range getUses(fn.funcPtr) { 112 if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt { 113 continue 114 } 115 for _, funcValueWithSignatureConstant := range getUses(ptrtoint) { 116 if !funcValueWithSignatureConstant.IsACallInst().IsNil() && funcValueWithSignatureConstant.CalledValue().Name() == "runtime.makeGoroutine" { 117 // makeGoroutine calls are handled seperately 118 continue 119 } 120 for _, funcValueWithSignatureGlobal := range getUses(funcValueWithSignatureConstant) { 121 for _, use := range getUses(funcValueWithSignatureGlobal) { 122 if ptrtoint.IsAConstantExpr().IsNil() || ptrtoint.Opcode() != llvm.PtrToInt { 123 panic("expected const ptrtoint") 124 } 125 use.ReplaceAllUsesWith(llvm.ConstInt(uintptrType, uint64(fn.id), false)) 126 } 127 } 128 } 129 } 130 } 131 132 for _, getFuncPtrCall := range getUses(info.sig) { 133 if getFuncPtrCall.IsACallInst().IsNil() { 134 continue 135 } 136 if getFuncPtrCall.CalledValue().Name() != "runtime.getFuncPtr" { 137 panic("expected all call uses to be runtime.getFuncPtr") 138 } 139 funcID := getFuncPtrCall.Operand(1) 140 141 // There are functions used in a func value that 142 // implement this signature. 143 // What we'll do is transform the following: 144 // rawPtr := runtime.getFuncPtr(func.ptr) 145 // if rawPtr == nil { 146 // runtime.nilPanic() 147 // } 148 // result := rawPtr(...args, func.context) 149 // into this: 150 // if false { 151 // runtime.nilPanic() 152 // } 153 // var result // Phi 154 // switch fn.id { 155 // case 0: 156 // runtime.nilPanic() 157 // case 1: 158 // result = call first implementation... 159 // case 2: 160 // result = call second implementation... 161 // default: 162 // unreachable 163 // } 164 165 // Remove some casts, checks, and the old call which we're going 166 // to replace. 167 for _, callIntPtr := range getUses(getFuncPtrCall) { 168 if !callIntPtr.IsACallInst().IsNil() && callIntPtr.CalledValue().Name() == "internal/task.start" { 169 // Special case for goroutine starts. 170 addFuncLoweringSwitch(mod, builder, funcID, callIntPtr, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value { 171 i8ptrType := llvm.PointerType(ctx.Int8Type(), 0) 172 calleeValue := builder.CreatePtrToInt(funcPtr, uintptrType, "") 173 start := mod.NamedFunction("internal/task.start") 174 builder.CreateCall(start, []llvm.Value{calleeValue, callIntPtr.Operand(1), llvm.Undef(i8ptrType), llvm.ConstNull(i8ptrType)}, "") 175 return llvm.Value{} // void so no return value 176 }, functions) 177 callIntPtr.EraseFromParentAsInstruction() 178 continue 179 } 180 if callIntPtr.IsAIntToPtrInst().IsNil() { 181 panic("expected inttoptr") 182 } 183 for _, ptrUse := range getUses(callIntPtr) { 184 if !ptrUse.IsAICmpInst().IsNil() { 185 ptrUse.ReplaceAllUsesWith(llvm.ConstInt(ctx.Int1Type(), 0, false)) 186 } else if !ptrUse.IsACallInst().IsNil() && ptrUse.CalledValue() == callIntPtr { 187 addFuncLoweringSwitch(mod, builder, funcID, ptrUse, func(funcPtr llvm.Value, params []llvm.Value) llvm.Value { 188 return builder.CreateCall(funcPtr, params, "") 189 }, functions) 190 } else { 191 panic("unexpected getFuncPtrCall") 192 } 193 ptrUse.EraseFromParentAsInstruction() 194 } 195 callIntPtr.EraseFromParentAsInstruction() 196 } 197 getFuncPtrCall.EraseFromParentAsInstruction() 198 } 199 } 200} 201 202// addFuncLoweringSwitch creates a new switch on a function ID and inserts calls 203// to the newly created direct calls. The funcID is the number to switch on, 204// call is the call instruction to replace, and createCall is the callback that 205// actually creates the new call. By changing createCall to something other than 206// builder.CreateCall, instead of calling a function it can start a new 207// goroutine for example. 208func addFuncLoweringSwitch(mod llvm.Module, builder llvm.Builder, funcID, call llvm.Value, createCall func(funcPtr llvm.Value, params []llvm.Value) llvm.Value, functions funcWithUsesList) { 209 ctx := mod.Context() 210 uintptrType := ctx.IntType(llvm.NewTargetData(mod.DataLayout()).PointerSize() * 8) 211 i8ptrType := llvm.PointerType(ctx.Int8Type(), 0) 212 213 // The block that cannot be reached with correct funcValues (to help the 214 // optimizer). 215 builder.SetInsertPointBefore(call) 216 defaultBlock := ctx.AddBasicBlock(call.InstructionParent().Parent(), "func.default") 217 builder.SetInsertPointAtEnd(defaultBlock) 218 builder.CreateUnreachable() 219 220 // Create the switch. 221 builder.SetInsertPointBefore(call) 222 sw := builder.CreateSwitch(funcID, defaultBlock, len(functions)+1) 223 224 // Split right after the switch. We will need to insert a few basic blocks 225 // in this gap. 226 nextBlock := llvmutil.SplitBasicBlock(builder, sw, llvm.NextBasicBlock(sw.InstructionParent()), "func.next") 227 228 // Temporarily set the insert point to set the correct debug insert location 229 // for the builder. It got destroyed by the SplitBasicBlock call. 230 builder.SetInsertPointBefore(call) 231 232 // The 0 case, which is actually a nil check. 233 nilBlock := ctx.InsertBasicBlock(nextBlock, "func.nil") 234 builder.SetInsertPointAtEnd(nilBlock) 235 nilPanic := mod.NamedFunction("runtime.nilPanic") 236 builder.CreateCall(nilPanic, []llvm.Value{llvm.Undef(i8ptrType), llvm.ConstNull(i8ptrType)}, "") 237 builder.CreateUnreachable() 238 sw.AddCase(llvm.ConstInt(uintptrType, 0, false), nilBlock) 239 240 // Gather the list of parameters for every call we're going to make. 241 callParams := make([]llvm.Value, call.OperandsCount()-1) 242 for i := range callParams { 243 callParams[i] = call.Operand(i) 244 } 245 246 // If the call produces a value, we need to get it using a PHI 247 // node. 248 phiBlocks := make([]llvm.BasicBlock, len(functions)) 249 phiValues := make([]llvm.Value, len(functions)) 250 for i, fn := range functions { 251 // Insert a switch case. 252 bb := ctx.InsertBasicBlock(nextBlock, "func.call"+strconv.Itoa(fn.id)) 253 builder.SetInsertPointAtEnd(bb) 254 result := createCall(fn.funcPtr, callParams) 255 builder.CreateBr(nextBlock) 256 sw.AddCase(llvm.ConstInt(uintptrType, uint64(fn.id), false), bb) 257 phiBlocks[i] = bb 258 phiValues[i] = result 259 } 260 if call.Type().TypeKind() != llvm.VoidTypeKind { 261 if len(functions) > 0 { 262 // Create the PHI node so that the call result flows into the 263 // next block (after the split). This is only necessary when the 264 // call produced a value. 265 builder.SetInsertPointBefore(nextBlock.FirstInstruction()) 266 phi := builder.CreatePHI(call.Type(), "") 267 phi.AddIncoming(phiValues, phiBlocks) 268 call.ReplaceAllUsesWith(phi) 269 } else { 270 // This is always a nil panic, so replace the call result with undef. 271 call.ReplaceAllUsesWith(llvm.Undef(call.Type())) 272 } 273 } 274} 275