1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
26 //
27 // Replacement:
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
33 //
34 //
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
40 // replace it.
41 //
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference  to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
54 // to the IR.
55 //
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
59 //
60 //===----------------------------------------------------------------------===//
61 
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/MapVector.h"
64 #include "llvm/ADT/Statistic.h"
65 #include "llvm/Analysis/TargetLibraryInfo.h"
66 #include "llvm/Analysis/TargetTransformInfo.h"
67 #include "llvm/CodeGen/TargetLowering.h"
68 #include "llvm/CodeGen/TargetPassConfig.h"
69 #include "llvm/CodeGen/TargetSubtargetInfo.h"
70 #include "llvm/IR/IRBuilder.h"
71 #include "llvm/IR/PatternMatch.h"
72 #include "llvm/InitializePasses.h"
73 #include "llvm/Target/TargetMachine.h"
74 #include "llvm/Transforms/Utils/Local.h"
75 #include <algorithm>
76 
77 using namespace llvm;
78 using namespace PatternMatch;
79 
80 #define DEBUG_TYPE "complex-deinterleaving"
81 
82 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83 
84 static cl::opt<bool> ComplexDeinterleavingEnabled(
85     "enable-complex-deinterleaving",
86     cl::desc("Enable generation of complex instructions"), cl::init(true),
87     cl::Hidden);
88 
89 /// Checks the given mask, and determines whether said mask is interleaving.
90 ///
91 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
92 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
93 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
94 static bool isInterleavingMask(ArrayRef<int> Mask);
95 
96 /// Checks the given mask, and determines whether said mask is deinterleaving.
97 ///
98 /// To be deinterleaving, a mask must increment in steps of 2, and either start
99 /// with 0 or 1.
100 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101 /// <1, 3, 5, 7>).
102 static bool isDeinterleavingMask(ArrayRef<int> Mask);
103 
104 /// Returns true if the operation is a negation of V, and it works for both
105 /// integers and floats.
106 static bool isNeg(Value *V);
107 
108 /// Returns the operand for negation operation.
109 static Value *getNegOperand(Value *V);
110 
111 namespace {
112 
113 class ComplexDeinterleavingLegacyPass : public FunctionPass {
114 public:
115   static char ID;
116 
ComplexDeinterleavingLegacyPass(const TargetMachine * TM=nullptr)117   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
118       : FunctionPass(ID), TM(TM) {
119     initializeComplexDeinterleavingLegacyPassPass(
120         *PassRegistry::getPassRegistry());
121   }
122 
getPassName() const123   StringRef getPassName() const override {
124     return "Complex Deinterleaving Pass";
125   }
126 
127   bool runOnFunction(Function &F) override;
getAnalysisUsage(AnalysisUsage & AU) const128   void getAnalysisUsage(AnalysisUsage &AU) const override {
129     AU.addRequired<TargetLibraryInfoWrapperPass>();
130     AU.setPreservesCFG();
131   }
132 
133 private:
134   const TargetMachine *TM;
135 };
136 
137 class ComplexDeinterleavingGraph;
138 struct ComplexDeinterleavingCompositeNode {
139 
ComplexDeinterleavingCompositeNode__anond8caac5c0111::ComplexDeinterleavingCompositeNode140   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
141                                      Value *R, Value *I)
142       : Operation(Op), Real(R), Imag(I) {}
143 
144 private:
145   friend class ComplexDeinterleavingGraph;
146   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148 
149 public:
150   ComplexDeinterleavingOperation Operation;
151   Value *Real;
152   Value *Imag;
153 
154   // This two members are required exclusively for generating
155   // ComplexDeinterleavingOperation::Symmetric operations.
156   unsigned Opcode;
157   std::optional<FastMathFlags> Flags;
158 
159   ComplexDeinterleavingRotation Rotation =
160       ComplexDeinterleavingRotation::Rotation_0;
161   SmallVector<RawNodePtr> Operands;
162   Value *ReplacementNode = nullptr;
163 
addOperand__anond8caac5c0111::ComplexDeinterleavingCompositeNode164   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
165 
dump__anond8caac5c0111::ComplexDeinterleavingCompositeNode166   void dump() { dump(dbgs()); }
dump__anond8caac5c0111::ComplexDeinterleavingCompositeNode167   void dump(raw_ostream &OS) {
168     auto PrintValue = [&](Value *V) {
169       if (V) {
170         OS << "\"";
171         V->print(OS, true);
172         OS << "\"\n";
173       } else
174         OS << "nullptr\n";
175     };
176     auto PrintNodeRef = [&](RawNodePtr Ptr) {
177       if (Ptr)
178         OS << Ptr << "\n";
179       else
180         OS << "nullptr\n";
181     };
182 
183     OS << "- CompositeNode: " << this << "\n";
184     OS << "  Real: ";
185     PrintValue(Real);
186     OS << "  Imag: ";
187     PrintValue(Imag);
188     OS << "  ReplacementNode: ";
189     PrintValue(ReplacementNode);
190     OS << "  Operation: " << (int)Operation << "\n";
191     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
192     OS << "  Operands: \n";
193     for (const auto &Op : Operands) {
194       OS << "    - ";
195       PrintNodeRef(Op);
196     }
197   }
198 };
199 
200 class ComplexDeinterleavingGraph {
201 public:
202   struct Product {
203     Value *Multiplier;
204     Value *Multiplicand;
205     bool IsPositive;
206   };
207 
208   using Addend = std::pair<Value *, bool>;
209   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
210   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
211 
212   // Helper struct for holding info about potential partial multiplication
213   // candidates
214   struct PartialMulCandidate {
215     Value *Common;
216     NodePtr Node;
217     unsigned RealIdx;
218     unsigned ImagIdx;
219     bool IsNodeInverted;
220   };
221 
ComplexDeinterleavingGraph(const TargetLowering * TL,const TargetLibraryInfo * TLI)222   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
223                                       const TargetLibraryInfo *TLI)
224       : TL(TL), TLI(TLI) {}
225 
226 private:
227   const TargetLowering *TL = nullptr;
228   const TargetLibraryInfo *TLI = nullptr;
229   SmallVector<NodePtr> CompositeNodes;
230   DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
231 
232   SmallPtrSet<Instruction *, 16> FinalInstructions;
233 
234   /// Root instructions are instructions from which complex computation starts
235   std::map<Instruction *, NodePtr> RootToNode;
236 
237   /// Topologically sorted root instructions
238   SmallVector<Instruction *, 1> OrderedRoots;
239 
240   /// When examining a basic block for complex deinterleaving, if it is a simple
241   /// one-block loop, then the only incoming block is 'Incoming' and the
242   /// 'BackEdge' block is the block itself."
243   BasicBlock *BackEdge = nullptr;
244   BasicBlock *Incoming = nullptr;
245 
246   /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
247   /// %OutsideUser as it is shown in the IR:
248   ///
249   /// vector.body:
250   ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
251   ///                                [ %ReductionOp, %vector.body ]
252   ///   ...
253   ///   %ReductionOp = fadd i64 ...
254   ///   ...
255   ///   br i1 %condition, label %vector.body, %middle.block
256   ///
257   /// middle.block:
258   ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
259   ///
260   /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
261   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
262   MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
263 
264   /// In the process of detecting a reduction, we consider a pair of
265   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
266   /// traverse the use-tree to detect complex operations. As this is a reduction
267   /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
268   /// to the %ReductionOPs that we suspect to be complex.
269   /// RealPHI and ImagPHI are used by the identifyPHINode method.
270   PHINode *RealPHI = nullptr;
271   PHINode *ImagPHI = nullptr;
272 
273   /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
274   /// detection.
275   bool PHIsFound = false;
276 
277   /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
278   /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
279   /// This mapping is populated during
280   /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
281   /// used in the ComplexDeinterleavingOperation::ReductionOperation node
282   /// replacement process.
283   std::map<PHINode *, PHINode *> OldToNewPHI;
284 
prepareCompositeNode(ComplexDeinterleavingOperation Operation,Value * R,Value * I)285   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
286                                Value *R, Value *I) {
287     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
288              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
289             (R && I)) &&
290            "Reduction related nodes must have Real and Imaginary parts");
291     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
292                                                                 I);
293   }
294 
submitCompositeNode(NodePtr Node)295   NodePtr submitCompositeNode(NodePtr Node) {
296     CompositeNodes.push_back(Node);
297     if (Node->Real && Node->Imag)
298       CachedResult[{Node->Real, Node->Imag}] = Node;
299     return Node;
300   }
301 
302   /// Identifies a complex partial multiply pattern and its rotation, based on
303   /// the following patterns
304   ///
305   ///  0:  r: cr + ar * br
306   ///      i: ci + ar * bi
307   /// 90:  r: cr - ai * bi
308   ///      i: ci + ai * br
309   /// 180: r: cr - ar * br
310   ///      i: ci - ar * bi
311   /// 270: r: cr + ai * bi
312   ///      i: ci - ai * br
313   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
314 
315   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
316   /// is partially known from identifyPartialMul, filling in the other half of
317   /// the complex pair.
318   NodePtr
319   identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
320                               std::pair<Value *, Value *> &CommonOperandI);
321 
322   /// Identifies a complex add pattern and its rotation, based on the following
323   /// patterns.
324   ///
325   /// 90:  r: ar - bi
326   ///      i: ai + br
327   /// 270: r: ar + bi
328   ///      i: ai - br
329   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
330   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
331 
332   NodePtr identifyNode(Value *R, Value *I);
333 
334   /// Determine if a sum of complex numbers can be formed from \p RealAddends
335   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
336   /// Return nullptr if it is not possible to construct a complex number.
337   /// \p Flags are needed to generate symmetric Add and Sub operations.
338   NodePtr identifyAdditions(std::list<Addend> &RealAddends,
339                             std::list<Addend> &ImagAddends,
340                             std::optional<FastMathFlags> Flags,
341                             NodePtr Accumulator);
342 
343   /// Extract one addend that have both real and imaginary parts positive.
344   NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
345                                 std::list<Addend> &ImagAddends);
346 
347   /// Determine if sum of multiplications of complex numbers can be formed from
348   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
349   /// to it. Return nullptr if it is not possible to construct a complex number.
350   NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
351                                   std::vector<Product> &ImagMuls,
352                                   NodePtr Accumulator);
353 
354   /// Go through pairs of multiplication (one Real and one Imag) and find all
355   /// possible candidates for partial multiplication and put them into \p
356   /// Candidates. Returns true if all Product has pair with common operand
357   bool collectPartialMuls(const std::vector<Product> &RealMuls,
358                           const std::vector<Product> &ImagMuls,
359                           std::vector<PartialMulCandidate> &Candidates);
360 
361   /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
362   /// the order of complex computation operations may be significantly altered,
363   /// and the real and imaginary parts may not be executed in parallel. This
364   /// function takes this into consideration and employs a more general approach
365   /// to identify complex computations. Initially, it gathers all the addends
366   /// and multiplicands and then constructs a complex expression from them.
367   NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
368 
369   NodePtr identifyRoot(Instruction *I);
370 
371   /// Identifies the Deinterleave operation applied to a vector containing
372   /// complex numbers. There are two ways to represent the Deinterleave
373   /// operation:
374   /// * Using two shufflevectors with even indices for /pReal instruction and
375   /// odd indices for /pImag instructions (only for fixed-width vectors)
376   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
377   /// intrinsic (for both fixed and scalable vectors)
378   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
379 
380   /// identifying the operation that represents a complex number repeated in a
381   /// Splat vector. There are two possible types of splats: ConstantExpr with
382   /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
383   /// initialization mask with all values set to zero.
384   NodePtr identifySplat(Value *Real, Value *Imag);
385 
386   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
387 
388   /// Identifies SelectInsts in a loop that has reduction with predication masks
389   /// and/or predicated tail folding
390   NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
391 
392   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
393 
394   /// Complete IR modifications after producing new reduction operation:
395   /// * Populate the PHINode generated for
396   /// ComplexDeinterleavingOperation::ReductionPHI
397   /// * Deinterleave the final value outside of the loop and repurpose original
398   /// reduction users
399   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
400 
401 public:
dump()402   void dump() { dump(dbgs()); }
dump(raw_ostream & OS)403   void dump(raw_ostream &OS) {
404     for (const auto &Node : CompositeNodes)
405       Node->dump(OS);
406   }
407 
408   /// Returns false if the deinterleaving operation should be cancelled for the
409   /// current graph.
410   bool identifyNodes(Instruction *RootI);
411 
412   /// In case \pB is one-block loop, this function seeks potential reductions
413   /// and populates ReductionInfo. Returns true if any reductions were
414   /// identified.
415   bool collectPotentialReductions(BasicBlock *B);
416 
417   void identifyReductionNodes();
418 
419   /// Check that every instruction, from the roots to the leaves, has internal
420   /// uses.
421   bool checkNodes();
422 
423   /// Perform the actual replacement of the underlying instruction graph.
424   void replaceNodes();
425 };
426 
427 class ComplexDeinterleaving {
428 public:
ComplexDeinterleaving(const TargetLowering * tl,const TargetLibraryInfo * tli)429   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
430       : TL(tl), TLI(tli) {}
431   bool runOnFunction(Function &F);
432 
433 private:
434   bool evaluateBasicBlock(BasicBlock *B);
435 
436   const TargetLowering *TL = nullptr;
437   const TargetLibraryInfo *TLI = nullptr;
438 };
439 
440 } // namespace
441 
442 char ComplexDeinterleavingLegacyPass::ID = 0;
443 
444 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
445                       "Complex Deinterleaving", false, false)
446 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
447                     "Complex Deinterleaving", false, false)
448 
run(Function & F,FunctionAnalysisManager & AM)449 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
450                                                  FunctionAnalysisManager &AM) {
451   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
452   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
453   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
454     return PreservedAnalyses::all();
455 
456   PreservedAnalyses PA;
457   PA.preserve<FunctionAnalysisManagerModuleProxy>();
458   return PA;
459 }
460 
createComplexDeinterleavingPass(const TargetMachine * TM)461 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
462   return new ComplexDeinterleavingLegacyPass(TM);
463 }
464 
runOnFunction(Function & F)465 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
466   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
467   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
468   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
469 }
470 
runOnFunction(Function & F)471 bool ComplexDeinterleaving::runOnFunction(Function &F) {
472   if (!ComplexDeinterleavingEnabled) {
473     LLVM_DEBUG(
474         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
475     return false;
476   }
477 
478   if (!TL->isComplexDeinterleavingSupported()) {
479     LLVM_DEBUG(
480         dbgs() << "Complex deinterleaving has been disabled, target does "
481                   "not support lowering of complex number operations.\n");
482     return false;
483   }
484 
485   bool Changed = false;
486   for (auto &B : F)
487     Changed |= evaluateBasicBlock(&B);
488 
489   return Changed;
490 }
491 
isInterleavingMask(ArrayRef<int> Mask)492 static bool isInterleavingMask(ArrayRef<int> Mask) {
493   // If the size is not even, it's not an interleaving mask
494   if ((Mask.size() & 1))
495     return false;
496 
497   int HalfNumElements = Mask.size() / 2;
498   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
499     int MaskIdx = Idx * 2;
500     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
501       return false;
502   }
503 
504   return true;
505 }
506 
isDeinterleavingMask(ArrayRef<int> Mask)507 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
508   int Offset = Mask[0];
509   int HalfNumElements = Mask.size() / 2;
510 
511   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
512     if (Mask[Idx] != (Idx * 2) + Offset)
513       return false;
514   }
515 
516   return true;
517 }
518 
isNeg(Value * V)519 bool isNeg(Value *V) {
520   return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
521 }
522 
getNegOperand(Value * V)523 Value *getNegOperand(Value *V) {
524   assert(isNeg(V));
525   auto *I = cast<Instruction>(V);
526   if (I->getOpcode() == Instruction::FNeg)
527     return I->getOperand(0);
528 
529   return I->getOperand(1);
530 }
531 
evaluateBasicBlock(BasicBlock * B)532 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
533   ComplexDeinterleavingGraph Graph(TL, TLI);
534   if (Graph.collectPotentialReductions(B))
535     Graph.identifyReductionNodes();
536 
537   for (auto &I : *B)
538     Graph.identifyNodes(&I);
539 
540   if (Graph.checkNodes()) {
541     Graph.replaceNodes();
542     return true;
543   }
544 
545   return false;
546 }
547 
548 ComplexDeinterleavingGraph::NodePtr
identifyNodeWithImplicitAdd(Instruction * Real,Instruction * Imag,std::pair<Value *,Value * > & PartialMatch)549 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
550     Instruction *Real, Instruction *Imag,
551     std::pair<Value *, Value *> &PartialMatch) {
552   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
553                     << "\n");
554 
555   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
556     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
557     return nullptr;
558   }
559 
560   if ((Real->getOpcode() != Instruction::FMul &&
561        Real->getOpcode() != Instruction::Mul) ||
562       (Imag->getOpcode() != Instruction::FMul &&
563        Imag->getOpcode() != Instruction::Mul)) {
564     LLVM_DEBUG(
565         dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
566     return nullptr;
567   }
568 
569   Value *R0 = Real->getOperand(0);
570   Value *R1 = Real->getOperand(1);
571   Value *I0 = Imag->getOperand(0);
572   Value *I1 = Imag->getOperand(1);
573 
574   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
575   // rotations and use the operand.
576   unsigned Negs = 0;
577   Value *Op;
578   if (match(R0, m_Neg(m_Value(Op)))) {
579     Negs |= 1;
580     R0 = Op;
581   } else if (match(R1, m_Neg(m_Value(Op)))) {
582     Negs |= 1;
583     R1 = Op;
584   }
585 
586   if (isNeg(I0)) {
587     Negs |= 2;
588     Negs ^= 1;
589     I0 = Op;
590   } else if (match(I1, m_Neg(m_Value(Op)))) {
591     Negs |= 2;
592     Negs ^= 1;
593     I1 = Op;
594   }
595 
596   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
597 
598   Value *CommonOperand;
599   Value *UncommonRealOp;
600   Value *UncommonImagOp;
601 
602   if (R0 == I0 || R0 == I1) {
603     CommonOperand = R0;
604     UncommonRealOp = R1;
605   } else if (R1 == I0 || R1 == I1) {
606     CommonOperand = R1;
607     UncommonRealOp = R0;
608   } else {
609     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
610     return nullptr;
611   }
612 
613   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
614   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
615       Rotation == ComplexDeinterleavingRotation::Rotation_270)
616     std::swap(UncommonRealOp, UncommonImagOp);
617 
618   // Between identifyPartialMul and here we need to have found a complete valid
619   // pair from the CommonOperand of each part.
620   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
621       Rotation == ComplexDeinterleavingRotation::Rotation_180)
622     PartialMatch.first = CommonOperand;
623   else
624     PartialMatch.second = CommonOperand;
625 
626   if (!PartialMatch.first || !PartialMatch.second) {
627     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
628     return nullptr;
629   }
630 
631   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
632   if (!CommonNode) {
633     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
634     return nullptr;
635   }
636 
637   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
638   if (!UncommonNode) {
639     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
640     return nullptr;
641   }
642 
643   NodePtr Node = prepareCompositeNode(
644       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
645   Node->Rotation = Rotation;
646   Node->addOperand(CommonNode);
647   Node->addOperand(UncommonNode);
648   return submitCompositeNode(Node);
649 }
650 
651 ComplexDeinterleavingGraph::NodePtr
identifyPartialMul(Instruction * Real,Instruction * Imag)652 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
653                                                Instruction *Imag) {
654   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
655                     << "\n");
656   // Determine rotation
657   auto IsAdd = [](unsigned Op) {
658     return Op == Instruction::FAdd || Op == Instruction::Add;
659   };
660   auto IsSub = [](unsigned Op) {
661     return Op == Instruction::FSub || Op == Instruction::Sub;
662   };
663   ComplexDeinterleavingRotation Rotation;
664   if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
665     Rotation = ComplexDeinterleavingRotation::Rotation_0;
666   else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
667     Rotation = ComplexDeinterleavingRotation::Rotation_90;
668   else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
669     Rotation = ComplexDeinterleavingRotation::Rotation_180;
670   else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
671     Rotation = ComplexDeinterleavingRotation::Rotation_270;
672   else {
673     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
674     return nullptr;
675   }
676 
677   if (isa<FPMathOperator>(Real) &&
678       (!Real->getFastMathFlags().allowContract() ||
679        !Imag->getFastMathFlags().allowContract())) {
680     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
681     return nullptr;
682   }
683 
684   Value *CR = Real->getOperand(0);
685   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
686   if (!RealMulI)
687     return nullptr;
688   Value *CI = Imag->getOperand(0);
689   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
690   if (!ImagMulI)
691     return nullptr;
692 
693   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
694     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
695     return nullptr;
696   }
697 
698   Value *R0 = RealMulI->getOperand(0);
699   Value *R1 = RealMulI->getOperand(1);
700   Value *I0 = ImagMulI->getOperand(0);
701   Value *I1 = ImagMulI->getOperand(1);
702 
703   Value *CommonOperand;
704   Value *UncommonRealOp;
705   Value *UncommonImagOp;
706 
707   if (R0 == I0 || R0 == I1) {
708     CommonOperand = R0;
709     UncommonRealOp = R1;
710   } else if (R1 == I0 || R1 == I1) {
711     CommonOperand = R1;
712     UncommonRealOp = R0;
713   } else {
714     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
715     return nullptr;
716   }
717 
718   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
719   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
720       Rotation == ComplexDeinterleavingRotation::Rotation_270)
721     std::swap(UncommonRealOp, UncommonImagOp);
722 
723   std::pair<Value *, Value *> PartialMatch(
724       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
725        Rotation == ComplexDeinterleavingRotation::Rotation_180)
726           ? CommonOperand
727           : nullptr,
728       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
729        Rotation == ComplexDeinterleavingRotation::Rotation_270)
730           ? CommonOperand
731           : nullptr);
732 
733   auto *CRInst = dyn_cast<Instruction>(CR);
734   auto *CIInst = dyn_cast<Instruction>(CI);
735 
736   if (!CRInst || !CIInst) {
737     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
738     return nullptr;
739   }
740 
741   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
742   if (!CNode) {
743     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
744     return nullptr;
745   }
746 
747   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
748   if (!UncommonRes) {
749     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
750     return nullptr;
751   }
752 
753   assert(PartialMatch.first && PartialMatch.second);
754   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
755   if (!CommonRes) {
756     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
757     return nullptr;
758   }
759 
760   NodePtr Node = prepareCompositeNode(
761       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
762   Node->Rotation = Rotation;
763   Node->addOperand(CommonRes);
764   Node->addOperand(UncommonRes);
765   Node->addOperand(CNode);
766   return submitCompositeNode(Node);
767 }
768 
769 ComplexDeinterleavingGraph::NodePtr
identifyAdd(Instruction * Real,Instruction * Imag)770 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
771   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
772 
773   // Determine rotation
774   ComplexDeinterleavingRotation Rotation;
775   if ((Real->getOpcode() == Instruction::FSub &&
776        Imag->getOpcode() == Instruction::FAdd) ||
777       (Real->getOpcode() == Instruction::Sub &&
778        Imag->getOpcode() == Instruction::Add))
779     Rotation = ComplexDeinterleavingRotation::Rotation_90;
780   else if ((Real->getOpcode() == Instruction::FAdd &&
781             Imag->getOpcode() == Instruction::FSub) ||
782            (Real->getOpcode() == Instruction::Add &&
783             Imag->getOpcode() == Instruction::Sub))
784     Rotation = ComplexDeinterleavingRotation::Rotation_270;
785   else {
786     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
787     return nullptr;
788   }
789 
790   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
791   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
792   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
793   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
794 
795   if (!AR || !AI || !BR || !BI) {
796     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
797     return nullptr;
798   }
799 
800   NodePtr ResA = identifyNode(AR, AI);
801   if (!ResA) {
802     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
803     return nullptr;
804   }
805   NodePtr ResB = identifyNode(BR, BI);
806   if (!ResB) {
807     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
808     return nullptr;
809   }
810 
811   NodePtr Node =
812       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
813   Node->Rotation = Rotation;
814   Node->addOperand(ResA);
815   Node->addOperand(ResB);
816   return submitCompositeNode(Node);
817 }
818 
isInstructionPairAdd(Instruction * A,Instruction * B)819 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
820   unsigned OpcA = A->getOpcode();
821   unsigned OpcB = B->getOpcode();
822 
823   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
824          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
825          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
826          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
827 }
828 
isInstructionPairMul(Instruction * A,Instruction * B)829 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
830   auto Pattern =
831       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
832 
833   return match(A, Pattern) && match(B, Pattern);
834 }
835 
isInstructionPotentiallySymmetric(Instruction * I)836 static bool isInstructionPotentiallySymmetric(Instruction *I) {
837   switch (I->getOpcode()) {
838   case Instruction::FAdd:
839   case Instruction::FSub:
840   case Instruction::FMul:
841   case Instruction::FNeg:
842   case Instruction::Add:
843   case Instruction::Sub:
844   case Instruction::Mul:
845     return true;
846   default:
847     return false;
848   }
849 }
850 
851 ComplexDeinterleavingGraph::NodePtr
identifySymmetricOperation(Instruction * Real,Instruction * Imag)852 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
853                                                        Instruction *Imag) {
854   if (Real->getOpcode() != Imag->getOpcode())
855     return nullptr;
856 
857   if (!isInstructionPotentiallySymmetric(Real) ||
858       !isInstructionPotentiallySymmetric(Imag))
859     return nullptr;
860 
861   auto *R0 = Real->getOperand(0);
862   auto *I0 = Imag->getOperand(0);
863 
864   NodePtr Op0 = identifyNode(R0, I0);
865   NodePtr Op1 = nullptr;
866   if (Op0 == nullptr)
867     return nullptr;
868 
869   if (Real->isBinaryOp()) {
870     auto *R1 = Real->getOperand(1);
871     auto *I1 = Imag->getOperand(1);
872     Op1 = identifyNode(R1, I1);
873     if (Op1 == nullptr)
874       return nullptr;
875   }
876 
877   if (isa<FPMathOperator>(Real) &&
878       Real->getFastMathFlags() != Imag->getFastMathFlags())
879     return nullptr;
880 
881   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
882                                    Real, Imag);
883   Node->Opcode = Real->getOpcode();
884   if (isa<FPMathOperator>(Real))
885     Node->Flags = Real->getFastMathFlags();
886 
887   Node->addOperand(Op0);
888   if (Real->isBinaryOp())
889     Node->addOperand(Op1);
890 
891   return submitCompositeNode(Node);
892 }
893 
894 ComplexDeinterleavingGraph::NodePtr
identifyNode(Value * R,Value * I)895 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
896   LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
897   assert(R->getType() == I->getType() &&
898          "Real and imaginary parts should not have different types");
899 
900   auto It = CachedResult.find({R, I});
901   if (It != CachedResult.end()) {
902     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
903     return It->second;
904   }
905 
906   if (NodePtr CN = identifySplat(R, I))
907     return CN;
908 
909   auto *Real = dyn_cast<Instruction>(R);
910   auto *Imag = dyn_cast<Instruction>(I);
911   if (!Real || !Imag)
912     return nullptr;
913 
914   if (NodePtr CN = identifyDeinterleave(Real, Imag))
915     return CN;
916 
917   if (NodePtr CN = identifyPHINode(Real, Imag))
918     return CN;
919 
920   if (NodePtr CN = identifySelectNode(Real, Imag))
921     return CN;
922 
923   auto *VTy = cast<VectorType>(Real->getType());
924   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
925 
926   bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
927       ComplexDeinterleavingOperation::CMulPartial, NewVTy);
928   bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
929       ComplexDeinterleavingOperation::CAdd, NewVTy);
930 
931   if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
932     if (NodePtr CN = identifyPartialMul(Real, Imag))
933       return CN;
934   }
935 
936   if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
937     if (NodePtr CN = identifyAdd(Real, Imag))
938       return CN;
939   }
940 
941   if (HasCMulSupport && HasCAddSupport) {
942     if (NodePtr CN = identifyReassocNodes(Real, Imag))
943       return CN;
944   }
945 
946   if (NodePtr CN = identifySymmetricOperation(Real, Imag))
947     return CN;
948 
949   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
950   CachedResult[{R, I}] = nullptr;
951   return nullptr;
952 }
953 
954 ComplexDeinterleavingGraph::NodePtr
identifyReassocNodes(Instruction * Real,Instruction * Imag)955 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
956                                                  Instruction *Imag) {
957   auto IsOperationSupported = [](unsigned Opcode) -> bool {
958     return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
959            Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
960            Opcode == Instruction::Sub;
961   };
962 
963   if (!IsOperationSupported(Real->getOpcode()) ||
964       !IsOperationSupported(Imag->getOpcode()))
965     return nullptr;
966 
967   std::optional<FastMathFlags> Flags;
968   if (isa<FPMathOperator>(Real)) {
969     if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
970       LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
971                            "not identical\n");
972       return nullptr;
973     }
974 
975     Flags = Real->getFastMathFlags();
976     if (!Flags->allowReassoc()) {
977       LLVM_DEBUG(
978           dbgs()
979           << "the 'Reassoc' attribute is missing in the FastMath flags\n");
980       return nullptr;
981     }
982   }
983 
984   // Collect multiplications and addend instructions from the given instruction
985   // while traversing it operands. Additionally, verify that all instructions
986   // have the same fast math flags.
987   auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
988                           std::list<Addend> &Addends) -> bool {
989     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
990     SmallPtrSet<Value *, 8> Visited;
991     while (!Worklist.empty()) {
992       auto [V, IsPositive] = Worklist.back();
993       Worklist.pop_back();
994       if (!Visited.insert(V).second)
995         continue;
996 
997       Instruction *I = dyn_cast<Instruction>(V);
998       if (!I) {
999         Addends.emplace_back(V, IsPositive);
1000         continue;
1001       }
1002 
1003       // If an instruction has more than one user, it indicates that it either
1004       // has an external user, which will be later checked by the checkNodes
1005       // function, or it is a subexpression utilized by multiple expressions. In
1006       // the latter case, we will attempt to separately identify the complex
1007       // operation from here in order to create a shared
1008       // ComplexDeinterleavingCompositeNode.
1009       if (I != Insn && I->getNumUses() > 1) {
1010         LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1011         Addends.emplace_back(I, IsPositive);
1012         continue;
1013       }
1014       switch (I->getOpcode()) {
1015       case Instruction::FAdd:
1016       case Instruction::Add:
1017         Worklist.emplace_back(I->getOperand(1), IsPositive);
1018         Worklist.emplace_back(I->getOperand(0), IsPositive);
1019         break;
1020       case Instruction::FSub:
1021         Worklist.emplace_back(I->getOperand(1), !IsPositive);
1022         Worklist.emplace_back(I->getOperand(0), IsPositive);
1023         break;
1024       case Instruction::Sub:
1025         if (isNeg(I)) {
1026           Worklist.emplace_back(getNegOperand(I), !IsPositive);
1027         } else {
1028           Worklist.emplace_back(I->getOperand(1), !IsPositive);
1029           Worklist.emplace_back(I->getOperand(0), IsPositive);
1030         }
1031         break;
1032       case Instruction::FMul:
1033       case Instruction::Mul: {
1034         Value *A, *B;
1035         if (isNeg(I->getOperand(0))) {
1036           A = getNegOperand(I->getOperand(0));
1037           IsPositive = !IsPositive;
1038         } else {
1039           A = I->getOperand(0);
1040         }
1041 
1042         if (isNeg(I->getOperand(1))) {
1043           B = getNegOperand(I->getOperand(1));
1044           IsPositive = !IsPositive;
1045         } else {
1046           B = I->getOperand(1);
1047         }
1048         Muls.push_back(Product{A, B, IsPositive});
1049         break;
1050       }
1051       case Instruction::FNeg:
1052         Worklist.emplace_back(I->getOperand(0), !IsPositive);
1053         break;
1054       default:
1055         Addends.emplace_back(I, IsPositive);
1056         continue;
1057       }
1058 
1059       if (Flags && I->getFastMathFlags() != *Flags) {
1060         LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1061                              "inconsistent with the root instructions' flags: "
1062                           << *I << "\n");
1063         return false;
1064       }
1065     }
1066     return true;
1067   };
1068 
1069   std::vector<Product> RealMuls, ImagMuls;
1070   std::list<Addend> RealAddends, ImagAddends;
1071   if (!Collect(Real, RealMuls, RealAddends) ||
1072       !Collect(Imag, ImagMuls, ImagAddends))
1073     return nullptr;
1074 
1075   if (RealAddends.size() != ImagAddends.size())
1076     return nullptr;
1077 
1078   NodePtr FinalNode;
1079   if (!RealMuls.empty() || !ImagMuls.empty()) {
1080     // If there are multiplicands, extract positive addend and use it as an
1081     // accumulator
1082     FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1083     FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1084     if (!FinalNode)
1085       return nullptr;
1086   }
1087 
1088   // Identify and process remaining additions
1089   if (!RealAddends.empty() || !ImagAddends.empty()) {
1090     FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1091     if (!FinalNode)
1092       return nullptr;
1093   }
1094   assert(FinalNode && "FinalNode can not be nullptr here");
1095   // Set the Real and Imag fields of the final node and submit it
1096   FinalNode->Real = Real;
1097   FinalNode->Imag = Imag;
1098   submitCompositeNode(FinalNode);
1099   return FinalNode;
1100 }
1101 
collectPartialMuls(const std::vector<Product> & RealMuls,const std::vector<Product> & ImagMuls,std::vector<PartialMulCandidate> & PartialMulCandidates)1102 bool ComplexDeinterleavingGraph::collectPartialMuls(
1103     const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1104     std::vector<PartialMulCandidate> &PartialMulCandidates) {
1105   // Helper function to extract a common operand from two products
1106   auto FindCommonInstruction = [](const Product &Real,
1107                                   const Product &Imag) -> Value * {
1108     if (Real.Multiplicand == Imag.Multiplicand ||
1109         Real.Multiplicand == Imag.Multiplier)
1110       return Real.Multiplicand;
1111 
1112     if (Real.Multiplier == Imag.Multiplicand ||
1113         Real.Multiplier == Imag.Multiplier)
1114       return Real.Multiplier;
1115 
1116     return nullptr;
1117   };
1118 
1119   // Iterating over real and imaginary multiplications to find common operands
1120   // If a common operand is found, a partial multiplication candidate is created
1121   // and added to the candidates vector The function returns false if no common
1122   // operands are found for any product
1123   for (unsigned i = 0; i < RealMuls.size(); ++i) {
1124     bool FoundCommon = false;
1125     for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1126       auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1127       if (!Common)
1128         continue;
1129 
1130       auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1131                                                    : RealMuls[i].Multiplicand;
1132       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1133                                                    : ImagMuls[j].Multiplicand;
1134 
1135       auto Node = identifyNode(A, B);
1136       if (Node) {
1137         FoundCommon = true;
1138         PartialMulCandidates.push_back({Common, Node, i, j, false});
1139       }
1140 
1141       Node = identifyNode(B, A);
1142       if (Node) {
1143         FoundCommon = true;
1144         PartialMulCandidates.push_back({Common, Node, i, j, true});
1145       }
1146     }
1147     if (!FoundCommon)
1148       return false;
1149   }
1150   return true;
1151 }
1152 
1153 ComplexDeinterleavingGraph::NodePtr
identifyMultiplications(std::vector<Product> & RealMuls,std::vector<Product> & ImagMuls,NodePtr Accumulator=nullptr)1154 ComplexDeinterleavingGraph::identifyMultiplications(
1155     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1156     NodePtr Accumulator = nullptr) {
1157   if (RealMuls.size() != ImagMuls.size())
1158     return nullptr;
1159 
1160   std::vector<PartialMulCandidate> Info;
1161   if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1162     return nullptr;
1163 
1164   // Map to store common instruction to node pointers
1165   std::map<Value *, NodePtr> CommonToNode;
1166   std::vector<bool> Processed(Info.size(), false);
1167   for (unsigned I = 0; I < Info.size(); ++I) {
1168     if (Processed[I])
1169       continue;
1170 
1171     PartialMulCandidate &InfoA = Info[I];
1172     for (unsigned J = I + 1; J < Info.size(); ++J) {
1173       if (Processed[J])
1174         continue;
1175 
1176       PartialMulCandidate &InfoB = Info[J];
1177       auto *InfoReal = &InfoA;
1178       auto *InfoImag = &InfoB;
1179 
1180       auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1181       if (!NodeFromCommon) {
1182         std::swap(InfoReal, InfoImag);
1183         NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1184       }
1185       if (!NodeFromCommon)
1186         continue;
1187 
1188       CommonToNode[InfoReal->Common] = NodeFromCommon;
1189       CommonToNode[InfoImag->Common] = NodeFromCommon;
1190       Processed[I] = true;
1191       Processed[J] = true;
1192     }
1193   }
1194 
1195   std::vector<bool> ProcessedReal(RealMuls.size(), false);
1196   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1197   NodePtr Result = Accumulator;
1198   for (auto &PMI : Info) {
1199     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1200       continue;
1201 
1202     auto It = CommonToNode.find(PMI.Common);
1203     // TODO: Process independent complex multiplications. Cases like this:
1204     //  A.real() * B where both A and B are complex numbers.
1205     if (It == CommonToNode.end()) {
1206       LLVM_DEBUG({
1207         dbgs() << "Unprocessed independent partial multiplication:\n";
1208         for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1209           dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1210                            << " multiplied by " << *Mul->Multiplicand << "\n";
1211       });
1212       return nullptr;
1213     }
1214 
1215     auto &RealMul = RealMuls[PMI.RealIdx];
1216     auto &ImagMul = ImagMuls[PMI.ImagIdx];
1217 
1218     auto NodeA = It->second;
1219     auto NodeB = PMI.Node;
1220     auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1221     // The following table illustrates the relationship between multiplications
1222     // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1223     // can see:
1224     //
1225     // Rotation |   Real |   Imag |
1226     // ---------+--------+--------+
1227     //        0 |  x * u |  x * v |
1228     //       90 | -y * v |  y * u |
1229     //      180 | -x * u | -x * v |
1230     //      270 |  y * v | -y * u |
1231     //
1232     // Check if the candidate can indeed be represented by partial
1233     // multiplication
1234     // TODO: Add support for multiplication by complex one
1235     if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1236         (!IsMultiplicandReal && !PMI.IsNodeInverted))
1237       continue;
1238 
1239     // Determine the rotation based on the multiplications
1240     ComplexDeinterleavingRotation Rotation;
1241     if (IsMultiplicandReal) {
1242       // Detect 0 and 180 degrees rotation
1243       if (RealMul.IsPositive && ImagMul.IsPositive)
1244         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1245       else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1246         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1247       else
1248         continue;
1249 
1250     } else {
1251       // Detect 90 and 270 degrees rotation
1252       if (!RealMul.IsPositive && ImagMul.IsPositive)
1253         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1254       else if (RealMul.IsPositive && !ImagMul.IsPositive)
1255         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1256       else
1257         continue;
1258     }
1259 
1260     LLVM_DEBUG({
1261       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1262       dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1263       dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1264       dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1265       dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1266       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1267     });
1268 
1269     NodePtr NodeMul = prepareCompositeNode(
1270         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1271     NodeMul->Rotation = Rotation;
1272     NodeMul->addOperand(NodeA);
1273     NodeMul->addOperand(NodeB);
1274     if (Result)
1275       NodeMul->addOperand(Result);
1276     submitCompositeNode(NodeMul);
1277     Result = NodeMul;
1278     ProcessedReal[PMI.RealIdx] = true;
1279     ProcessedImag[PMI.ImagIdx] = true;
1280   }
1281 
1282   // Ensure all products have been processed, if not return nullptr.
1283   if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1284       !all_of(ProcessedImag, [](bool V) { return V; })) {
1285 
1286     // Dump debug information about which partial multiplications are not
1287     // processed.
1288     LLVM_DEBUG({
1289       dbgs() << "Unprocessed products (Real):\n";
1290       for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1291         if (!ProcessedReal[i])
1292           dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1293                            << *RealMuls[i].Multiplier << " multiplied by "
1294                            << *RealMuls[i].Multiplicand << "\n";
1295       }
1296       dbgs() << "Unprocessed products (Imag):\n";
1297       for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1298         if (!ProcessedImag[i])
1299           dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1300                            << *ImagMuls[i].Multiplier << " multiplied by "
1301                            << *ImagMuls[i].Multiplicand << "\n";
1302       }
1303     });
1304     return nullptr;
1305   }
1306 
1307   return Result;
1308 }
1309 
1310 ComplexDeinterleavingGraph::NodePtr
identifyAdditions(std::list<Addend> & RealAddends,std::list<Addend> & ImagAddends,std::optional<FastMathFlags> Flags,NodePtr Accumulator=nullptr)1311 ComplexDeinterleavingGraph::identifyAdditions(
1312     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1313     std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1314   if (RealAddends.size() != ImagAddends.size())
1315     return nullptr;
1316 
1317   NodePtr Result;
1318   // If we have accumulator use it as first addend
1319   if (Accumulator)
1320     Result = Accumulator;
1321   // Otherwise find an element with both positive real and imaginary parts.
1322   else
1323     Result = extractPositiveAddend(RealAddends, ImagAddends);
1324 
1325   if (!Result)
1326     return nullptr;
1327 
1328   while (!RealAddends.empty()) {
1329     auto ItR = RealAddends.begin();
1330     auto [R, IsPositiveR] = *ItR;
1331 
1332     bool FoundImag = false;
1333     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1334       auto [I, IsPositiveI] = *ItI;
1335       ComplexDeinterleavingRotation Rotation;
1336       if (IsPositiveR && IsPositiveI)
1337         Rotation = ComplexDeinterleavingRotation::Rotation_0;
1338       else if (!IsPositiveR && IsPositiveI)
1339         Rotation = ComplexDeinterleavingRotation::Rotation_90;
1340       else if (!IsPositiveR && !IsPositiveI)
1341         Rotation = ComplexDeinterleavingRotation::Rotation_180;
1342       else
1343         Rotation = ComplexDeinterleavingRotation::Rotation_270;
1344 
1345       NodePtr AddNode;
1346       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1347           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1348         AddNode = identifyNode(R, I);
1349       } else {
1350         AddNode = identifyNode(I, R);
1351       }
1352       if (AddNode) {
1353         LLVM_DEBUG({
1354           dbgs() << "Identified addition:\n";
1355           dbgs().indent(4) << "X: " << *R << "\n";
1356           dbgs().indent(4) << "Y: " << *I << "\n";
1357           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1358         });
1359 
1360         NodePtr TmpNode;
1361         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1362           TmpNode = prepareCompositeNode(
1363               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1364           if (Flags) {
1365             TmpNode->Opcode = Instruction::FAdd;
1366             TmpNode->Flags = *Flags;
1367           } else {
1368             TmpNode->Opcode = Instruction::Add;
1369           }
1370         } else if (Rotation ==
1371                    llvm::ComplexDeinterleavingRotation::Rotation_180) {
1372           TmpNode = prepareCompositeNode(
1373               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1374           if (Flags) {
1375             TmpNode->Opcode = Instruction::FSub;
1376             TmpNode->Flags = *Flags;
1377           } else {
1378             TmpNode->Opcode = Instruction::Sub;
1379           }
1380         } else {
1381           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1382                                          nullptr, nullptr);
1383           TmpNode->Rotation = Rotation;
1384         }
1385 
1386         TmpNode->addOperand(Result);
1387         TmpNode->addOperand(AddNode);
1388         submitCompositeNode(TmpNode);
1389         Result = TmpNode;
1390         RealAddends.erase(ItR);
1391         ImagAddends.erase(ItI);
1392         FoundImag = true;
1393         break;
1394       }
1395     }
1396     if (!FoundImag)
1397       return nullptr;
1398   }
1399   return Result;
1400 }
1401 
1402 ComplexDeinterleavingGraph::NodePtr
extractPositiveAddend(std::list<Addend> & RealAddends,std::list<Addend> & ImagAddends)1403 ComplexDeinterleavingGraph::extractPositiveAddend(
1404     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1405   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1406     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1407       auto [R, IsPositiveR] = *ItR;
1408       auto [I, IsPositiveI] = *ItI;
1409       if (IsPositiveR && IsPositiveI) {
1410         auto Result = identifyNode(R, I);
1411         if (Result) {
1412           RealAddends.erase(ItR);
1413           ImagAddends.erase(ItI);
1414           return Result;
1415         }
1416       }
1417     }
1418   }
1419   return nullptr;
1420 }
1421 
identifyNodes(Instruction * RootI)1422 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1423   // This potential root instruction might already have been recognized as
1424   // reduction. Because RootToNode maps both Real and Imaginary parts to
1425   // CompositeNode we should choose only one either Real or Imag instruction to
1426   // use as an anchor for generating complex instruction.
1427   auto It = RootToNode.find(RootI);
1428   if (It != RootToNode.end()) {
1429     auto RootNode = It->second;
1430     assert(RootNode->Operation ==
1431            ComplexDeinterleavingOperation::ReductionOperation);
1432     // Find out which part, Real or Imag, comes later, and only if we come to
1433     // the latest part, add it to OrderedRoots.
1434     auto *R = cast<Instruction>(RootNode->Real);
1435     auto *I = cast<Instruction>(RootNode->Imag);
1436     auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1437     if (ReplacementAnchor != RootI)
1438       return false;
1439     OrderedRoots.push_back(RootI);
1440     return true;
1441   }
1442 
1443   auto RootNode = identifyRoot(RootI);
1444   if (!RootNode)
1445     return false;
1446 
1447   LLVM_DEBUG({
1448     Function *F = RootI->getFunction();
1449     BasicBlock *B = RootI->getParent();
1450     dbgs() << "Complex deinterleaving graph for " << F->getName()
1451            << "::" << B->getName() << ".\n";
1452     dump(dbgs());
1453     dbgs() << "\n";
1454   });
1455   RootToNode[RootI] = RootNode;
1456   OrderedRoots.push_back(RootI);
1457   return true;
1458 }
1459 
collectPotentialReductions(BasicBlock * B)1460 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1461   bool FoundPotentialReduction = false;
1462 
1463   auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1464   if (!Br || Br->getNumSuccessors() != 2)
1465     return false;
1466 
1467   // Identify simple one-block loop
1468   if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1469     return false;
1470 
1471   SmallVector<PHINode *> PHIs;
1472   for (auto &PHI : B->phis()) {
1473     if (PHI.getNumIncomingValues() != 2)
1474       continue;
1475 
1476     if (!PHI.getType()->isVectorTy())
1477       continue;
1478 
1479     auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1480     if (!ReductionOp)
1481       continue;
1482 
1483     // Check if final instruction is reduced outside of current block
1484     Instruction *FinalReduction = nullptr;
1485     auto NumUsers = 0u;
1486     for (auto *U : ReductionOp->users()) {
1487       ++NumUsers;
1488       if (U == &PHI)
1489         continue;
1490       FinalReduction = dyn_cast<Instruction>(U);
1491     }
1492 
1493     if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1494         isa<PHINode>(FinalReduction))
1495       continue;
1496 
1497     ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1498     BackEdge = B;
1499     auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1500     auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1501     Incoming = PHI.getIncomingBlock(IncomingIdx);
1502     FoundPotentialReduction = true;
1503 
1504     // If the initial value of PHINode is an Instruction, consider it a leaf
1505     // value of a complex deinterleaving graph.
1506     if (auto *InitPHI =
1507             dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1508       FinalInstructions.insert(InitPHI);
1509   }
1510   return FoundPotentialReduction;
1511 }
1512 
identifyReductionNodes()1513 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1514   SmallVector<bool> Processed(ReductionInfo.size(), false);
1515   SmallVector<Instruction *> OperationInstruction;
1516   for (auto &P : ReductionInfo)
1517     OperationInstruction.push_back(P.first);
1518 
1519   // Identify a complex computation by evaluating two reduction operations that
1520   // potentially could be involved
1521   for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1522     if (Processed[i])
1523       continue;
1524     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1525       if (Processed[j])
1526         continue;
1527 
1528       auto *Real = OperationInstruction[i];
1529       auto *Imag = OperationInstruction[j];
1530       if (Real->getType() != Imag->getType())
1531         continue;
1532 
1533       RealPHI = ReductionInfo[Real].first;
1534       ImagPHI = ReductionInfo[Imag].first;
1535       PHIsFound = false;
1536       auto Node = identifyNode(Real, Imag);
1537       if (!Node) {
1538         std::swap(Real, Imag);
1539         std::swap(RealPHI, ImagPHI);
1540         Node = identifyNode(Real, Imag);
1541       }
1542 
1543       // If a node is identified and reduction PHINode is used in the chain of
1544       // operations, mark its operation instructions as used to prevent
1545       // re-identification and attach the node to the real part
1546       if (Node && PHIsFound) {
1547         LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1548                           << *Real << " / " << *Imag << "\n");
1549         Processed[i] = true;
1550         Processed[j] = true;
1551         auto RootNode = prepareCompositeNode(
1552             ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1553         RootNode->addOperand(Node);
1554         RootToNode[Real] = RootNode;
1555         RootToNode[Imag] = RootNode;
1556         submitCompositeNode(RootNode);
1557         break;
1558       }
1559     }
1560   }
1561 
1562   RealPHI = nullptr;
1563   ImagPHI = nullptr;
1564 }
1565 
checkNodes()1566 bool ComplexDeinterleavingGraph::checkNodes() {
1567   // Collect all instructions from roots to leaves
1568   SmallPtrSet<Instruction *, 16> AllInstructions;
1569   SmallVector<Instruction *, 8> Worklist;
1570   for (auto &Pair : RootToNode)
1571     Worklist.push_back(Pair.first);
1572 
1573   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1574   // chains
1575   while (!Worklist.empty()) {
1576     auto *I = Worklist.back();
1577     Worklist.pop_back();
1578 
1579     if (!AllInstructions.insert(I).second)
1580       continue;
1581 
1582     for (Value *Op : I->operands()) {
1583       if (auto *OpI = dyn_cast<Instruction>(Op)) {
1584         if (!FinalInstructions.count(I))
1585           Worklist.emplace_back(OpI);
1586       }
1587     }
1588   }
1589 
1590   // Find instructions that have users outside of chain
1591   SmallVector<Instruction *, 2> OuterInstructions;
1592   for (auto *I : AllInstructions) {
1593     // Skip root nodes
1594     if (RootToNode.count(I))
1595       continue;
1596 
1597     for (User *U : I->users()) {
1598       if (AllInstructions.count(cast<Instruction>(U)))
1599         continue;
1600 
1601       // Found an instruction that is not used by XCMLA/XCADD chain
1602       Worklist.emplace_back(I);
1603       break;
1604     }
1605   }
1606 
1607   // If any instructions are found to be used outside, find and remove roots
1608   // that somehow connect to those instructions.
1609   SmallPtrSet<Instruction *, 16> Visited;
1610   while (!Worklist.empty()) {
1611     auto *I = Worklist.back();
1612     Worklist.pop_back();
1613     if (!Visited.insert(I).second)
1614       continue;
1615 
1616     // Found an impacted root node. Removing it from the nodes to be
1617     // deinterleaved
1618     if (RootToNode.count(I)) {
1619       LLVM_DEBUG(dbgs() << "Instruction " << *I
1620                         << " could be deinterleaved but its chain of complex "
1621                            "operations have an outside user\n");
1622       RootToNode.erase(I);
1623     }
1624 
1625     if (!AllInstructions.count(I) || FinalInstructions.count(I))
1626       continue;
1627 
1628     for (User *U : I->users())
1629       Worklist.emplace_back(cast<Instruction>(U));
1630 
1631     for (Value *Op : I->operands()) {
1632       if (auto *OpI = dyn_cast<Instruction>(Op))
1633         Worklist.emplace_back(OpI);
1634     }
1635   }
1636   return !RootToNode.empty();
1637 }
1638 
1639 ComplexDeinterleavingGraph::NodePtr
identifyRoot(Instruction * RootI)1640 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1641   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1642     if (Intrinsic->getIntrinsicID() !=
1643         Intrinsic::experimental_vector_interleave2)
1644       return nullptr;
1645 
1646     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1647     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1648     if (!Real || !Imag)
1649       return nullptr;
1650 
1651     return identifyNode(Real, Imag);
1652   }
1653 
1654   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1655   if (!SVI)
1656     return nullptr;
1657 
1658   // Look for a shufflevector that takes separate vectors of the real and
1659   // imaginary components and recombines them into a single vector.
1660   if (!isInterleavingMask(SVI->getShuffleMask()))
1661     return nullptr;
1662 
1663   Instruction *Real;
1664   Instruction *Imag;
1665   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1666     return nullptr;
1667 
1668   return identifyNode(Real, Imag);
1669 }
1670 
1671 ComplexDeinterleavingGraph::NodePtr
identifyDeinterleave(Instruction * Real,Instruction * Imag)1672 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1673                                                  Instruction *Imag) {
1674   Instruction *I = nullptr;
1675   Value *FinalValue = nullptr;
1676   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1677       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1678       match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1679                    m_Value(FinalValue)))) {
1680     NodePtr PlaceholderNode = prepareCompositeNode(
1681         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1682     PlaceholderNode->ReplacementNode = FinalValue;
1683     FinalInstructions.insert(Real);
1684     FinalInstructions.insert(Imag);
1685     return submitCompositeNode(PlaceholderNode);
1686   }
1687 
1688   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1689   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1690   if (!RealShuffle || !ImagShuffle) {
1691     if (RealShuffle || ImagShuffle)
1692       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1693     return nullptr;
1694   }
1695 
1696   Value *RealOp1 = RealShuffle->getOperand(1);
1697   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1698     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1699     return nullptr;
1700   }
1701   Value *ImagOp1 = ImagShuffle->getOperand(1);
1702   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1703     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1704     return nullptr;
1705   }
1706 
1707   Value *RealOp0 = RealShuffle->getOperand(0);
1708   Value *ImagOp0 = ImagShuffle->getOperand(0);
1709 
1710   if (RealOp0 != ImagOp0) {
1711     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1712     return nullptr;
1713   }
1714 
1715   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1716   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1717   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1718     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1719     return nullptr;
1720   }
1721 
1722   if (RealMask[0] != 0 || ImagMask[0] != 1) {
1723     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1724     return nullptr;
1725   }
1726 
1727   // Type checking, the shuffle type should be a vector type of the same
1728   // scalar type, but half the size
1729   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1730     Value *Op = Shuffle->getOperand(0);
1731     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1732     auto *OpTy = cast<FixedVectorType>(Op->getType());
1733 
1734     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1735       return false;
1736     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1737       return false;
1738 
1739     return true;
1740   };
1741 
1742   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1743     if (!CheckType(Shuffle))
1744       return false;
1745 
1746     ArrayRef<int> Mask = Shuffle->getShuffleMask();
1747     int Last = *Mask.rbegin();
1748 
1749     Value *Op = Shuffle->getOperand(0);
1750     auto *OpTy = cast<FixedVectorType>(Op->getType());
1751     int NumElements = OpTy->getNumElements();
1752 
1753     // Ensure that the deinterleaving shuffle only pulls from the first
1754     // shuffle operand.
1755     return Last < NumElements;
1756   };
1757 
1758   if (RealShuffle->getType() != ImagShuffle->getType()) {
1759     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1760     return nullptr;
1761   }
1762   if (!CheckDeinterleavingShuffle(RealShuffle)) {
1763     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1764     return nullptr;
1765   }
1766   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1767     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1768     return nullptr;
1769   }
1770 
1771   NodePtr PlaceholderNode =
1772       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1773                            RealShuffle, ImagShuffle);
1774   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1775   FinalInstructions.insert(RealShuffle);
1776   FinalInstructions.insert(ImagShuffle);
1777   return submitCompositeNode(PlaceholderNode);
1778 }
1779 
1780 ComplexDeinterleavingGraph::NodePtr
identifySplat(Value * R,Value * I)1781 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1782   auto IsSplat = [](Value *V) -> bool {
1783     // Fixed-width vector with constants
1784     if (isa<ConstantDataVector>(V))
1785       return true;
1786 
1787     VectorType *VTy;
1788     ArrayRef<int> Mask;
1789     // Splats are represented differently depending on whether the repeated
1790     // value is a constant or an Instruction
1791     if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1792       if (Const->getOpcode() != Instruction::ShuffleVector)
1793         return false;
1794       VTy = cast<VectorType>(Const->getType());
1795       Mask = Const->getShuffleMask();
1796     } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1797       VTy = Shuf->getType();
1798       Mask = Shuf->getShuffleMask();
1799     } else {
1800       return false;
1801     }
1802 
1803     // When the data type is <1 x Type>, it's not possible to differentiate
1804     // between the ComplexDeinterleaving::Deinterleave and
1805     // ComplexDeinterleaving::Splat operations.
1806     if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1807       return false;
1808 
1809     return all_equal(Mask) && Mask[0] == 0;
1810   };
1811 
1812   if (!IsSplat(R) || !IsSplat(I))
1813     return nullptr;
1814 
1815   auto *Real = dyn_cast<Instruction>(R);
1816   auto *Imag = dyn_cast<Instruction>(I);
1817   if ((!Real && Imag) || (Real && !Imag))
1818     return nullptr;
1819 
1820   if (Real && Imag) {
1821     // Non-constant splats should be in the same basic block
1822     if (Real->getParent() != Imag->getParent())
1823       return nullptr;
1824 
1825     FinalInstructions.insert(Real);
1826     FinalInstructions.insert(Imag);
1827   }
1828   NodePtr PlaceholderNode =
1829       prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1830   return submitCompositeNode(PlaceholderNode);
1831 }
1832 
1833 ComplexDeinterleavingGraph::NodePtr
identifyPHINode(Instruction * Real,Instruction * Imag)1834 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1835                                             Instruction *Imag) {
1836   if (Real != RealPHI || Imag != ImagPHI)
1837     return nullptr;
1838 
1839   PHIsFound = true;
1840   NodePtr PlaceholderNode = prepareCompositeNode(
1841       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1842   return submitCompositeNode(PlaceholderNode);
1843 }
1844 
1845 ComplexDeinterleavingGraph::NodePtr
identifySelectNode(Instruction * Real,Instruction * Imag)1846 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1847                                                Instruction *Imag) {
1848   auto *SelectReal = dyn_cast<SelectInst>(Real);
1849   auto *SelectImag = dyn_cast<SelectInst>(Imag);
1850   if (!SelectReal || !SelectImag)
1851     return nullptr;
1852 
1853   Instruction *MaskA, *MaskB;
1854   Instruction *AR, *AI, *RA, *BI;
1855   if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1856                             m_Instruction(RA))) ||
1857       !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1858                             m_Instruction(BI))))
1859     return nullptr;
1860 
1861   if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1862     return nullptr;
1863 
1864   if (!MaskA->getType()->isVectorTy())
1865     return nullptr;
1866 
1867   auto NodeA = identifyNode(AR, AI);
1868   if (!NodeA)
1869     return nullptr;
1870 
1871   auto NodeB = identifyNode(RA, BI);
1872   if (!NodeB)
1873     return nullptr;
1874 
1875   NodePtr PlaceholderNode = prepareCompositeNode(
1876       ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1877   PlaceholderNode->addOperand(NodeA);
1878   PlaceholderNode->addOperand(NodeB);
1879   FinalInstructions.insert(MaskA);
1880   FinalInstructions.insert(MaskB);
1881   return submitCompositeNode(PlaceholderNode);
1882 }
1883 
replaceSymmetricNode(IRBuilderBase & B,unsigned Opcode,std::optional<FastMathFlags> Flags,Value * InputA,Value * InputB)1884 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1885                                    std::optional<FastMathFlags> Flags,
1886                                    Value *InputA, Value *InputB) {
1887   Value *I;
1888   switch (Opcode) {
1889   case Instruction::FNeg:
1890     I = B.CreateFNeg(InputA);
1891     break;
1892   case Instruction::FAdd:
1893     I = B.CreateFAdd(InputA, InputB);
1894     break;
1895   case Instruction::Add:
1896     I = B.CreateAdd(InputA, InputB);
1897     break;
1898   case Instruction::FSub:
1899     I = B.CreateFSub(InputA, InputB);
1900     break;
1901   case Instruction::Sub:
1902     I = B.CreateSub(InputA, InputB);
1903     break;
1904   case Instruction::FMul:
1905     I = B.CreateFMul(InputA, InputB);
1906     break;
1907   case Instruction::Mul:
1908     I = B.CreateMul(InputA, InputB);
1909     break;
1910   default:
1911     llvm_unreachable("Incorrect symmetric opcode");
1912   }
1913   if (Flags)
1914     cast<Instruction>(I)->setFastMathFlags(*Flags);
1915   return I;
1916 }
1917 
replaceNode(IRBuilderBase & Builder,RawNodePtr Node)1918 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1919                                                RawNodePtr Node) {
1920   if (Node->ReplacementNode)
1921     return Node->ReplacementNode;
1922 
1923   auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1924     return Node->Operands.size() > Idx
1925                ? replaceNode(Builder, Node->Operands[Idx])
1926                : nullptr;
1927   };
1928 
1929   Value *ReplacementNode;
1930   switch (Node->Operation) {
1931   case ComplexDeinterleavingOperation::CAdd:
1932   case ComplexDeinterleavingOperation::CMulPartial:
1933   case ComplexDeinterleavingOperation::Symmetric: {
1934     Value *Input0 = ReplaceOperandIfExist(Node, 0);
1935     Value *Input1 = ReplaceOperandIfExist(Node, 1);
1936     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1937     assert(!Input1 || (Input0->getType() == Input1->getType() &&
1938                        "Node inputs need to be of the same type"));
1939     assert(!Accumulator ||
1940            (Input0->getType() == Accumulator->getType() &&
1941             "Accumulator and input need to be of the same type"));
1942     if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1943       ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1944                                              Input0, Input1);
1945     else
1946       ReplacementNode = TL->createComplexDeinterleavingIR(
1947           Builder, Node->Operation, Node->Rotation, Input0, Input1,
1948           Accumulator);
1949     break;
1950   }
1951   case ComplexDeinterleavingOperation::Deinterleave:
1952     llvm_unreachable("Deinterleave node should already have ReplacementNode");
1953     break;
1954   case ComplexDeinterleavingOperation::Splat: {
1955     auto *NewTy = VectorType::getDoubleElementsVectorType(
1956         cast<VectorType>(Node->Real->getType()));
1957     auto *R = dyn_cast<Instruction>(Node->Real);
1958     auto *I = dyn_cast<Instruction>(Node->Imag);
1959     if (R && I) {
1960       // Splats that are not constant are interleaved where they are located
1961       Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1962       IRBuilder<> IRB(InsertPoint);
1963       ReplacementNode =
1964           IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy,
1965                               {Node->Real, Node->Imag});
1966     } else {
1967       ReplacementNode =
1968           Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1969                                   NewTy, {Node->Real, Node->Imag});
1970     }
1971     break;
1972   }
1973   case ComplexDeinterleavingOperation::ReductionPHI: {
1974     // If Operation is ReductionPHI, a new empty PHINode is created.
1975     // It is filled later when the ReductionOperation is processed.
1976     auto *VTy = cast<VectorType>(Node->Real->getType());
1977     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1978     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
1979     OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1980     ReplacementNode = NewPHI;
1981     break;
1982   }
1983   case ComplexDeinterleavingOperation::ReductionOperation:
1984     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1985     processReductionOperation(ReplacementNode, Node);
1986     break;
1987   case ComplexDeinterleavingOperation::ReductionSelect: {
1988     auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1989     auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1990     auto *A = replaceNode(Builder, Node->Operands[0]);
1991     auto *B = replaceNode(Builder, Node->Operands[1]);
1992     auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1993         cast<VectorType>(MaskReal->getType()));
1994     auto *NewMask =
1995         Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1996                                 NewMaskTy, {MaskReal, MaskImag});
1997     ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1998     break;
1999   }
2000   }
2001 
2002   assert(ReplacementNode && "Target failed to create Intrinsic call.");
2003   NumComplexTransformations += 1;
2004   Node->ReplacementNode = ReplacementNode;
2005   return ReplacementNode;
2006 }
2007 
processReductionOperation(Value * OperationReplacement,RawNodePtr Node)2008 void ComplexDeinterleavingGraph::processReductionOperation(
2009     Value *OperationReplacement, RawNodePtr Node) {
2010   auto *Real = cast<Instruction>(Node->Real);
2011   auto *Imag = cast<Instruction>(Node->Imag);
2012   auto *OldPHIReal = ReductionInfo[Real].first;
2013   auto *OldPHIImag = ReductionInfo[Imag].first;
2014   auto *NewPHI = OldToNewPHI[OldPHIReal];
2015 
2016   auto *VTy = cast<VectorType>(Real->getType());
2017   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2018 
2019   // We have to interleave initial origin values coming from IncomingBlock
2020   Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2021   Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2022 
2023   IRBuilder<> Builder(Incoming->getTerminator());
2024   auto *NewInit = Builder.CreateIntrinsic(
2025       Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
2026 
2027   NewPHI->addIncoming(NewInit, Incoming);
2028   NewPHI->addIncoming(OperationReplacement, BackEdge);
2029 
2030   // Deinterleave complex vector outside of loop so that it can be finally
2031   // reduced
2032   auto *FinalReductionReal = ReductionInfo[Real].second;
2033   auto *FinalReductionImag = ReductionInfo[Imag].second;
2034 
2035   Builder.SetInsertPoint(
2036       &*FinalReductionReal->getParent()->getFirstInsertionPt());
2037   auto *Deinterleave = Builder.CreateIntrinsic(
2038       Intrinsic::experimental_vector_deinterleave2,
2039       OperationReplacement->getType(), OperationReplacement);
2040 
2041   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2042   FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2043 
2044   Builder.SetInsertPoint(FinalReductionImag);
2045   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2046   FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2047 }
2048 
replaceNodes()2049 void ComplexDeinterleavingGraph::replaceNodes() {
2050   SmallVector<Instruction *, 16> DeadInstrRoots;
2051   for (auto *RootInstruction : OrderedRoots) {
2052     // Check if this potential root went through check process and we can
2053     // deinterleave it
2054     if (!RootToNode.count(RootInstruction))
2055       continue;
2056 
2057     IRBuilder<> Builder(RootInstruction);
2058     auto RootNode = RootToNode[RootInstruction];
2059     Value *R = replaceNode(Builder, RootNode.get());
2060 
2061     if (RootNode->Operation ==
2062         ComplexDeinterleavingOperation::ReductionOperation) {
2063       auto *RootReal = cast<Instruction>(RootNode->Real);
2064       auto *RootImag = cast<Instruction>(RootNode->Imag);
2065       ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2066       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2067       DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2068       DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2069     } else {
2070       assert(R && "Unable to find replacement for RootInstruction");
2071       DeadInstrRoots.push_back(RootInstruction);
2072       RootInstruction->replaceAllUsesWith(R);
2073     }
2074   }
2075 
2076   for (auto *I : DeadInstrRoots)
2077     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2078 }
2079