1 //===--- ExpandReductions.cpp - Expand experimental reduction intrinsics --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements IR expansion for reduction intrinsics, allowing targets
10 // to enable the intrinsics until just before codegen.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/CodeGen/ExpandReductions.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/CodeGen/Passes.h"
17 #include "llvm/IR/IRBuilder.h"
18 #include "llvm/IR/InstIterator.h"
19 #include "llvm/IR/IntrinsicInst.h"
20 #include "llvm/IR/Intrinsics.h"
21 #include "llvm/InitializePasses.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Transforms/Utils/LoopUtils.h"
24 
25 using namespace llvm;
26 
27 namespace {
28 
29 unsigned getOpcode(Intrinsic::ID ID) {
30   switch (ID) {
31   case Intrinsic::vector_reduce_fadd:
32     return Instruction::FAdd;
33   case Intrinsic::vector_reduce_fmul:
34     return Instruction::FMul;
35   case Intrinsic::vector_reduce_add:
36     return Instruction::Add;
37   case Intrinsic::vector_reduce_mul:
38     return Instruction::Mul;
39   case Intrinsic::vector_reduce_and:
40     return Instruction::And;
41   case Intrinsic::vector_reduce_or:
42     return Instruction::Or;
43   case Intrinsic::vector_reduce_xor:
44     return Instruction::Xor;
45   case Intrinsic::vector_reduce_smax:
46   case Intrinsic::vector_reduce_smin:
47   case Intrinsic::vector_reduce_umax:
48   case Intrinsic::vector_reduce_umin:
49     return Instruction::ICmp;
50   case Intrinsic::vector_reduce_fmax:
51   case Intrinsic::vector_reduce_fmin:
52     return Instruction::FCmp;
53   default:
54     llvm_unreachable("Unexpected ID");
55   }
56 }
57 
58 RecurKind getRK(Intrinsic::ID ID) {
59   switch (ID) {
60   case Intrinsic::vector_reduce_smax:
61     return RecurKind::SMax;
62   case Intrinsic::vector_reduce_smin:
63     return RecurKind::SMin;
64   case Intrinsic::vector_reduce_umax:
65     return RecurKind::UMax;
66   case Intrinsic::vector_reduce_umin:
67     return RecurKind::UMin;
68   case Intrinsic::vector_reduce_fmax:
69     return RecurKind::FMax;
70   case Intrinsic::vector_reduce_fmin:
71     return RecurKind::FMin;
72   default:
73     return RecurKind::None;
74   }
75 }
76 
77 bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
78   bool Changed = false;
79   SmallVector<IntrinsicInst *, 4> Worklist;
80   for (auto &I : instructions(F)) {
81     if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
82       switch (II->getIntrinsicID()) {
83       default: break;
84       case Intrinsic::vector_reduce_fadd:
85       case Intrinsic::vector_reduce_fmul:
86       case Intrinsic::vector_reduce_add:
87       case Intrinsic::vector_reduce_mul:
88       case Intrinsic::vector_reduce_and:
89       case Intrinsic::vector_reduce_or:
90       case Intrinsic::vector_reduce_xor:
91       case Intrinsic::vector_reduce_smax:
92       case Intrinsic::vector_reduce_smin:
93       case Intrinsic::vector_reduce_umax:
94       case Intrinsic::vector_reduce_umin:
95       case Intrinsic::vector_reduce_fmax:
96       case Intrinsic::vector_reduce_fmin:
97         if (TTI->shouldExpandReduction(II))
98           Worklist.push_back(II);
99 
100         break;
101       }
102     }
103   }
104 
105   for (auto *II : Worklist) {
106     FastMathFlags FMF =
107         isa<FPMathOperator>(II) ? II->getFastMathFlags() : FastMathFlags{};
108     Intrinsic::ID ID = II->getIntrinsicID();
109     RecurKind RK = getRK(ID);
110 
111     Value *Rdx = nullptr;
112     IRBuilder<> Builder(II);
113     IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
114     Builder.setFastMathFlags(FMF);
115     switch (ID) {
116     default: llvm_unreachable("Unexpected intrinsic!");
117     case Intrinsic::vector_reduce_fadd:
118     case Intrinsic::vector_reduce_fmul: {
119       // FMFs must be attached to the call, otherwise it's an ordered reduction
120       // and it can't be handled by generating a shuffle sequence.
121       Value *Acc = II->getArgOperand(0);
122       Value *Vec = II->getArgOperand(1);
123       if (!FMF.allowReassoc())
124         Rdx = getOrderedReduction(Builder, Acc, Vec, getOpcode(ID), RK);
125       else {
126         if (!isPowerOf2_32(
127                 cast<FixedVectorType>(Vec->getType())->getNumElements()))
128           continue;
129 
130         Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
131         Rdx = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(ID),
132                                   Acc, Rdx, "bin.rdx");
133       }
134       break;
135     }
136     case Intrinsic::vector_reduce_add:
137     case Intrinsic::vector_reduce_mul:
138     case Intrinsic::vector_reduce_and:
139     case Intrinsic::vector_reduce_or:
140     case Intrinsic::vector_reduce_xor:
141     case Intrinsic::vector_reduce_smax:
142     case Intrinsic::vector_reduce_smin:
143     case Intrinsic::vector_reduce_umax:
144     case Intrinsic::vector_reduce_umin: {
145       Value *Vec = II->getArgOperand(0);
146       if (!isPowerOf2_32(
147               cast<FixedVectorType>(Vec->getType())->getNumElements()))
148         continue;
149 
150       Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
151       break;
152     }
153     case Intrinsic::vector_reduce_fmax:
154     case Intrinsic::vector_reduce_fmin: {
155       // We require "nnan" to use a shuffle reduction; "nsz" is implied by the
156       // semantics of the reduction.
157       Value *Vec = II->getArgOperand(0);
158       if (!isPowerOf2_32(
159               cast<FixedVectorType>(Vec->getType())->getNumElements()) ||
160           !FMF.noNaNs())
161         continue;
162 
163       Rdx = getShuffleReduction(Builder, Vec, getOpcode(ID), RK);
164       break;
165     }
166     }
167     II->replaceAllUsesWith(Rdx);
168     II->eraseFromParent();
169     Changed = true;
170   }
171   return Changed;
172 }
173 
174 class ExpandReductions : public FunctionPass {
175 public:
176   static char ID;
177   ExpandReductions() : FunctionPass(ID) {
178     initializeExpandReductionsPass(*PassRegistry::getPassRegistry());
179   }
180 
181   bool runOnFunction(Function &F) override {
182     const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
183     return expandReductions(F, TTI);
184   }
185 
186   void getAnalysisUsage(AnalysisUsage &AU) const override {
187     AU.addRequired<TargetTransformInfoWrapperPass>();
188     AU.setPreservesCFG();
189   }
190 };
191 }
192 
193 char ExpandReductions::ID;
194 INITIALIZE_PASS_BEGIN(ExpandReductions, "expand-reductions",
195                       "Expand reduction intrinsics", false, false)
196 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
197 INITIALIZE_PASS_END(ExpandReductions, "expand-reductions",
198                     "Expand reduction intrinsics", false, false)
199 
200 FunctionPass *llvm::createExpandReductionsPass() {
201   return new ExpandReductions();
202 }
203 
204 PreservedAnalyses ExpandReductionsPass::run(Function &F,
205                                             FunctionAnalysisManager &AM) {
206   const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
207   if (!expandReductions(F, &TTI))
208     return PreservedAnalyses::all();
209   PreservedAnalyses PA;
210   PA.preserveSet<CFGAnalyses>();
211   return PA;
212 }
213