1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Identification:
10 // This step is responsible for finding the patterns that can be lowered to
11 // complex instructions, and building a graph to represent the complex
12 // structures. Starting from the "Converging Shuffle" (a shuffle that
13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14 // operands are evaluated and identified as "Composite Nodes" (collections of
15 // instructions that can potentially be lowered to a single complex
16 // instruction). This is performed by checking the real and imaginary components
17 // and tracking the data flow for each component while following the operand
18 // pairs. Validity of each node is expected to be done upon creation, and any
19 // validation errors should halt traversal and prevent further graph
20 // construction.
21 // Instead of relying on Shuffle operations, vector interleaving and
22 // deinterleaving can be represented by vector.interleave2 and
23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24 // these intrinsics, whereas, fixed-width vectors are recognized for both
25 // shufflevector instruction and intrinsics.
26 //
27 // Replacement:
28 // This step traverses the graph built up by identification, delegating to the
29 // target to validate and generate the correct intrinsics, and plumbs them
30 // together connecting each end of the new intrinsics graph to the existing
31 // use-def chain. This step is assumed to finish successfully, as all
32 // information is expected to be correct by this point.
33 //
34 //
35 // Internal data structure:
36 // ComplexDeinterleavingGraph:
37 // Keeps references to all the valid CompositeNodes formed as part of the
38 // transformation, and every Instruction contained within said nodes. It also
39 // holds onto a reference to the root Instruction, and the root node that should
40 // replace it.
41 //
42 // ComplexDeinterleavingCompositeNode:
43 // A CompositeNode represents a single transformation point; each node should
44 // transform into a single complex instruction (ignoring vector splitting, which
45 // would generate more instructions per node). They are identified in a
46 // depth-first manner, traversing and identifying the operands of each
47 // instruction in the order they appear in the IR.
48 // Each node maintains a reference  to its Real and Imaginary instructions,
49 // as well as any additional instructions that make up the identified operation
50 // (Internal instructions should only have uses within their containing node).
51 // A Node also contains the rotation and operation type that it represents.
52 // Operands contains pointers to other CompositeNodes, acting as the edges in
53 // the graph. ReplacementValue is the transformed Value* that has been emitted
54 // to the IR.
55 //
56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue
58 // should be pre-populated.
59 //
60 //===----------------------------------------------------------------------===//
61 
62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63 #include "llvm/ADT/Statistic.h"
64 #include "llvm/Analysis/TargetLibraryInfo.h"
65 #include "llvm/Analysis/TargetTransformInfo.h"
66 #include "llvm/CodeGen/TargetLowering.h"
67 #include "llvm/CodeGen/TargetPassConfig.h"
68 #include "llvm/CodeGen/TargetSubtargetInfo.h"
69 #include "llvm/IR/IRBuilder.h"
70 #include "llvm/IR/PatternMatch.h"
71 #include "llvm/InitializePasses.h"
72 #include "llvm/Target/TargetMachine.h"
73 #include "llvm/Transforms/Utils/Local.h"
74 #include <algorithm>
75 
76 using namespace llvm;
77 using namespace PatternMatch;
78 
79 #define DEBUG_TYPE "complex-deinterleaving"
80 
81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
82 
83 static cl::opt<bool> ComplexDeinterleavingEnabled(
84     "enable-complex-deinterleaving",
85     cl::desc("Enable generation of complex instructions"), cl::init(true),
86     cl::Hidden);
87 
88 /// Checks the given mask, and determines whether said mask is interleaving.
89 ///
90 /// To be interleaving, a mask must alternate between `i` and `i + (Length /
91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
92 /// 4x vector interleaving mask would be <0, 2, 1, 3>).
93 static bool isInterleavingMask(ArrayRef<int> Mask);
94 
95 /// Checks the given mask, and determines whether said mask is deinterleaving.
96 ///
97 /// To be deinterleaving, a mask must increment in steps of 2, and either start
98 /// with 0 or 1.
99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
100 /// <1, 3, 5, 7>).
101 static bool isDeinterleavingMask(ArrayRef<int> Mask);
102 
103 /// Returns true if the operation is a negation of V, and it works for both
104 /// integers and floats.
105 static bool isNeg(Value *V);
106 
107 /// Returns the operand for negation operation.
108 static Value *getNegOperand(Value *V);
109 
110 namespace {
111 
112 class ComplexDeinterleavingLegacyPass : public FunctionPass {
113 public:
114   static char ID;
115 
116   ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
117       : FunctionPass(ID), TM(TM) {
118     initializeComplexDeinterleavingLegacyPassPass(
119         *PassRegistry::getPassRegistry());
120   }
121 
122   StringRef getPassName() const override {
123     return "Complex Deinterleaving Pass";
124   }
125 
126   bool runOnFunction(Function &F) override;
127   void getAnalysisUsage(AnalysisUsage &AU) const override {
128     AU.addRequired<TargetLibraryInfoWrapperPass>();
129     AU.setPreservesCFG();
130   }
131 
132 private:
133   const TargetMachine *TM;
134 };
135 
136 class ComplexDeinterleavingGraph;
137 struct ComplexDeinterleavingCompositeNode {
138 
139   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
140                                      Value *R, Value *I)
141       : Operation(Op), Real(R), Imag(I) {}
142 
143 private:
144   friend class ComplexDeinterleavingGraph;
145   using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
146   using RawNodePtr = ComplexDeinterleavingCompositeNode *;
147 
148 public:
149   ComplexDeinterleavingOperation Operation;
150   Value *Real;
151   Value *Imag;
152 
153   // This two members are required exclusively for generating
154   // ComplexDeinterleavingOperation::Symmetric operations.
155   unsigned Opcode;
156   std::optional<FastMathFlags> Flags;
157 
158   ComplexDeinterleavingRotation Rotation =
159       ComplexDeinterleavingRotation::Rotation_0;
160   SmallVector<RawNodePtr> Operands;
161   Value *ReplacementNode = nullptr;
162 
163   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
164 
165   void dump() { dump(dbgs()); }
166   void dump(raw_ostream &OS) {
167     auto PrintValue = [&](Value *V) {
168       if (V) {
169         OS << "\"";
170         V->print(OS, true);
171         OS << "\"\n";
172       } else
173         OS << "nullptr\n";
174     };
175     auto PrintNodeRef = [&](RawNodePtr Ptr) {
176       if (Ptr)
177         OS << Ptr << "\n";
178       else
179         OS << "nullptr\n";
180     };
181 
182     OS << "- CompositeNode: " << this << "\n";
183     OS << "  Real: ";
184     PrintValue(Real);
185     OS << "  Imag: ";
186     PrintValue(Imag);
187     OS << "  ReplacementNode: ";
188     PrintValue(ReplacementNode);
189     OS << "  Operation: " << (int)Operation << "\n";
190     OS << "  Rotation: " << ((int)Rotation * 90) << "\n";
191     OS << "  Operands: \n";
192     for (const auto &Op : Operands) {
193       OS << "    - ";
194       PrintNodeRef(Op);
195     }
196   }
197 };
198 
199 class ComplexDeinterleavingGraph {
200 public:
201   struct Product {
202     Value *Multiplier;
203     Value *Multiplicand;
204     bool IsPositive;
205   };
206 
207   using Addend = std::pair<Value *, bool>;
208   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
209   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
210 
211   // Helper struct for holding info about potential partial multiplication
212   // candidates
213   struct PartialMulCandidate {
214     Value *Common;
215     NodePtr Node;
216     unsigned RealIdx;
217     unsigned ImagIdx;
218     bool IsNodeInverted;
219   };
220 
221   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
222                                       const TargetLibraryInfo *TLI)
223       : TL(TL), TLI(TLI) {}
224 
225 private:
226   const TargetLowering *TL = nullptr;
227   const TargetLibraryInfo *TLI = nullptr;
228   SmallVector<NodePtr> CompositeNodes;
229   DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
230 
231   SmallPtrSet<Instruction *, 16> FinalInstructions;
232 
233   /// Root instructions are instructions from which complex computation starts
234   std::map<Instruction *, NodePtr> RootToNode;
235 
236   /// Topologically sorted root instructions
237   SmallVector<Instruction *, 1> OrderedRoots;
238 
239   /// When examining a basic block for complex deinterleaving, if it is a simple
240   /// one-block loop, then the only incoming block is 'Incoming' and the
241   /// 'BackEdge' block is the block itself."
242   BasicBlock *BackEdge = nullptr;
243   BasicBlock *Incoming = nullptr;
244 
245   /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
246   /// %OutsideUser as it is shown in the IR:
247   ///
248   /// vector.body:
249   ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
250   ///                                [ %ReductionOp, %vector.body ]
251   ///   ...
252   ///   %ReductionOp = fadd i64 ...
253   ///   ...
254   ///   br i1 %condition, label %vector.body, %middle.block
255   ///
256   /// middle.block:
257   ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
258   ///
259   /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
260   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
261   std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
262 
263   /// In the process of detecting a reduction, we consider a pair of
264   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
265   /// traverse the use-tree to detect complex operations. As this is a reduction
266   /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
267   /// to the %ReductionOPs that we suspect to be complex.
268   /// RealPHI and ImagPHI are used by the identifyPHINode method.
269   PHINode *RealPHI = nullptr;
270   PHINode *ImagPHI = nullptr;
271 
272   /// Set this flag to true if RealPHI and ImagPHI were reached during reduction
273   /// detection.
274   bool PHIsFound = false;
275 
276   /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
277   /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
278   /// This mapping is populated during
279   /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
280   /// used in the ComplexDeinterleavingOperation::ReductionOperation node
281   /// replacement process.
282   std::map<PHINode *, PHINode *> OldToNewPHI;
283 
284   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
285                                Value *R, Value *I) {
286     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
287              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
288             (R && I)) &&
289            "Reduction related nodes must have Real and Imaginary parts");
290     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
291                                                                 I);
292   }
293 
294   NodePtr submitCompositeNode(NodePtr Node) {
295     CompositeNodes.push_back(Node);
296     if (Node->Real && Node->Imag)
297       CachedResult[{Node->Real, Node->Imag}] = Node;
298     return Node;
299   }
300 
301   /// Identifies a complex partial multiply pattern and its rotation, based on
302   /// the following patterns
303   ///
304   ///  0:  r: cr + ar * br
305   ///      i: ci + ar * bi
306   /// 90:  r: cr - ai * bi
307   ///      i: ci + ai * br
308   /// 180: r: cr - ar * br
309   ///      i: ci - ar * bi
310   /// 270: r: cr + ai * bi
311   ///      i: ci - ai * br
312   NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
313 
314   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
315   /// is partially known from identifyPartialMul, filling in the other half of
316   /// the complex pair.
317   NodePtr
318   identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
319                               std::pair<Value *, Value *> &CommonOperandI);
320 
321   /// Identifies a complex add pattern and its rotation, based on the following
322   /// patterns.
323   ///
324   /// 90:  r: ar - bi
325   ///      i: ai + br
326   /// 270: r: ar + bi
327   ///      i: ai - br
328   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
329   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
330 
331   NodePtr identifyNode(Value *R, Value *I);
332 
333   /// Determine if a sum of complex numbers can be formed from \p RealAddends
334   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
335   /// Return nullptr if it is not possible to construct a complex number.
336   /// \p Flags are needed to generate symmetric Add and Sub operations.
337   NodePtr identifyAdditions(std::list<Addend> &RealAddends,
338                             std::list<Addend> &ImagAddends,
339                             std::optional<FastMathFlags> Flags,
340                             NodePtr Accumulator);
341 
342   /// Extract one addend that have both real and imaginary parts positive.
343   NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
344                                 std::list<Addend> &ImagAddends);
345 
346   /// Determine if sum of multiplications of complex numbers can be formed from
347   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
348   /// to it. Return nullptr if it is not possible to construct a complex number.
349   NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
350                                   std::vector<Product> &ImagMuls,
351                                   NodePtr Accumulator);
352 
353   /// Go through pairs of multiplication (one Real and one Imag) and find all
354   /// possible candidates for partial multiplication and put them into \p
355   /// Candidates. Returns true if all Product has pair with common operand
356   bool collectPartialMuls(const std::vector<Product> &RealMuls,
357                           const std::vector<Product> &ImagMuls,
358                           std::vector<PartialMulCandidate> &Candidates);
359 
360   /// If the code is compiled with -Ofast or expressions have `reassoc` flag,
361   /// the order of complex computation operations may be significantly altered,
362   /// and the real and imaginary parts may not be executed in parallel. This
363   /// function takes this into consideration and employs a more general approach
364   /// to identify complex computations. Initially, it gathers all the addends
365   /// and multiplicands and then constructs a complex expression from them.
366   NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
367 
368   NodePtr identifyRoot(Instruction *I);
369 
370   /// Identifies the Deinterleave operation applied to a vector containing
371   /// complex numbers. There are two ways to represent the Deinterleave
372   /// operation:
373   /// * Using two shufflevectors with even indices for /pReal instruction and
374   /// odd indices for /pImag instructions (only for fixed-width vectors)
375   /// * Using two extractvalue instructions applied to `vector.deinterleave2`
376   /// intrinsic (for both fixed and scalable vectors)
377   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
378 
379   /// identifying the operation that represents a complex number repeated in a
380   /// Splat vector. There are two possible types of splats: ConstantExpr with
381   /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
382   /// initialization mask with all values set to zero.
383   NodePtr identifySplat(Value *Real, Value *Imag);
384 
385   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
386 
387   /// Identifies SelectInsts in a loop that has reduction with predication masks
388   /// and/or predicated tail folding
389   NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
390 
391   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
392 
393   /// Complete IR modifications after producing new reduction operation:
394   /// * Populate the PHINode generated for
395   /// ComplexDeinterleavingOperation::ReductionPHI
396   /// * Deinterleave the final value outside of the loop and repurpose original
397   /// reduction users
398   void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
399 
400 public:
401   void dump() { dump(dbgs()); }
402   void dump(raw_ostream &OS) {
403     for (const auto &Node : CompositeNodes)
404       Node->dump(OS);
405   }
406 
407   /// Returns false if the deinterleaving operation should be cancelled for the
408   /// current graph.
409   bool identifyNodes(Instruction *RootI);
410 
411   /// In case \pB is one-block loop, this function seeks potential reductions
412   /// and populates ReductionInfo. Returns true if any reductions were
413   /// identified.
414   bool collectPotentialReductions(BasicBlock *B);
415 
416   void identifyReductionNodes();
417 
418   /// Check that every instruction, from the roots to the leaves, has internal
419   /// uses.
420   bool checkNodes();
421 
422   /// Perform the actual replacement of the underlying instruction graph.
423   void replaceNodes();
424 };
425 
426 class ComplexDeinterleaving {
427 public:
428   ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
429       : TL(tl), TLI(tli) {}
430   bool runOnFunction(Function &F);
431 
432 private:
433   bool evaluateBasicBlock(BasicBlock *B);
434 
435   const TargetLowering *TL = nullptr;
436   const TargetLibraryInfo *TLI = nullptr;
437 };
438 
439 } // namespace
440 
441 char ComplexDeinterleavingLegacyPass::ID = 0;
442 
443 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
444                       "Complex Deinterleaving", false, false)
445 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
446                     "Complex Deinterleaving", false, false)
447 
448 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
449                                                  FunctionAnalysisManager &AM) {
450   const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
451   auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
452   if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
453     return PreservedAnalyses::all();
454 
455   PreservedAnalyses PA;
456   PA.preserve<FunctionAnalysisManagerModuleProxy>();
457   return PA;
458 }
459 
460 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
461   return new ComplexDeinterleavingLegacyPass(TM);
462 }
463 
464 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
465   const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
466   auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
467   return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
468 }
469 
470 bool ComplexDeinterleaving::runOnFunction(Function &F) {
471   if (!ComplexDeinterleavingEnabled) {
472     LLVM_DEBUG(
473         dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
474     return false;
475   }
476 
477   if (!TL->isComplexDeinterleavingSupported()) {
478     LLVM_DEBUG(
479         dbgs() << "Complex deinterleaving has been disabled, target does "
480                   "not support lowering of complex number operations.\n");
481     return false;
482   }
483 
484   bool Changed = false;
485   for (auto &B : F)
486     Changed |= evaluateBasicBlock(&B);
487 
488   return Changed;
489 }
490 
491 static bool isInterleavingMask(ArrayRef<int> Mask) {
492   // If the size is not even, it's not an interleaving mask
493   if ((Mask.size() & 1))
494     return false;
495 
496   int HalfNumElements = Mask.size() / 2;
497   for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
498     int MaskIdx = Idx * 2;
499     if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
500       return false;
501   }
502 
503   return true;
504 }
505 
506 static bool isDeinterleavingMask(ArrayRef<int> Mask) {
507   int Offset = Mask[0];
508   int HalfNumElements = Mask.size() / 2;
509 
510   for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
511     if (Mask[Idx] != (Idx * 2) + Offset)
512       return false;
513   }
514 
515   return true;
516 }
517 
518 bool isNeg(Value *V) {
519   return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
520 }
521 
522 Value *getNegOperand(Value *V) {
523   assert(isNeg(V));
524   auto *I = cast<Instruction>(V);
525   if (I->getOpcode() == Instruction::FNeg)
526     return I->getOperand(0);
527 
528   return I->getOperand(1);
529 }
530 
531 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
532   ComplexDeinterleavingGraph Graph(TL, TLI);
533   if (Graph.collectPotentialReductions(B))
534     Graph.identifyReductionNodes();
535 
536   for (auto &I : *B)
537     Graph.identifyNodes(&I);
538 
539   if (Graph.checkNodes()) {
540     Graph.replaceNodes();
541     return true;
542   }
543 
544   return false;
545 }
546 
547 ComplexDeinterleavingGraph::NodePtr
548 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
549     Instruction *Real, Instruction *Imag,
550     std::pair<Value *, Value *> &PartialMatch) {
551   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
552                     << "\n");
553 
554   if (!Real->hasOneUse() || !Imag->hasOneUse()) {
555     LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
556     return nullptr;
557   }
558 
559   if ((Real->getOpcode() != Instruction::FMul &&
560        Real->getOpcode() != Instruction::Mul) ||
561       (Imag->getOpcode() != Instruction::FMul &&
562        Imag->getOpcode() != Instruction::Mul)) {
563     LLVM_DEBUG(
564         dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
565     return nullptr;
566   }
567 
568   Value *R0 = Real->getOperand(0);
569   Value *R1 = Real->getOperand(1);
570   Value *I0 = Imag->getOperand(0);
571   Value *I1 = Imag->getOperand(1);
572 
573   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
574   // rotations and use the operand.
575   unsigned Negs = 0;
576   Value *Op;
577   if (match(R0, m_Neg(m_Value(Op)))) {
578     Negs |= 1;
579     R0 = Op;
580   } else if (match(R1, m_Neg(m_Value(Op)))) {
581     Negs |= 1;
582     R1 = Op;
583   }
584 
585   if (isNeg(I0)) {
586     Negs |= 2;
587     Negs ^= 1;
588     I0 = Op;
589   } else if (match(I1, m_Neg(m_Value(Op)))) {
590     Negs |= 2;
591     Negs ^= 1;
592     I1 = Op;
593   }
594 
595   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
596 
597   Value *CommonOperand;
598   Value *UncommonRealOp;
599   Value *UncommonImagOp;
600 
601   if (R0 == I0 || R0 == I1) {
602     CommonOperand = R0;
603     UncommonRealOp = R1;
604   } else if (R1 == I0 || R1 == I1) {
605     CommonOperand = R1;
606     UncommonRealOp = R0;
607   } else {
608     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
609     return nullptr;
610   }
611 
612   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
613   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
614       Rotation == ComplexDeinterleavingRotation::Rotation_270)
615     std::swap(UncommonRealOp, UncommonImagOp);
616 
617   // Between identifyPartialMul and here we need to have found a complete valid
618   // pair from the CommonOperand of each part.
619   if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
620       Rotation == ComplexDeinterleavingRotation::Rotation_180)
621     PartialMatch.first = CommonOperand;
622   else
623     PartialMatch.second = CommonOperand;
624 
625   if (!PartialMatch.first || !PartialMatch.second) {
626     LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
627     return nullptr;
628   }
629 
630   NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
631   if (!CommonNode) {
632     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
633     return nullptr;
634   }
635 
636   NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
637   if (!UncommonNode) {
638     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
639     return nullptr;
640   }
641 
642   NodePtr Node = prepareCompositeNode(
643       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
644   Node->Rotation = Rotation;
645   Node->addOperand(CommonNode);
646   Node->addOperand(UncommonNode);
647   return submitCompositeNode(Node);
648 }
649 
650 ComplexDeinterleavingGraph::NodePtr
651 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
652                                                Instruction *Imag) {
653   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
654                     << "\n");
655   // Determine rotation
656   auto IsAdd = [](unsigned Op) {
657     return Op == Instruction::FAdd || Op == Instruction::Add;
658   };
659   auto IsSub = [](unsigned Op) {
660     return Op == Instruction::FSub || Op == Instruction::Sub;
661   };
662   ComplexDeinterleavingRotation Rotation;
663   if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
664     Rotation = ComplexDeinterleavingRotation::Rotation_0;
665   else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
666     Rotation = ComplexDeinterleavingRotation::Rotation_90;
667   else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
668     Rotation = ComplexDeinterleavingRotation::Rotation_180;
669   else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
670     Rotation = ComplexDeinterleavingRotation::Rotation_270;
671   else {
672     LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
673     return nullptr;
674   }
675 
676   if (isa<FPMathOperator>(Real) &&
677       (!Real->getFastMathFlags().allowContract() ||
678        !Imag->getFastMathFlags().allowContract())) {
679     LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
680     return nullptr;
681   }
682 
683   Value *CR = Real->getOperand(0);
684   Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
685   if (!RealMulI)
686     return nullptr;
687   Value *CI = Imag->getOperand(0);
688   Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
689   if (!ImagMulI)
690     return nullptr;
691 
692   if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
693     LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
694     return nullptr;
695   }
696 
697   Value *R0 = RealMulI->getOperand(0);
698   Value *R1 = RealMulI->getOperand(1);
699   Value *I0 = ImagMulI->getOperand(0);
700   Value *I1 = ImagMulI->getOperand(1);
701 
702   Value *CommonOperand;
703   Value *UncommonRealOp;
704   Value *UncommonImagOp;
705 
706   if (R0 == I0 || R0 == I1) {
707     CommonOperand = R0;
708     UncommonRealOp = R1;
709   } else if (R1 == I0 || R1 == I1) {
710     CommonOperand = R1;
711     UncommonRealOp = R0;
712   } else {
713     LLVM_DEBUG(dbgs() << "  - No equal operand\n");
714     return nullptr;
715   }
716 
717   UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
718   if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
719       Rotation == ComplexDeinterleavingRotation::Rotation_270)
720     std::swap(UncommonRealOp, UncommonImagOp);
721 
722   std::pair<Value *, Value *> PartialMatch(
723       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
724        Rotation == ComplexDeinterleavingRotation::Rotation_180)
725           ? CommonOperand
726           : nullptr,
727       (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
728        Rotation == ComplexDeinterleavingRotation::Rotation_270)
729           ? CommonOperand
730           : nullptr);
731 
732   auto *CRInst = dyn_cast<Instruction>(CR);
733   auto *CIInst = dyn_cast<Instruction>(CI);
734 
735   if (!CRInst || !CIInst) {
736     LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
737     return nullptr;
738   }
739 
740   NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
741   if (!CNode) {
742     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
743     return nullptr;
744   }
745 
746   NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
747   if (!UncommonRes) {
748     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
749     return nullptr;
750   }
751 
752   assert(PartialMatch.first && PartialMatch.second);
753   NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
754   if (!CommonRes) {
755     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
756     return nullptr;
757   }
758 
759   NodePtr Node = prepareCompositeNode(
760       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
761   Node->Rotation = Rotation;
762   Node->addOperand(CommonRes);
763   Node->addOperand(UncommonRes);
764   Node->addOperand(CNode);
765   return submitCompositeNode(Node);
766 }
767 
768 ComplexDeinterleavingGraph::NodePtr
769 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
770   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
771 
772   // Determine rotation
773   ComplexDeinterleavingRotation Rotation;
774   if ((Real->getOpcode() == Instruction::FSub &&
775        Imag->getOpcode() == Instruction::FAdd) ||
776       (Real->getOpcode() == Instruction::Sub &&
777        Imag->getOpcode() == Instruction::Add))
778     Rotation = ComplexDeinterleavingRotation::Rotation_90;
779   else if ((Real->getOpcode() == Instruction::FAdd &&
780             Imag->getOpcode() == Instruction::FSub) ||
781            (Real->getOpcode() == Instruction::Add &&
782             Imag->getOpcode() == Instruction::Sub))
783     Rotation = ComplexDeinterleavingRotation::Rotation_270;
784   else {
785     LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
786     return nullptr;
787   }
788 
789   auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
790   auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
791   auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
792   auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
793 
794   if (!AR || !AI || !BR || !BI) {
795     LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
796     return nullptr;
797   }
798 
799   NodePtr ResA = identifyNode(AR, AI);
800   if (!ResA) {
801     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
802     return nullptr;
803   }
804   NodePtr ResB = identifyNode(BR, BI);
805   if (!ResB) {
806     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
807     return nullptr;
808   }
809 
810   NodePtr Node =
811       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
812   Node->Rotation = Rotation;
813   Node->addOperand(ResA);
814   Node->addOperand(ResB);
815   return submitCompositeNode(Node);
816 }
817 
818 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
819   unsigned OpcA = A->getOpcode();
820   unsigned OpcB = B->getOpcode();
821 
822   return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
823          (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
824          (OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
825          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
826 }
827 
828 static bool isInstructionPairMul(Instruction *A, Instruction *B) {
829   auto Pattern =
830       m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
831 
832   return match(A, Pattern) && match(B, Pattern);
833 }
834 
835 static bool isInstructionPotentiallySymmetric(Instruction *I) {
836   switch (I->getOpcode()) {
837   case Instruction::FAdd:
838   case Instruction::FSub:
839   case Instruction::FMul:
840   case Instruction::FNeg:
841   case Instruction::Add:
842   case Instruction::Sub:
843   case Instruction::Mul:
844     return true;
845   default:
846     return false;
847   }
848 }
849 
850 ComplexDeinterleavingGraph::NodePtr
851 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
852                                                        Instruction *Imag) {
853   if (Real->getOpcode() != Imag->getOpcode())
854     return nullptr;
855 
856   if (!isInstructionPotentiallySymmetric(Real) ||
857       !isInstructionPotentiallySymmetric(Imag))
858     return nullptr;
859 
860   auto *R0 = Real->getOperand(0);
861   auto *I0 = Imag->getOperand(0);
862 
863   NodePtr Op0 = identifyNode(R0, I0);
864   NodePtr Op1 = nullptr;
865   if (Op0 == nullptr)
866     return nullptr;
867 
868   if (Real->isBinaryOp()) {
869     auto *R1 = Real->getOperand(1);
870     auto *I1 = Imag->getOperand(1);
871     Op1 = identifyNode(R1, I1);
872     if (Op1 == nullptr)
873       return nullptr;
874   }
875 
876   if (isa<FPMathOperator>(Real) &&
877       Real->getFastMathFlags() != Imag->getFastMathFlags())
878     return nullptr;
879 
880   auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
881                                    Real, Imag);
882   Node->Opcode = Real->getOpcode();
883   if (isa<FPMathOperator>(Real))
884     Node->Flags = Real->getFastMathFlags();
885 
886   Node->addOperand(Op0);
887   if (Real->isBinaryOp())
888     Node->addOperand(Op1);
889 
890   return submitCompositeNode(Node);
891 }
892 
893 ComplexDeinterleavingGraph::NodePtr
894 ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
895   LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
896   assert(R->getType() == I->getType() &&
897          "Real and imaginary parts should not have different types");
898 
899   auto It = CachedResult.find({R, I});
900   if (It != CachedResult.end()) {
901     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
902     return It->second;
903   }
904 
905   if (NodePtr CN = identifySplat(R, I))
906     return CN;
907 
908   auto *Real = dyn_cast<Instruction>(R);
909   auto *Imag = dyn_cast<Instruction>(I);
910   if (!Real || !Imag)
911     return nullptr;
912 
913   if (NodePtr CN = identifyDeinterleave(Real, Imag))
914     return CN;
915 
916   if (NodePtr CN = identifyPHINode(Real, Imag))
917     return CN;
918 
919   if (NodePtr CN = identifySelectNode(Real, Imag))
920     return CN;
921 
922   auto *VTy = cast<VectorType>(Real->getType());
923   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
924 
925   bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
926       ComplexDeinterleavingOperation::CMulPartial, NewVTy);
927   bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
928       ComplexDeinterleavingOperation::CAdd, NewVTy);
929 
930   if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
931     if (NodePtr CN = identifyPartialMul(Real, Imag))
932       return CN;
933   }
934 
935   if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
936     if (NodePtr CN = identifyAdd(Real, Imag))
937       return CN;
938   }
939 
940   if (HasCMulSupport && HasCAddSupport) {
941     if (NodePtr CN = identifyReassocNodes(Real, Imag))
942       return CN;
943   }
944 
945   if (NodePtr CN = identifySymmetricOperation(Real, Imag))
946     return CN;
947 
948   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
949   CachedResult[{R, I}] = nullptr;
950   return nullptr;
951 }
952 
953 ComplexDeinterleavingGraph::NodePtr
954 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
955                                                  Instruction *Imag) {
956   auto IsOperationSupported = [](unsigned Opcode) -> bool {
957     return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
958            Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
959            Opcode == Instruction::Sub;
960   };
961 
962   if (!IsOperationSupported(Real->getOpcode()) ||
963       !IsOperationSupported(Imag->getOpcode()))
964     return nullptr;
965 
966   std::optional<FastMathFlags> Flags;
967   if (isa<FPMathOperator>(Real)) {
968     if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
969       LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
970                            "not identical\n");
971       return nullptr;
972     }
973 
974     Flags = Real->getFastMathFlags();
975     if (!Flags->allowReassoc()) {
976       LLVM_DEBUG(
977           dbgs()
978           << "the 'Reassoc' attribute is missing in the FastMath flags\n");
979       return nullptr;
980     }
981   }
982 
983   // Collect multiplications and addend instructions from the given instruction
984   // while traversing it operands. Additionally, verify that all instructions
985   // have the same fast math flags.
986   auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
987                           std::list<Addend> &Addends) -> bool {
988     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
989     SmallPtrSet<Value *, 8> Visited;
990     while (!Worklist.empty()) {
991       auto [V, IsPositive] = Worklist.back();
992       Worklist.pop_back();
993       if (!Visited.insert(V).second)
994         continue;
995 
996       Instruction *I = dyn_cast<Instruction>(V);
997       if (!I) {
998         Addends.emplace_back(V, IsPositive);
999         continue;
1000       }
1001 
1002       // If an instruction has more than one user, it indicates that it either
1003       // has an external user, which will be later checked by the checkNodes
1004       // function, or it is a subexpression utilized by multiple expressions. In
1005       // the latter case, we will attempt to separately identify the complex
1006       // operation from here in order to create a shared
1007       // ComplexDeinterleavingCompositeNode.
1008       if (I != Insn && I->getNumUses() > 1) {
1009         LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1010         Addends.emplace_back(I, IsPositive);
1011         continue;
1012       }
1013       switch (I->getOpcode()) {
1014       case Instruction::FAdd:
1015       case Instruction::Add:
1016         Worklist.emplace_back(I->getOperand(1), IsPositive);
1017         Worklist.emplace_back(I->getOperand(0), IsPositive);
1018         break;
1019       case Instruction::FSub:
1020         Worklist.emplace_back(I->getOperand(1), !IsPositive);
1021         Worklist.emplace_back(I->getOperand(0), IsPositive);
1022         break;
1023       case Instruction::Sub:
1024         if (isNeg(I)) {
1025           Worklist.emplace_back(getNegOperand(I), !IsPositive);
1026         } else {
1027           Worklist.emplace_back(I->getOperand(1), !IsPositive);
1028           Worklist.emplace_back(I->getOperand(0), IsPositive);
1029         }
1030         break;
1031       case Instruction::FMul:
1032       case Instruction::Mul: {
1033         Value *A, *B;
1034         if (isNeg(I->getOperand(0))) {
1035           A = getNegOperand(I->getOperand(0));
1036           IsPositive = !IsPositive;
1037         } else {
1038           A = I->getOperand(0);
1039         }
1040 
1041         if (isNeg(I->getOperand(1))) {
1042           B = getNegOperand(I->getOperand(1));
1043           IsPositive = !IsPositive;
1044         } else {
1045           B = I->getOperand(1);
1046         }
1047         Muls.push_back(Product{A, B, IsPositive});
1048         break;
1049       }
1050       case Instruction::FNeg:
1051         Worklist.emplace_back(I->getOperand(0), !IsPositive);
1052         break;
1053       default:
1054         Addends.emplace_back(I, IsPositive);
1055         continue;
1056       }
1057 
1058       if (Flags && I->getFastMathFlags() != *Flags) {
1059         LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1060                              "inconsistent with the root instructions' flags: "
1061                           << *I << "\n");
1062         return false;
1063       }
1064     }
1065     return true;
1066   };
1067 
1068   std::vector<Product> RealMuls, ImagMuls;
1069   std::list<Addend> RealAddends, ImagAddends;
1070   if (!Collect(Real, RealMuls, RealAddends) ||
1071       !Collect(Imag, ImagMuls, ImagAddends))
1072     return nullptr;
1073 
1074   if (RealAddends.size() != ImagAddends.size())
1075     return nullptr;
1076 
1077   NodePtr FinalNode;
1078   if (!RealMuls.empty() || !ImagMuls.empty()) {
1079     // If there are multiplicands, extract positive addend and use it as an
1080     // accumulator
1081     FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1082     FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1083     if (!FinalNode)
1084       return nullptr;
1085   }
1086 
1087   // Identify and process remaining additions
1088   if (!RealAddends.empty() || !ImagAddends.empty()) {
1089     FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1090     if (!FinalNode)
1091       return nullptr;
1092   }
1093   assert(FinalNode && "FinalNode can not be nullptr here");
1094   // Set the Real and Imag fields of the final node and submit it
1095   FinalNode->Real = Real;
1096   FinalNode->Imag = Imag;
1097   submitCompositeNode(FinalNode);
1098   return FinalNode;
1099 }
1100 
1101 bool ComplexDeinterleavingGraph::collectPartialMuls(
1102     const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1103     std::vector<PartialMulCandidate> &PartialMulCandidates) {
1104   // Helper function to extract a common operand from two products
1105   auto FindCommonInstruction = [](const Product &Real,
1106                                   const Product &Imag) -> Value * {
1107     if (Real.Multiplicand == Imag.Multiplicand ||
1108         Real.Multiplicand == Imag.Multiplier)
1109       return Real.Multiplicand;
1110 
1111     if (Real.Multiplier == Imag.Multiplicand ||
1112         Real.Multiplier == Imag.Multiplier)
1113       return Real.Multiplier;
1114 
1115     return nullptr;
1116   };
1117 
1118   // Iterating over real and imaginary multiplications to find common operands
1119   // If a common operand is found, a partial multiplication candidate is created
1120   // and added to the candidates vector The function returns false if no common
1121   // operands are found for any product
1122   for (unsigned i = 0; i < RealMuls.size(); ++i) {
1123     bool FoundCommon = false;
1124     for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1125       auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1126       if (!Common)
1127         continue;
1128 
1129       auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1130                                                    : RealMuls[i].Multiplicand;
1131       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1132                                                    : ImagMuls[j].Multiplicand;
1133 
1134       auto Node = identifyNode(A, B);
1135       if (Node) {
1136         FoundCommon = true;
1137         PartialMulCandidates.push_back({Common, Node, i, j, false});
1138       }
1139 
1140       Node = identifyNode(B, A);
1141       if (Node) {
1142         FoundCommon = true;
1143         PartialMulCandidates.push_back({Common, Node, i, j, true});
1144       }
1145     }
1146     if (!FoundCommon)
1147       return false;
1148   }
1149   return true;
1150 }
1151 
1152 ComplexDeinterleavingGraph::NodePtr
1153 ComplexDeinterleavingGraph::identifyMultiplications(
1154     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1155     NodePtr Accumulator = nullptr) {
1156   if (RealMuls.size() != ImagMuls.size())
1157     return nullptr;
1158 
1159   std::vector<PartialMulCandidate> Info;
1160   if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1161     return nullptr;
1162 
1163   // Map to store common instruction to node pointers
1164   std::map<Value *, NodePtr> CommonToNode;
1165   std::vector<bool> Processed(Info.size(), false);
1166   for (unsigned I = 0; I < Info.size(); ++I) {
1167     if (Processed[I])
1168       continue;
1169 
1170     PartialMulCandidate &InfoA = Info[I];
1171     for (unsigned J = I + 1; J < Info.size(); ++J) {
1172       if (Processed[J])
1173         continue;
1174 
1175       PartialMulCandidate &InfoB = Info[J];
1176       auto *InfoReal = &InfoA;
1177       auto *InfoImag = &InfoB;
1178 
1179       auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1180       if (!NodeFromCommon) {
1181         std::swap(InfoReal, InfoImag);
1182         NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1183       }
1184       if (!NodeFromCommon)
1185         continue;
1186 
1187       CommonToNode[InfoReal->Common] = NodeFromCommon;
1188       CommonToNode[InfoImag->Common] = NodeFromCommon;
1189       Processed[I] = true;
1190       Processed[J] = true;
1191     }
1192   }
1193 
1194   std::vector<bool> ProcessedReal(RealMuls.size(), false);
1195   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1196   NodePtr Result = Accumulator;
1197   for (auto &PMI : Info) {
1198     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1199       continue;
1200 
1201     auto It = CommonToNode.find(PMI.Common);
1202     // TODO: Process independent complex multiplications. Cases like this:
1203     //  A.real() * B where both A and B are complex numbers.
1204     if (It == CommonToNode.end()) {
1205       LLVM_DEBUG({
1206         dbgs() << "Unprocessed independent partial multiplication:\n";
1207         for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1208           dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1209                            << " multiplied by " << *Mul->Multiplicand << "\n";
1210       });
1211       return nullptr;
1212     }
1213 
1214     auto &RealMul = RealMuls[PMI.RealIdx];
1215     auto &ImagMul = ImagMuls[PMI.ImagIdx];
1216 
1217     auto NodeA = It->second;
1218     auto NodeB = PMI.Node;
1219     auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1220     // The following table illustrates the relationship between multiplications
1221     // and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1222     // can see:
1223     //
1224     // Rotation |   Real |   Imag |
1225     // ---------+--------+--------+
1226     //        0 |  x * u |  x * v |
1227     //       90 | -y * v |  y * u |
1228     //      180 | -x * u | -x * v |
1229     //      270 |  y * v | -y * u |
1230     //
1231     // Check if the candidate can indeed be represented by partial
1232     // multiplication
1233     // TODO: Add support for multiplication by complex one
1234     if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1235         (!IsMultiplicandReal && !PMI.IsNodeInverted))
1236       continue;
1237 
1238     // Determine the rotation based on the multiplications
1239     ComplexDeinterleavingRotation Rotation;
1240     if (IsMultiplicandReal) {
1241       // Detect 0 and 180 degrees rotation
1242       if (RealMul.IsPositive && ImagMul.IsPositive)
1243         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1244       else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1245         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1246       else
1247         continue;
1248 
1249     } else {
1250       // Detect 90 and 270 degrees rotation
1251       if (!RealMul.IsPositive && ImagMul.IsPositive)
1252         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1253       else if (RealMul.IsPositive && !ImagMul.IsPositive)
1254         Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1255       else
1256         continue;
1257     }
1258 
1259     LLVM_DEBUG({
1260       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1261       dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1262       dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1263       dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1264       dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1265       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1266     });
1267 
1268     NodePtr NodeMul = prepareCompositeNode(
1269         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1270     NodeMul->Rotation = Rotation;
1271     NodeMul->addOperand(NodeA);
1272     NodeMul->addOperand(NodeB);
1273     if (Result)
1274       NodeMul->addOperand(Result);
1275     submitCompositeNode(NodeMul);
1276     Result = NodeMul;
1277     ProcessedReal[PMI.RealIdx] = true;
1278     ProcessedImag[PMI.ImagIdx] = true;
1279   }
1280 
1281   // Ensure all products have been processed, if not return nullptr.
1282   if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1283       !all_of(ProcessedImag, [](bool V) { return V; })) {
1284 
1285     // Dump debug information about which partial multiplications are not
1286     // processed.
1287     LLVM_DEBUG({
1288       dbgs() << "Unprocessed products (Real):\n";
1289       for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1290         if (!ProcessedReal[i])
1291           dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1292                            << *RealMuls[i].Multiplier << " multiplied by "
1293                            << *RealMuls[i].Multiplicand << "\n";
1294       }
1295       dbgs() << "Unprocessed products (Imag):\n";
1296       for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1297         if (!ProcessedImag[i])
1298           dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1299                            << *ImagMuls[i].Multiplier << " multiplied by "
1300                            << *ImagMuls[i].Multiplicand << "\n";
1301       }
1302     });
1303     return nullptr;
1304   }
1305 
1306   return Result;
1307 }
1308 
1309 ComplexDeinterleavingGraph::NodePtr
1310 ComplexDeinterleavingGraph::identifyAdditions(
1311     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1312     std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1313   if (RealAddends.size() != ImagAddends.size())
1314     return nullptr;
1315 
1316   NodePtr Result;
1317   // If we have accumulator use it as first addend
1318   if (Accumulator)
1319     Result = Accumulator;
1320   // Otherwise find an element with both positive real and imaginary parts.
1321   else
1322     Result = extractPositiveAddend(RealAddends, ImagAddends);
1323 
1324   if (!Result)
1325     return nullptr;
1326 
1327   while (!RealAddends.empty()) {
1328     auto ItR = RealAddends.begin();
1329     auto [R, IsPositiveR] = *ItR;
1330 
1331     bool FoundImag = false;
1332     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1333       auto [I, IsPositiveI] = *ItI;
1334       ComplexDeinterleavingRotation Rotation;
1335       if (IsPositiveR && IsPositiveI)
1336         Rotation = ComplexDeinterleavingRotation::Rotation_0;
1337       else if (!IsPositiveR && IsPositiveI)
1338         Rotation = ComplexDeinterleavingRotation::Rotation_90;
1339       else if (!IsPositiveR && !IsPositiveI)
1340         Rotation = ComplexDeinterleavingRotation::Rotation_180;
1341       else
1342         Rotation = ComplexDeinterleavingRotation::Rotation_270;
1343 
1344       NodePtr AddNode;
1345       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1346           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1347         AddNode = identifyNode(R, I);
1348       } else {
1349         AddNode = identifyNode(I, R);
1350       }
1351       if (AddNode) {
1352         LLVM_DEBUG({
1353           dbgs() << "Identified addition:\n";
1354           dbgs().indent(4) << "X: " << *R << "\n";
1355           dbgs().indent(4) << "Y: " << *I << "\n";
1356           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1357         });
1358 
1359         NodePtr TmpNode;
1360         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1361           TmpNode = prepareCompositeNode(
1362               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1363           if (Flags) {
1364             TmpNode->Opcode = Instruction::FAdd;
1365             TmpNode->Flags = *Flags;
1366           } else {
1367             TmpNode->Opcode = Instruction::Add;
1368           }
1369         } else if (Rotation ==
1370                    llvm::ComplexDeinterleavingRotation::Rotation_180) {
1371           TmpNode = prepareCompositeNode(
1372               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1373           if (Flags) {
1374             TmpNode->Opcode = Instruction::FSub;
1375             TmpNode->Flags = *Flags;
1376           } else {
1377             TmpNode->Opcode = Instruction::Sub;
1378           }
1379         } else {
1380           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1381                                          nullptr, nullptr);
1382           TmpNode->Rotation = Rotation;
1383         }
1384 
1385         TmpNode->addOperand(Result);
1386         TmpNode->addOperand(AddNode);
1387         submitCompositeNode(TmpNode);
1388         Result = TmpNode;
1389         RealAddends.erase(ItR);
1390         ImagAddends.erase(ItI);
1391         FoundImag = true;
1392         break;
1393       }
1394     }
1395     if (!FoundImag)
1396       return nullptr;
1397   }
1398   return Result;
1399 }
1400 
1401 ComplexDeinterleavingGraph::NodePtr
1402 ComplexDeinterleavingGraph::extractPositiveAddend(
1403     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1404   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1405     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1406       auto [R, IsPositiveR] = *ItR;
1407       auto [I, IsPositiveI] = *ItI;
1408       if (IsPositiveR && IsPositiveI) {
1409         auto Result = identifyNode(R, I);
1410         if (Result) {
1411           RealAddends.erase(ItR);
1412           ImagAddends.erase(ItI);
1413           return Result;
1414         }
1415       }
1416     }
1417   }
1418   return nullptr;
1419 }
1420 
1421 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1422   // This potential root instruction might already have been recognized as
1423   // reduction. Because RootToNode maps both Real and Imaginary parts to
1424   // CompositeNode we should choose only one either Real or Imag instruction to
1425   // use as an anchor for generating complex instruction.
1426   auto It = RootToNode.find(RootI);
1427   if (It != RootToNode.end()) {
1428     auto RootNode = It->second;
1429     assert(RootNode->Operation ==
1430            ComplexDeinterleavingOperation::ReductionOperation);
1431     // Find out which part, Real or Imag, comes later, and only if we come to
1432     // the latest part, add it to OrderedRoots.
1433     auto *R = cast<Instruction>(RootNode->Real);
1434     auto *I = cast<Instruction>(RootNode->Imag);
1435     auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1436     if (ReplacementAnchor != RootI)
1437       return false;
1438     OrderedRoots.push_back(RootI);
1439     return true;
1440   }
1441 
1442   auto RootNode = identifyRoot(RootI);
1443   if (!RootNode)
1444     return false;
1445 
1446   LLVM_DEBUG({
1447     Function *F = RootI->getFunction();
1448     BasicBlock *B = RootI->getParent();
1449     dbgs() << "Complex deinterleaving graph for " << F->getName()
1450            << "::" << B->getName() << ".\n";
1451     dump(dbgs());
1452     dbgs() << "\n";
1453   });
1454   RootToNode[RootI] = RootNode;
1455   OrderedRoots.push_back(RootI);
1456   return true;
1457 }
1458 
1459 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1460   bool FoundPotentialReduction = false;
1461 
1462   auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1463   if (!Br || Br->getNumSuccessors() != 2)
1464     return false;
1465 
1466   // Identify simple one-block loop
1467   if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1468     return false;
1469 
1470   SmallVector<PHINode *> PHIs;
1471   for (auto &PHI : B->phis()) {
1472     if (PHI.getNumIncomingValues() != 2)
1473       continue;
1474 
1475     if (!PHI.getType()->isVectorTy())
1476       continue;
1477 
1478     auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1479     if (!ReductionOp)
1480       continue;
1481 
1482     // Check if final instruction is reduced outside of current block
1483     Instruction *FinalReduction = nullptr;
1484     auto NumUsers = 0u;
1485     for (auto *U : ReductionOp->users()) {
1486       ++NumUsers;
1487       if (U == &PHI)
1488         continue;
1489       FinalReduction = dyn_cast<Instruction>(U);
1490     }
1491 
1492     if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1493         isa<PHINode>(FinalReduction))
1494       continue;
1495 
1496     ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1497     BackEdge = B;
1498     auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1499     auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1500     Incoming = PHI.getIncomingBlock(IncomingIdx);
1501     FoundPotentialReduction = true;
1502 
1503     // If the initial value of PHINode is an Instruction, consider it a leaf
1504     // value of a complex deinterleaving graph.
1505     if (auto *InitPHI =
1506             dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1507       FinalInstructions.insert(InitPHI);
1508   }
1509   return FoundPotentialReduction;
1510 }
1511 
1512 void ComplexDeinterleavingGraph::identifyReductionNodes() {
1513   SmallVector<bool> Processed(ReductionInfo.size(), false);
1514   SmallVector<Instruction *> OperationInstruction;
1515   for (auto &P : ReductionInfo)
1516     OperationInstruction.push_back(P.first);
1517 
1518   // Identify a complex computation by evaluating two reduction operations that
1519   // potentially could be involved
1520   for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1521     if (Processed[i])
1522       continue;
1523     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1524       if (Processed[j])
1525         continue;
1526 
1527       auto *Real = OperationInstruction[i];
1528       auto *Imag = OperationInstruction[j];
1529       if (Real->getType() != Imag->getType())
1530         continue;
1531 
1532       RealPHI = ReductionInfo[Real].first;
1533       ImagPHI = ReductionInfo[Imag].first;
1534       PHIsFound = false;
1535       auto Node = identifyNode(Real, Imag);
1536       if (!Node) {
1537         std::swap(Real, Imag);
1538         std::swap(RealPHI, ImagPHI);
1539         Node = identifyNode(Real, Imag);
1540       }
1541 
1542       // If a node is identified and reduction PHINode is used in the chain of
1543       // operations, mark its operation instructions as used to prevent
1544       // re-identification and attach the node to the real part
1545       if (Node && PHIsFound) {
1546         LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1547                           << *Real << " / " << *Imag << "\n");
1548         Processed[i] = true;
1549         Processed[j] = true;
1550         auto RootNode = prepareCompositeNode(
1551             ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1552         RootNode->addOperand(Node);
1553         RootToNode[Real] = RootNode;
1554         RootToNode[Imag] = RootNode;
1555         submitCompositeNode(RootNode);
1556         break;
1557       }
1558     }
1559   }
1560 
1561   RealPHI = nullptr;
1562   ImagPHI = nullptr;
1563 }
1564 
1565 bool ComplexDeinterleavingGraph::checkNodes() {
1566   // Collect all instructions from roots to leaves
1567   SmallPtrSet<Instruction *, 16> AllInstructions;
1568   SmallVector<Instruction *, 8> Worklist;
1569   for (auto &Pair : RootToNode)
1570     Worklist.push_back(Pair.first);
1571 
1572   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1573   // chains
1574   while (!Worklist.empty()) {
1575     auto *I = Worklist.back();
1576     Worklist.pop_back();
1577 
1578     if (!AllInstructions.insert(I).second)
1579       continue;
1580 
1581     for (Value *Op : I->operands()) {
1582       if (auto *OpI = dyn_cast<Instruction>(Op)) {
1583         if (!FinalInstructions.count(I))
1584           Worklist.emplace_back(OpI);
1585       }
1586     }
1587   }
1588 
1589   // Find instructions that have users outside of chain
1590   SmallVector<Instruction *, 2> OuterInstructions;
1591   for (auto *I : AllInstructions) {
1592     // Skip root nodes
1593     if (RootToNode.count(I))
1594       continue;
1595 
1596     for (User *U : I->users()) {
1597       if (AllInstructions.count(cast<Instruction>(U)))
1598         continue;
1599 
1600       // Found an instruction that is not used by XCMLA/XCADD chain
1601       Worklist.emplace_back(I);
1602       break;
1603     }
1604   }
1605 
1606   // If any instructions are found to be used outside, find and remove roots
1607   // that somehow connect to those instructions.
1608   SmallPtrSet<Instruction *, 16> Visited;
1609   while (!Worklist.empty()) {
1610     auto *I = Worklist.back();
1611     Worklist.pop_back();
1612     if (!Visited.insert(I).second)
1613       continue;
1614 
1615     // Found an impacted root node. Removing it from the nodes to be
1616     // deinterleaved
1617     if (RootToNode.count(I)) {
1618       LLVM_DEBUG(dbgs() << "Instruction " << *I
1619                         << " could be deinterleaved but its chain of complex "
1620                            "operations have an outside user\n");
1621       RootToNode.erase(I);
1622     }
1623 
1624     if (!AllInstructions.count(I) || FinalInstructions.count(I))
1625       continue;
1626 
1627     for (User *U : I->users())
1628       Worklist.emplace_back(cast<Instruction>(U));
1629 
1630     for (Value *Op : I->operands()) {
1631       if (auto *OpI = dyn_cast<Instruction>(Op))
1632         Worklist.emplace_back(OpI);
1633     }
1634   }
1635   return !RootToNode.empty();
1636 }
1637 
1638 ComplexDeinterleavingGraph::NodePtr
1639 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1640   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1641     if (Intrinsic->getIntrinsicID() !=
1642         Intrinsic::experimental_vector_interleave2)
1643       return nullptr;
1644 
1645     auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1646     auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1647     if (!Real || !Imag)
1648       return nullptr;
1649 
1650     return identifyNode(Real, Imag);
1651   }
1652 
1653   auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1654   if (!SVI)
1655     return nullptr;
1656 
1657   // Look for a shufflevector that takes separate vectors of the real and
1658   // imaginary components and recombines them into a single vector.
1659   if (!isInterleavingMask(SVI->getShuffleMask()))
1660     return nullptr;
1661 
1662   Instruction *Real;
1663   Instruction *Imag;
1664   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1665     return nullptr;
1666 
1667   return identifyNode(Real, Imag);
1668 }
1669 
1670 ComplexDeinterleavingGraph::NodePtr
1671 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1672                                                  Instruction *Imag) {
1673   Instruction *I = nullptr;
1674   Value *FinalValue = nullptr;
1675   if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1676       match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1677       match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>(
1678                    m_Value(FinalValue)))) {
1679     NodePtr PlaceholderNode = prepareCompositeNode(
1680         llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1681     PlaceholderNode->ReplacementNode = FinalValue;
1682     FinalInstructions.insert(Real);
1683     FinalInstructions.insert(Imag);
1684     return submitCompositeNode(PlaceholderNode);
1685   }
1686 
1687   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1688   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1689   if (!RealShuffle || !ImagShuffle) {
1690     if (RealShuffle || ImagShuffle)
1691       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1692     return nullptr;
1693   }
1694 
1695   Value *RealOp1 = RealShuffle->getOperand(1);
1696   if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1697     LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1698     return nullptr;
1699   }
1700   Value *ImagOp1 = ImagShuffle->getOperand(1);
1701   if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1702     LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1703     return nullptr;
1704   }
1705 
1706   Value *RealOp0 = RealShuffle->getOperand(0);
1707   Value *ImagOp0 = ImagShuffle->getOperand(0);
1708 
1709   if (RealOp0 != ImagOp0) {
1710     LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1711     return nullptr;
1712   }
1713 
1714   ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1715   ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1716   if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1717     LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1718     return nullptr;
1719   }
1720 
1721   if (RealMask[0] != 0 || ImagMask[0] != 1) {
1722     LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1723     return nullptr;
1724   }
1725 
1726   // Type checking, the shuffle type should be a vector type of the same
1727   // scalar type, but half the size
1728   auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1729     Value *Op = Shuffle->getOperand(0);
1730     auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1731     auto *OpTy = cast<FixedVectorType>(Op->getType());
1732 
1733     if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1734       return false;
1735     if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1736       return false;
1737 
1738     return true;
1739   };
1740 
1741   auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1742     if (!CheckType(Shuffle))
1743       return false;
1744 
1745     ArrayRef<int> Mask = Shuffle->getShuffleMask();
1746     int Last = *Mask.rbegin();
1747 
1748     Value *Op = Shuffle->getOperand(0);
1749     auto *OpTy = cast<FixedVectorType>(Op->getType());
1750     int NumElements = OpTy->getNumElements();
1751 
1752     // Ensure that the deinterleaving shuffle only pulls from the first
1753     // shuffle operand.
1754     return Last < NumElements;
1755   };
1756 
1757   if (RealShuffle->getType() != ImagShuffle->getType()) {
1758     LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1759     return nullptr;
1760   }
1761   if (!CheckDeinterleavingShuffle(RealShuffle)) {
1762     LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1763     return nullptr;
1764   }
1765   if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1766     LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1767     return nullptr;
1768   }
1769 
1770   NodePtr PlaceholderNode =
1771       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1772                            RealShuffle, ImagShuffle);
1773   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1774   FinalInstructions.insert(RealShuffle);
1775   FinalInstructions.insert(ImagShuffle);
1776   return submitCompositeNode(PlaceholderNode);
1777 }
1778 
1779 ComplexDeinterleavingGraph::NodePtr
1780 ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1781   auto IsSplat = [](Value *V) -> bool {
1782     // Fixed-width vector with constants
1783     if (isa<ConstantDataVector>(V))
1784       return true;
1785 
1786     VectorType *VTy;
1787     ArrayRef<int> Mask;
1788     // Splats are represented differently depending on whether the repeated
1789     // value is a constant or an Instruction
1790     if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1791       if (Const->getOpcode() != Instruction::ShuffleVector)
1792         return false;
1793       VTy = cast<VectorType>(Const->getType());
1794       Mask = Const->getShuffleMask();
1795     } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1796       VTy = Shuf->getType();
1797       Mask = Shuf->getShuffleMask();
1798     } else {
1799       return false;
1800     }
1801 
1802     // When the data type is <1 x Type>, it's not possible to differentiate
1803     // between the ComplexDeinterleaving::Deinterleave and
1804     // ComplexDeinterleaving::Splat operations.
1805     if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1806       return false;
1807 
1808     return all_equal(Mask) && Mask[0] == 0;
1809   };
1810 
1811   if (!IsSplat(R) || !IsSplat(I))
1812     return nullptr;
1813 
1814   auto *Real = dyn_cast<Instruction>(R);
1815   auto *Imag = dyn_cast<Instruction>(I);
1816   if ((!Real && Imag) || (Real && !Imag))
1817     return nullptr;
1818 
1819   if (Real && Imag) {
1820     // Non-constant splats should be in the same basic block
1821     if (Real->getParent() != Imag->getParent())
1822       return nullptr;
1823 
1824     FinalInstructions.insert(Real);
1825     FinalInstructions.insert(Imag);
1826   }
1827   NodePtr PlaceholderNode =
1828       prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1829   return submitCompositeNode(PlaceholderNode);
1830 }
1831 
1832 ComplexDeinterleavingGraph::NodePtr
1833 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1834                                             Instruction *Imag) {
1835   if (Real != RealPHI || Imag != ImagPHI)
1836     return nullptr;
1837 
1838   PHIsFound = true;
1839   NodePtr PlaceholderNode = prepareCompositeNode(
1840       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1841   return submitCompositeNode(PlaceholderNode);
1842 }
1843 
1844 ComplexDeinterleavingGraph::NodePtr
1845 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1846                                                Instruction *Imag) {
1847   auto *SelectReal = dyn_cast<SelectInst>(Real);
1848   auto *SelectImag = dyn_cast<SelectInst>(Imag);
1849   if (!SelectReal || !SelectImag)
1850     return nullptr;
1851 
1852   Instruction *MaskA, *MaskB;
1853   Instruction *AR, *AI, *RA, *BI;
1854   if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1855                             m_Instruction(RA))) ||
1856       !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1857                             m_Instruction(BI))))
1858     return nullptr;
1859 
1860   if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1861     return nullptr;
1862 
1863   if (!MaskA->getType()->isVectorTy())
1864     return nullptr;
1865 
1866   auto NodeA = identifyNode(AR, AI);
1867   if (!NodeA)
1868     return nullptr;
1869 
1870   auto NodeB = identifyNode(RA, BI);
1871   if (!NodeB)
1872     return nullptr;
1873 
1874   NodePtr PlaceholderNode = prepareCompositeNode(
1875       ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1876   PlaceholderNode->addOperand(NodeA);
1877   PlaceholderNode->addOperand(NodeB);
1878   FinalInstructions.insert(MaskA);
1879   FinalInstructions.insert(MaskB);
1880   return submitCompositeNode(PlaceholderNode);
1881 }
1882 
1883 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1884                                    std::optional<FastMathFlags> Flags,
1885                                    Value *InputA, Value *InputB) {
1886   Value *I;
1887   switch (Opcode) {
1888   case Instruction::FNeg:
1889     I = B.CreateFNeg(InputA);
1890     break;
1891   case Instruction::FAdd:
1892     I = B.CreateFAdd(InputA, InputB);
1893     break;
1894   case Instruction::Add:
1895     I = B.CreateAdd(InputA, InputB);
1896     break;
1897   case Instruction::FSub:
1898     I = B.CreateFSub(InputA, InputB);
1899     break;
1900   case Instruction::Sub:
1901     I = B.CreateSub(InputA, InputB);
1902     break;
1903   case Instruction::FMul:
1904     I = B.CreateFMul(InputA, InputB);
1905     break;
1906   case Instruction::Mul:
1907     I = B.CreateMul(InputA, InputB);
1908     break;
1909   default:
1910     llvm_unreachable("Incorrect symmetric opcode");
1911   }
1912   if (Flags)
1913     cast<Instruction>(I)->setFastMathFlags(*Flags);
1914   return I;
1915 }
1916 
1917 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1918                                                RawNodePtr Node) {
1919   if (Node->ReplacementNode)
1920     return Node->ReplacementNode;
1921 
1922   auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1923     return Node->Operands.size() > Idx
1924                ? replaceNode(Builder, Node->Operands[Idx])
1925                : nullptr;
1926   };
1927 
1928   Value *ReplacementNode;
1929   switch (Node->Operation) {
1930   case ComplexDeinterleavingOperation::CAdd:
1931   case ComplexDeinterleavingOperation::CMulPartial:
1932   case ComplexDeinterleavingOperation::Symmetric: {
1933     Value *Input0 = ReplaceOperandIfExist(Node, 0);
1934     Value *Input1 = ReplaceOperandIfExist(Node, 1);
1935     Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1936     assert(!Input1 || (Input0->getType() == Input1->getType() &&
1937                        "Node inputs need to be of the same type"));
1938     assert(!Accumulator ||
1939            (Input0->getType() == Accumulator->getType() &&
1940             "Accumulator and input need to be of the same type"));
1941     if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1942       ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1943                                              Input0, Input1);
1944     else
1945       ReplacementNode = TL->createComplexDeinterleavingIR(
1946           Builder, Node->Operation, Node->Rotation, Input0, Input1,
1947           Accumulator);
1948     break;
1949   }
1950   case ComplexDeinterleavingOperation::Deinterleave:
1951     llvm_unreachable("Deinterleave node should already have ReplacementNode");
1952     break;
1953   case ComplexDeinterleavingOperation::Splat: {
1954     auto *NewTy = VectorType::getDoubleElementsVectorType(
1955         cast<VectorType>(Node->Real->getType()));
1956     auto *R = dyn_cast<Instruction>(Node->Real);
1957     auto *I = dyn_cast<Instruction>(Node->Imag);
1958     if (R && I) {
1959       // Splats that are not constant are interleaved where they are located
1960       Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1961       IRBuilder<> IRB(InsertPoint);
1962       ReplacementNode =
1963           IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy,
1964                               {Node->Real, Node->Imag});
1965     } else {
1966       ReplacementNode =
1967           Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1968                                   NewTy, {Node->Real, Node->Imag});
1969     }
1970     break;
1971   }
1972   case ComplexDeinterleavingOperation::ReductionPHI: {
1973     // If Operation is ReductionPHI, a new empty PHINode is created.
1974     // It is filled later when the ReductionOperation is processed.
1975     auto *VTy = cast<VectorType>(Node->Real->getType());
1976     auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1977     auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
1978     OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1979     ReplacementNode = NewPHI;
1980     break;
1981   }
1982   case ComplexDeinterleavingOperation::ReductionOperation:
1983     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1984     processReductionOperation(ReplacementNode, Node);
1985     break;
1986   case ComplexDeinterleavingOperation::ReductionSelect: {
1987     auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1988     auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1989     auto *A = replaceNode(Builder, Node->Operands[0]);
1990     auto *B = replaceNode(Builder, Node->Operands[1]);
1991     auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1992         cast<VectorType>(MaskReal->getType()));
1993     auto *NewMask =
1994         Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
1995                                 NewMaskTy, {MaskReal, MaskImag});
1996     ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1997     break;
1998   }
1999   }
2000 
2001   assert(ReplacementNode && "Target failed to create Intrinsic call.");
2002   NumComplexTransformations += 1;
2003   Node->ReplacementNode = ReplacementNode;
2004   return ReplacementNode;
2005 }
2006 
2007 void ComplexDeinterleavingGraph::processReductionOperation(
2008     Value *OperationReplacement, RawNodePtr Node) {
2009   auto *Real = cast<Instruction>(Node->Real);
2010   auto *Imag = cast<Instruction>(Node->Imag);
2011   auto *OldPHIReal = ReductionInfo[Real].first;
2012   auto *OldPHIImag = ReductionInfo[Imag].first;
2013   auto *NewPHI = OldToNewPHI[OldPHIReal];
2014 
2015   auto *VTy = cast<VectorType>(Real->getType());
2016   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2017 
2018   // We have to interleave initial origin values coming from IncomingBlock
2019   Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2020   Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2021 
2022   IRBuilder<> Builder(Incoming->getTerminator());
2023   auto *NewInit = Builder.CreateIntrinsic(
2024       Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
2025 
2026   NewPHI->addIncoming(NewInit, Incoming);
2027   NewPHI->addIncoming(OperationReplacement, BackEdge);
2028 
2029   // Deinterleave complex vector outside of loop so that it can be finally
2030   // reduced
2031   auto *FinalReductionReal = ReductionInfo[Real].second;
2032   auto *FinalReductionImag = ReductionInfo[Imag].second;
2033 
2034   Builder.SetInsertPoint(
2035       &*FinalReductionReal->getParent()->getFirstInsertionPt());
2036   auto *Deinterleave = Builder.CreateIntrinsic(
2037       Intrinsic::experimental_vector_deinterleave2,
2038       OperationReplacement->getType(), OperationReplacement);
2039 
2040   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2041   FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2042 
2043   Builder.SetInsertPoint(FinalReductionImag);
2044   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2045   FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2046 }
2047 
2048 void ComplexDeinterleavingGraph::replaceNodes() {
2049   SmallVector<Instruction *, 16> DeadInstrRoots;
2050   for (auto *RootInstruction : OrderedRoots) {
2051     // Check if this potential root went through check process and we can
2052     // deinterleave it
2053     if (!RootToNode.count(RootInstruction))
2054       continue;
2055 
2056     IRBuilder<> Builder(RootInstruction);
2057     auto RootNode = RootToNode[RootInstruction];
2058     Value *R = replaceNode(Builder, RootNode.get());
2059 
2060     if (RootNode->Operation ==
2061         ComplexDeinterleavingOperation::ReductionOperation) {
2062       auto *RootReal = cast<Instruction>(RootNode->Real);
2063       auto *RootImag = cast<Instruction>(RootNode->Imag);
2064       ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2065       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2066       DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2067       DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2068     } else {
2069       assert(R && "Unable to find replacement for RootInstruction");
2070       DeadInstrRoots.push_back(RootInstruction);
2071       RootInstruction->replaceAllUsesWith(R);
2072     }
2073   }
2074 
2075   for (auto *I : DeadInstrRoots)
2076     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2077 }
2078