1 //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file contains the IR transform to lower external or indirect calls for
11 /// the ARM64EC calling convention. Such calls must go through the runtime, so
12 /// we can translate the calling convention for calls into the emulator.
13 ///
14 /// This subsumes Control Flow Guard handling.
15 ///
16 //===----------------------------------------------------------------------===//
17
18 #include "AArch64.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/IR/CallingConv.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/Instruction.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Object/COFF.h"
28 #include "llvm/Pass.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/TargetParser/Triple.h"
31
32 using namespace llvm;
33 using namespace llvm::object;
34
35 using OperandBundleDef = OperandBundleDefT<Value *>;
36
37 #define DEBUG_TYPE "arm64eccalllowering"
38
39 STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
40
41 static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
42 cl::Hidden, cl::init(true));
43 static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
44 cl::init(true));
45
46 namespace {
47
48 enum class ThunkType { GuestExit, Entry, Exit };
49
50 class AArch64Arm64ECCallLowering : public ModulePass {
51 public:
52 static char ID;
AArch64Arm64ECCallLowering()53 AArch64Arm64ECCallLowering() : ModulePass(ID) {
54 initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry());
55 }
56
57 Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
58 Function *buildEntryThunk(Function *F);
59 void lowerCall(CallBase *CB);
60 Function *buildGuestExitThunk(Function *F);
61 bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
62 bool runOnModule(Module &M) override;
63
64 private:
65 int cfguard_module_flag = 0;
66 FunctionType *GuardFnType = nullptr;
67 PointerType *GuardFnPtrType = nullptr;
68 Constant *GuardFnCFGlobal = nullptr;
69 Constant *GuardFnGlobal = nullptr;
70 Module *M = nullptr;
71
72 Type *PtrTy;
73 Type *I64Ty;
74 Type *VoidTy;
75
76 void getThunkType(FunctionType *FT, AttributeList AttrList, ThunkType TT,
77 raw_ostream &Out, FunctionType *&Arm64Ty,
78 FunctionType *&X64Ty);
79 void getThunkRetType(FunctionType *FT, AttributeList AttrList,
80 raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
81 SmallVectorImpl<Type *> &Arm64ArgTypes,
82 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr);
83 void getThunkArgTypes(FunctionType *FT, AttributeList AttrList, ThunkType TT,
84 raw_ostream &Out,
85 SmallVectorImpl<Type *> &Arm64ArgTypes,
86 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr);
87 void canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
88 uint64_t ArgSizeBytes, raw_ostream &Out,
89 Type *&Arm64Ty, Type *&X64Ty);
90 };
91
92 } // end anonymous namespace
93
getThunkType(FunctionType * FT,AttributeList AttrList,ThunkType TT,raw_ostream & Out,FunctionType * & Arm64Ty,FunctionType * & X64Ty)94 void AArch64Arm64ECCallLowering::getThunkType(FunctionType *FT,
95 AttributeList AttrList,
96 ThunkType TT, raw_ostream &Out,
97 FunctionType *&Arm64Ty,
98 FunctionType *&X64Ty) {
99 Out << (TT == ThunkType::Entry ? "$ientry_thunk$cdecl$"
100 : "$iexit_thunk$cdecl$");
101
102 Type *Arm64RetTy;
103 Type *X64RetTy;
104
105 SmallVector<Type *> Arm64ArgTypes;
106 SmallVector<Type *> X64ArgTypes;
107
108 // The first argument to a thunk is the called function, stored in x9.
109 // For exit thunks, we pass the called function down to the emulator;
110 // for entry/guest exit thunks, we just call the Arm64 function directly.
111 if (TT == ThunkType::Exit)
112 Arm64ArgTypes.push_back(PtrTy);
113 X64ArgTypes.push_back(PtrTy);
114
115 bool HasSretPtr = false;
116 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
117 X64ArgTypes, HasSretPtr);
118
119 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
120 HasSretPtr);
121
122 Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
123
124 X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
125 }
126
getThunkArgTypes(FunctionType * FT,AttributeList AttrList,ThunkType TT,raw_ostream & Out,SmallVectorImpl<Type * > & Arm64ArgTypes,SmallVectorImpl<Type * > & X64ArgTypes,bool HasSretPtr)127 void AArch64Arm64ECCallLowering::getThunkArgTypes(
128 FunctionType *FT, AttributeList AttrList, ThunkType TT, raw_ostream &Out,
129 SmallVectorImpl<Type *> &Arm64ArgTypes,
130 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) {
131
132 Out << "$";
133 if (FT->isVarArg()) {
134 // We treat the variadic function's thunk as a normal function
135 // with the following type on the ARM side:
136 // rettype exitthunk(
137 // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
138 //
139 // that can coverage all types of variadic function.
140 // x9 is similar to normal exit thunk, store the called function.
141 // x0-x3 is the arguments be stored in registers.
142 // x4 is the address of the arguments on the stack.
143 // x5 is the size of the arguments on the stack.
144 //
145 // On the x64 side, it's the same except that x5 isn't set.
146 //
147 // If both the ARM and X64 sides are sret, there are only three
148 // arguments in registers.
149 //
150 // If the X64 side is sret, but the ARM side isn't, we pass an extra value
151 // to/from the X64 side, and let SelectionDAG transform it into a memory
152 // location.
153 Out << "varargs";
154
155 // x0-x3
156 for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
157 Arm64ArgTypes.push_back(I64Ty);
158 X64ArgTypes.push_back(I64Ty);
159 }
160
161 // x4
162 Arm64ArgTypes.push_back(PtrTy);
163 X64ArgTypes.push_back(PtrTy);
164 // x5
165 Arm64ArgTypes.push_back(I64Ty);
166 if (TT != ThunkType::Entry) {
167 // FIXME: x5 isn't actually used by the x64 side; revisit once we
168 // have proper isel for varargs
169 X64ArgTypes.push_back(I64Ty);
170 }
171 return;
172 }
173
174 unsigned I = 0;
175 if (HasSretPtr)
176 I++;
177
178 if (I == FT->getNumParams()) {
179 Out << "v";
180 return;
181 }
182
183 for (unsigned E = FT->getNumParams(); I != E; ++I) {
184 #if 0
185 // FIXME: Need more information about argument size; see
186 // https://reviews.llvm.org/D132926
187 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
188 Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
189 #else
190 uint64_t ArgSizeBytes = 0;
191 Align ParamAlign = Align();
192 #endif
193 Type *Arm64Ty, *X64Ty;
194 canonicalizeThunkType(FT->getParamType(I), ParamAlign,
195 /*Ret*/ false, ArgSizeBytes, Out, Arm64Ty, X64Ty);
196 Arm64ArgTypes.push_back(Arm64Ty);
197 X64ArgTypes.push_back(X64Ty);
198 }
199 }
200
getThunkRetType(FunctionType * FT,AttributeList AttrList,raw_ostream & Out,Type * & Arm64RetTy,Type * & X64RetTy,SmallVectorImpl<Type * > & Arm64ArgTypes,SmallVectorImpl<Type * > & X64ArgTypes,bool & HasSretPtr)201 void AArch64Arm64ECCallLowering::getThunkRetType(
202 FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
203 Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
204 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) {
205 Type *T = FT->getReturnType();
206 #if 0
207 // FIXME: Need more information about argument size; see
208 // https://reviews.llvm.org/D132926
209 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
210 #else
211 int64_t ArgSizeBytes = 0;
212 #endif
213 if (T->isVoidTy()) {
214 if (FT->getNumParams()) {
215 auto SRetAttr = AttrList.getParamAttr(0, Attribute::StructRet);
216 auto InRegAttr = AttrList.getParamAttr(0, Attribute::InReg);
217 if (SRetAttr.isValid() && InRegAttr.isValid()) {
218 // sret+inreg indicates a call that returns a C++ class value. This is
219 // actually equivalent to just passing and returning a void* pointer
220 // as the first argument. Translate it that way, instead of trying
221 // to model "inreg" in the thunk's calling convention, to simplify
222 // the rest of the code.
223 Out << "i8";
224 Arm64RetTy = I64Ty;
225 X64RetTy = I64Ty;
226 return;
227 }
228 if (SRetAttr.isValid()) {
229 // FIXME: Sanity-check the sret type; if it's an integer or pointer,
230 // we'll get screwy mangling/codegen.
231 // FIXME: For large struct types, mangle as an integer argument and
232 // integer return, so we can reuse more thunks, instead of "m" syntax.
233 // (MSVC mangles this case as an integer return with no argument, but
234 // that's a miscompile.)
235 Type *SRetType = SRetAttr.getValueAsType();
236 Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
237 Type *Arm64Ty, *X64Ty;
238 canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
239 Out, Arm64Ty, X64Ty);
240 Arm64RetTy = VoidTy;
241 X64RetTy = VoidTy;
242 Arm64ArgTypes.push_back(FT->getParamType(0));
243 X64ArgTypes.push_back(FT->getParamType(0));
244 HasSretPtr = true;
245 return;
246 }
247 }
248
249 Out << "v";
250 Arm64RetTy = VoidTy;
251 X64RetTy = VoidTy;
252 return;
253 }
254
255 canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out, Arm64RetTy,
256 X64RetTy);
257 if (X64RetTy->isPointerTy()) {
258 // If the X64 type is canonicalized to a pointer, that means it's
259 // passed/returned indirectly. For a return value, that means it's an
260 // sret pointer.
261 X64ArgTypes.push_back(X64RetTy);
262 X64RetTy = VoidTy;
263 }
264 }
265
canonicalizeThunkType(Type * T,Align Alignment,bool Ret,uint64_t ArgSizeBytes,raw_ostream & Out,Type * & Arm64Ty,Type * & X64Ty)266 void AArch64Arm64ECCallLowering::canonicalizeThunkType(
267 Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out,
268 Type *&Arm64Ty, Type *&X64Ty) {
269 if (T->isFloatTy()) {
270 Out << "f";
271 Arm64Ty = T;
272 X64Ty = T;
273 return;
274 }
275
276 if (T->isDoubleTy()) {
277 Out << "d";
278 Arm64Ty = T;
279 X64Ty = T;
280 return;
281 }
282
283 if (T->isFloatingPointTy()) {
284 report_fatal_error(
285 "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
286 }
287
288 auto &DL = M->getDataLayout();
289
290 if (auto *StructTy = dyn_cast<StructType>(T))
291 if (StructTy->getNumElements() == 1)
292 T = StructTy->getElementType(0);
293
294 if (T->isArrayTy()) {
295 Type *ElementTy = T->getArrayElementType();
296 uint64_t ElementCnt = T->getArrayNumElements();
297 uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
298 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
299 if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
300 Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
301 if (Alignment.value() >= 16 && !Ret)
302 Out << "a" << Alignment.value();
303 Arm64Ty = T;
304 if (TotalSizeBytes <= 8) {
305 // Arm64 returns small structs of float/double in float registers;
306 // X64 uses RAX.
307 X64Ty = llvm::Type::getIntNTy(M->getContext(), TotalSizeBytes * 8);
308 } else {
309 // Struct is passed directly on Arm64, but indirectly on X64.
310 X64Ty = PtrTy;
311 }
312 return;
313 } else if (T->isFloatingPointTy()) {
314 report_fatal_error("Only 32 and 64 bit floating points are supported for "
315 "ARM64EC thunks");
316 }
317 }
318
319 if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
320 Out << "i8";
321 Arm64Ty = I64Ty;
322 X64Ty = I64Ty;
323 return;
324 }
325
326 unsigned TypeSize = ArgSizeBytes;
327 if (TypeSize == 0)
328 TypeSize = DL.getTypeSizeInBits(T) / 8;
329 Out << "m";
330 if (TypeSize != 4)
331 Out << TypeSize;
332 if (Alignment.value() >= 16 && !Ret)
333 Out << "a" << Alignment.value();
334 // FIXME: Try to canonicalize Arm64Ty more thoroughly?
335 Arm64Ty = T;
336 if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
337 // Pass directly in an integer register
338 X64Ty = llvm::Type::getIntNTy(M->getContext(), TypeSize * 8);
339 } else {
340 // Passed directly on Arm64, but indirectly on X64.
341 X64Ty = PtrTy;
342 }
343 }
344
345 // This function builds the "exit thunk", a function which translates
346 // arguments and return values when calling x64 code from AArch64 code.
buildExitThunk(FunctionType * FT,AttributeList Attrs)347 Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
348 AttributeList Attrs) {
349 SmallString<256> ExitThunkName;
350 llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
351 FunctionType *Arm64Ty, *X64Ty;
352 getThunkType(FT, Attrs, ThunkType::Exit, ExitThunkStream, Arm64Ty, X64Ty);
353 if (Function *F = M->getFunction(ExitThunkName))
354 return F;
355
356 Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0,
357 ExitThunkName, M);
358 F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
359 F->setSection(".wowthk$aa");
360 F->setComdat(M->getOrInsertComdat(ExitThunkName));
361 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
362 F->addFnAttr("frame-pointer", "all");
363 // Only copy sret from the first argument. For C++ instance methods, clang can
364 // stick an sret marking on a later argument, but it doesn't actually affect
365 // the ABI, so we can omit it. This avoids triggering a verifier assertion.
366 if (FT->getNumParams()) {
367 auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
368 auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
369 if (SRet.isValid() && !InReg.isValid())
370 F->addParamAttr(1, SRet);
371 }
372 // FIXME: Copy anything other than sret? Shouldn't be necessary for normal
373 // C ABI, but might show up in other cases.
374 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
375 IRBuilder<> IRB(BB);
376 Value *CalleePtr =
377 M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
378 Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
379 auto &DL = M->getDataLayout();
380 SmallVector<Value *> Args;
381
382 // Pass the called function in x9.
383 Args.push_back(F->arg_begin());
384
385 Type *RetTy = Arm64Ty->getReturnType();
386 if (RetTy != X64Ty->getReturnType()) {
387 // If the return type is an array or struct, translate it. Values of size
388 // 8 or less go into RAX; bigger values go into memory, and we pass a
389 // pointer.
390 if (DL.getTypeStoreSize(RetTy) > 8) {
391 Args.push_back(IRB.CreateAlloca(RetTy));
392 }
393 }
394
395 for (auto &Arg : make_range(F->arg_begin() + 1, F->arg_end())) {
396 // Translate arguments from AArch64 calling convention to x86 calling
397 // convention.
398 //
399 // For simple types, we don't need to do any translation: they're
400 // represented the same way. (Implicit sign extension is not part of
401 // either convention.)
402 //
403 // The big thing we have to worry about is struct types... but
404 // fortunately AArch64 clang is pretty friendly here: the cases that need
405 // translation are always passed as a struct or array. (If we run into
406 // some cases where this doesn't work, we can teach clang to mark it up
407 // with an attribute.)
408 //
409 // The first argument is the called function, stored in x9.
410 if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() ||
411 DL.getTypeStoreSize(Arg.getType()) > 8) {
412 Value *Mem = IRB.CreateAlloca(Arg.getType());
413 IRB.CreateStore(&Arg, Mem);
414 if (DL.getTypeStoreSize(Arg.getType()) <= 8) {
415 Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
416 Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
417 } else
418 Args.push_back(Mem);
419 } else {
420 Args.push_back(&Arg);
421 }
422 }
423 // FIXME: Transfer necessary attributes? sret? anything else?
424
425 Callee = IRB.CreateBitCast(Callee, PtrTy);
426 CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
427 Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
428
429 Value *RetVal = Call;
430 if (RetTy != X64Ty->getReturnType()) {
431 // If we rewrote the return type earlier, convert the return value to
432 // the proper type.
433 if (DL.getTypeStoreSize(RetTy) > 8) {
434 RetVal = IRB.CreateLoad(RetTy, Args[1]);
435 } else {
436 Value *CastAlloca = IRB.CreateAlloca(RetTy);
437 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
438 RetVal = IRB.CreateLoad(RetTy, CastAlloca);
439 }
440 }
441
442 if (RetTy->isVoidTy())
443 IRB.CreateRetVoid();
444 else
445 IRB.CreateRet(RetVal);
446 return F;
447 }
448
449 // This function builds the "entry thunk", a function which translates
450 // arguments and return values when calling AArch64 code from x64 code.
buildEntryThunk(Function * F)451 Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
452 SmallString<256> EntryThunkName;
453 llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
454 FunctionType *Arm64Ty, *X64Ty;
455 getThunkType(F->getFunctionType(), F->getAttributes(), ThunkType::Entry,
456 EntryThunkStream, Arm64Ty, X64Ty);
457 if (Function *F = M->getFunction(EntryThunkName))
458 return F;
459
460 Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0,
461 EntryThunkName, M);
462 Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
463 Thunk->setSection(".wowthk$aa");
464 Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
465 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
466 Thunk->addFnAttr("frame-pointer", "all");
467
468 auto &DL = M->getDataLayout();
469 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
470 IRBuilder<> IRB(BB);
471
472 Type *RetTy = Arm64Ty->getReturnType();
473 Type *X64RetType = X64Ty->getReturnType();
474
475 bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
476 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
477 unsigned PassthroughArgSize = F->isVarArg() ? 5 : Thunk->arg_size();
478
479 // Translate arguments to call.
480 SmallVector<Value *> Args;
481 for (unsigned i = ThunkArgOffset, e = PassthroughArgSize; i != e; ++i) {
482 Value *Arg = Thunk->getArg(i);
483 Type *ArgTy = Arm64Ty->getParamType(i - ThunkArgOffset);
484 if (ArgTy->isArrayTy() || ArgTy->isStructTy() ||
485 DL.getTypeStoreSize(ArgTy) > 8) {
486 // Translate array/struct arguments to the expected type.
487 if (DL.getTypeStoreSize(ArgTy) <= 8) {
488 Value *CastAlloca = IRB.CreateAlloca(ArgTy);
489 IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
490 Arg = IRB.CreateLoad(ArgTy, CastAlloca);
491 } else {
492 Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
493 }
494 }
495 Args.push_back(Arg);
496 }
497
498 if (F->isVarArg()) {
499 // The 5th argument to variadic entry thunks is used to model the x64 sp
500 // which is passed to the thunk in x4, this can be passed to the callee as
501 // the variadic argument start address after skipping over the 32 byte
502 // shadow store.
503
504 // The EC thunk CC will assign any argument marked as InReg to x4.
505 Thunk->addParamAttr(5, Attribute::InReg);
506 Value *Arg = Thunk->getArg(5);
507 Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
508 Args.push_back(Arg);
509
510 // Pass in a zero variadic argument size (in x5).
511 Args.push_back(IRB.getInt64(0));
512 }
513
514 // Call the function passed to the thunk.
515 Value *Callee = Thunk->getArg(0);
516 Callee = IRB.CreateBitCast(Callee, PtrTy);
517 CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
518
519 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
520 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
521 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
522 Thunk->addParamAttr(1, SRetAttr);
523 Call->addParamAttr(0, SRetAttr);
524 }
525
526 Value *RetVal = Call;
527 if (TransformDirectToSRet) {
528 IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy));
529 } else if (X64RetType != RetTy) {
530 Value *CastAlloca = IRB.CreateAlloca(X64RetType);
531 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
532 RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
533 }
534
535 // Return to the caller. Note that the isel has code to translate this
536 // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we
537 // could emit a tail call here, but that would require a dedicated calling
538 // convention, which seems more complicated overall.)
539 if (X64RetType->isVoidTy())
540 IRB.CreateRetVoid();
541 else
542 IRB.CreateRet(RetVal);
543
544 return Thunk;
545 }
546
547 // Builds the "guest exit thunk", a helper to call a function which may or may
548 // not be an exit thunk. (We optimistically assume non-dllimport function
549 // declarations refer to functions defined in AArch64 code; if the linker
550 // can't prove that, we use this routine instead.)
buildGuestExitThunk(Function * F)551 Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
552 llvm::raw_null_ostream NullThunkName;
553 FunctionType *Arm64Ty, *X64Ty;
554 getThunkType(F->getFunctionType(), F->getAttributes(), ThunkType::GuestExit,
555 NullThunkName, Arm64Ty, X64Ty);
556 auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
557 assert(MangledName && "Can't guest exit to function that's already native");
558 std::string ThunkName = *MangledName;
559 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
560 ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
561 } else {
562 ThunkName.append("$exit_thunk");
563 }
564 Function *GuestExit =
565 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
566 GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
567 GuestExit->setSection(".wowthk$aa");
568 GuestExit->setMetadata(
569 "arm64ec_unmangled_name",
570 MDNode::get(M->getContext(),
571 MDString::get(M->getContext(), F->getName())));
572 GuestExit->setMetadata(
573 "arm64ec_ecmangled_name",
574 MDNode::get(M->getContext(),
575 MDString::get(M->getContext(), *MangledName)));
576 F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
577 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
578 IRBuilder<> B(BB);
579
580 // Load the global symbol as a pointer to the check function.
581 Value *GuardFn;
582 if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
583 GuardFn = GuardFnCFGlobal;
584 else
585 GuardFn = GuardFnGlobal;
586 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
587
588 // Create new call instruction. The CFGuard check should always be a call,
589 // even if the original CallBase is an Invoke or CallBr instruction.
590 Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
591 CallInst *GuardCheck = B.CreateCall(
592 GuardFnType, GuardCheckLoad,
593 {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())});
594
595 // Ensure that the first argument is passed in the correct register.
596 GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
597
598 Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy);
599 SmallVector<Value *> Args;
600 for (Argument &Arg : GuestExit->args())
601 Args.push_back(&Arg);
602 CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args);
603 Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
604
605 if (Call->getType()->isVoidTy())
606 B.CreateRetVoid();
607 else
608 B.CreateRet(Call);
609
610 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
611 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
612 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
613 GuestExit->addParamAttr(0, SRetAttr);
614 Call->addParamAttr(0, SRetAttr);
615 }
616
617 return GuestExit;
618 }
619
620 // Lower an indirect call with inline code.
lowerCall(CallBase * CB)621 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
622 assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
623 "Only applicable for Windows targets");
624
625 IRBuilder<> B(CB);
626 Value *CalledOperand = CB->getCalledOperand();
627
628 // If the indirect call is called within catchpad or cleanuppad,
629 // we need to copy "funclet" bundle of the call.
630 SmallVector<llvm::OperandBundleDef, 1> Bundles;
631 if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
632 Bundles.push_back(OperandBundleDef(*Bundle));
633
634 // Load the global symbol as a pointer to the check function.
635 Value *GuardFn;
636 if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
637 GuardFn = GuardFnCFGlobal;
638 else
639 GuardFn = GuardFnGlobal;
640 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
641
642 // Create new call instruction. The CFGuard check should always be a call,
643 // even if the original CallBase is an Invoke or CallBr instruction.
644 Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
645 CallInst *GuardCheck =
646 B.CreateCall(GuardFnType, GuardCheckLoad,
647 {B.CreateBitCast(CalledOperand, B.getPtrTy()),
648 B.CreateBitCast(Thunk, B.getPtrTy())},
649 Bundles);
650
651 // Ensure that the first argument is passed in the correct register.
652 GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
653
654 Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType());
655 CB->setCalledOperand(GuardRetVal);
656 }
657
runOnModule(Module & Mod)658 bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
659 if (!GenerateThunks)
660 return false;
661
662 M = &Mod;
663
664 // Check if this module has the cfguard flag and read its value.
665 if (auto *MD =
666 mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
667 cfguard_module_flag = MD->getZExtValue();
668
669 PtrTy = PointerType::getUnqual(M->getContext());
670 I64Ty = Type::getInt64Ty(M->getContext());
671 VoidTy = Type::getVoidTy(M->getContext());
672
673 GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
674 GuardFnPtrType = PointerType::get(GuardFnType, 0);
675 GuardFnCFGlobal =
676 M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
677 GuardFnGlobal =
678 M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
679
680 SetVector<Function *> DirectCalledFns;
681 for (Function &F : Mod)
682 if (!F.isDeclaration() &&
683 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
684 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
685 processFunction(F, DirectCalledFns);
686
687 struct ThunkInfo {
688 Constant *Src;
689 Constant *Dst;
690 unsigned Kind;
691 };
692 SmallVector<ThunkInfo> ThunkMapping;
693 for (Function &F : Mod) {
694 if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
695 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
696 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
697 if (!F.hasComdat())
698 F.setComdat(Mod.getOrInsertComdat(F.getName()));
699 ThunkMapping.push_back({&F, buildEntryThunk(&F), 1});
700 }
701 }
702 for (Function *F : DirectCalledFns) {
703 ThunkMapping.push_back(
704 {F, buildExitThunk(F->getFunctionType(), F->getAttributes()), 4});
705 if (!F->hasDLLImportStorageClass())
706 ThunkMapping.push_back({buildGuestExitThunk(F), F, 0});
707 }
708
709 if (!ThunkMapping.empty()) {
710 SmallVector<Constant *> ThunkMappingArrayElems;
711 for (ThunkInfo &Thunk : ThunkMapping) {
712 ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
713 {ConstantExpr::getBitCast(Thunk.Src, PtrTy),
714 ConstantExpr::getBitCast(Thunk.Dst, PtrTy),
715 ConstantInt::get(M->getContext(), APInt(32, Thunk.Kind))}));
716 }
717 Constant *ThunkMappingArray = ConstantArray::get(
718 llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
719 ThunkMappingArrayElems.size()),
720 ThunkMappingArrayElems);
721 new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
722 GlobalValue::ExternalLinkage, ThunkMappingArray,
723 "llvm.arm64ec.symbolmap");
724 }
725
726 return true;
727 }
728
processFunction(Function & F,SetVector<Function * > & DirectCalledFns)729 bool AArch64Arm64ECCallLowering::processFunction(
730 Function &F, SetVector<Function *> &DirectCalledFns) {
731 SmallVector<CallBase *, 8> IndirectCalls;
732
733 // For ARM64EC targets, a function definition's name is mangled differently
734 // from the normal symbol. We currently have no representation of this sort
735 // of symbol in IR, so we change the name to the mangled name, then store
736 // the unmangled name as metadata. Later passes that need the unmangled
737 // name (emitting the definition) can grab it from the metadata.
738 //
739 // FIXME: Handle functions with weak linkage?
740 if (F.hasExternalLinkage() || F.hasWeakLinkage() || F.hasLinkOnceLinkage()) {
741 if (std::optional<std::string> MangledName =
742 getArm64ECMangledFunctionName(F.getName().str())) {
743 F.setMetadata("arm64ec_unmangled_name",
744 MDNode::get(M->getContext(),
745 MDString::get(M->getContext(), F.getName())));
746 if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
747 Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
748 SmallVector<GlobalObject *> ComdatUsers =
749 to_vector(F.getComdat()->getUsers());
750 for (GlobalObject *User : ComdatUsers)
751 User->setComdat(MangledComdat);
752 }
753 F.setName(MangledName.value());
754 }
755 }
756
757 // Iterate over the instructions to find all indirect call/invoke/callbr
758 // instructions. Make a separate list of pointers to indirect
759 // call/invoke/callbr instructions because the original instructions will be
760 // deleted as the checks are added.
761 for (BasicBlock &BB : F) {
762 for (Instruction &I : BB) {
763 auto *CB = dyn_cast<CallBase>(&I);
764 if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
765 CB->isInlineAsm())
766 continue;
767
768 // We need to instrument any call that isn't directly calling an
769 // ARM64 function.
770 //
771 // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
772 // unprototyped functions in C)
773 if (Function *F = CB->getCalledFunction()) {
774 if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
775 F->isIntrinsic() || !F->isDeclaration())
776 continue;
777
778 DirectCalledFns.insert(F);
779 continue;
780 }
781
782 IndirectCalls.push_back(CB);
783 ++Arm64ECCallsLowered;
784 }
785 }
786
787 if (IndirectCalls.empty())
788 return false;
789
790 for (CallBase *CB : IndirectCalls)
791 lowerCall(CB);
792
793 return true;
794 }
795
796 char AArch64Arm64ECCallLowering::ID = 0;
797 INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
798 "AArch64Arm64ECCallLowering", false, false)
799
createAArch64Arm64ECCallLoweringPass()800 ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() {
801 return new AArch64Arm64ECCallLowering;
802 }
803