1 //===- ScalarEvolutionNormalization.cpp - See below -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for working with "normalized" expressions.
10 // See the comments at the top of ScalarEvolutionNormalization.h for details.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/ScalarEvolutionNormalization.h"
15 #include "llvm/Analysis/LoopInfo.h"
16 #include "llvm/Analysis/ScalarEvolution.h"
17 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
18 using namespace llvm;
19 
20 /// TransformKind - Different types of transformations that
21 /// TransformForPostIncUse can do.
22 enum TransformKind {
23   /// Normalize - Normalize according to the given loops.
24   Normalize,
25   /// Denormalize - Perform the inverse transform on the expression with the
26   /// given loop set.
27   Denormalize
28 };
29 
30 namespace {
31 struct NormalizeDenormalizeRewriter
32     : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> {
33   const TransformKind Kind;
34 
35   // NB! Pred is a function_ref.  Storing it here is okay only because
36   // we're careful about the lifetime of NormalizeDenormalizeRewriter.
37   const NormalizePredTy Pred;
38 
39   NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred,
40                                ScalarEvolution &SE)
41       : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind),
42         Pred(Pred) {}
43   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr);
44 };
45 } // namespace
46 
47 const SCEV *
48 NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) {
49   SmallVector<const SCEV *, 8> Operands;
50 
51   transform(AR->operands(), std::back_inserter(Operands),
52             [&](const SCEV *Op) { return visit(Op); });
53 
54   if (!Pred(AR))
55     return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
56 
57   // Normalization and denormalization are fancy names for decrementing and
58   // incrementing a SCEV expression with respect to a set of loops.  Since
59   // Pred(AR) has returned true, we know we need to normalize or denormalize AR
60   // with respect to its loop.
61 
62   if (Kind == Denormalize) {
63     // Denormalization / "partial increment" is essentially the same as \c
64     // SCEVAddRecExpr::getPostIncExpr.  Here we use an explicit loop to make the
65     // symmetry with Normalization clear.
66     for (int i = 0, e = Operands.size() - 1; i < e; i++)
67       Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]);
68   } else {
69     assert(Kind == Normalize && "Only two possibilities!");
70 
71     // Normalization / "partial decrement" is a bit more subtle.  Since
72     // incrementing a SCEV expression (in general) changes the step of the SCEV
73     // expression as well, we cannot use the step of the current expression.
74     // Instead, we have to use the step of the very expression we're trying to
75     // compute!
76     //
77     // We solve the issue by recursively building up the result, starting from
78     // the "least significant" operand in the add recurrence:
79     //
80     // Base case:
81     //   Single operand add recurrence.  It's its own normalization.
82     //
83     // N-operand case:
84     //   {S_{N-1},+,S_{N-2},+,...,+,S_0} = S
85     //
86     //   Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its
87     //   normalization by induction.  We subtract the normalized step
88     //   recurrence from S_{N-1} to get the normalization of S.
89 
90     for (int i = Operands.size() - 2; i >= 0; i--)
91       Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]);
92   }
93 
94   return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
95 }
96 
97 const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
98                                          const PostIncLoopSet &Loops,
99                                          ScalarEvolution &SE,
100                                          bool CheckInvertible) {
101   if (Loops.empty())
102     return S;
103   auto Pred = [&](const SCEVAddRecExpr *AR) {
104     return Loops.count(AR->getLoop());
105   };
106   const SCEV *Normalized =
107       NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
108   const SCEV *Denormalized = denormalizeForPostIncUse(Normalized, Loops, SE);
109   // If the normalized expression isn't invertible.
110   if (CheckInvertible && Denormalized != S)
111     return nullptr;
112   return Normalized;
113 }
114 
115 const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
116                                            ScalarEvolution &SE) {
117   return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
118 }
119 
120 const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
121                                            const PostIncLoopSet &Loops,
122                                            ScalarEvolution &SE) {
123   if (Loops.empty())
124     return S;
125   auto Pred = [&](const SCEVAddRecExpr *AR) {
126     return Loops.count(AR->getLoop());
127   };
128   return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S);
129 }
130