1 //===- LoopUtils.h - Loop transformation utilities --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines prototypes for various loop transformation utility
10 // methods: these are not passes by themselves but are used either by passes,
11 // optimization sequences, or in turn by other transformation utilities.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_TRANSFORMS_LOOP_UTILS_H
16 #define MLIR_TRANSFORMS_LOOP_UTILS_H
17 
18 #include "mlir/IR/Block.h"
19 #include "mlir/Support/LLVM.h"
20 #include "mlir/Support/LogicalResult.h"
21 
22 namespace mlir {
23 class AffineForOp;
24 class FuncOp;
25 class LoopLikeOpInterface;
26 struct MemRefRegion;
27 class OpBuilder;
28 class Value;
29 class ValueRange;
30 
31 namespace scf {
32 class ForOp;
33 class ParallelOp;
34 } // end namespace scf
35 
36 /// Unrolls this for operation completely if the trip count is known to be
37 /// constant. Returns failure otherwise.
38 LogicalResult loopUnrollFull(AffineForOp forOp);
39 
40 /// Unrolls this for operation by the specified unroll factor. Returns failure
41 /// if the loop cannot be unrolled either due to restrictions or due to invalid
42 /// unroll factors. Requires positive loop bounds and step.
43 LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor);
44 LogicalResult loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor);
45 
46 /// Unrolls this loop by the specified unroll factor or its trip count,
47 /// whichever is lower.
48 LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor);
49 
50 /// Returns true if `loops` is a perfectly nested loop nest, where loops appear
51 /// in it from outermost to innermost.
52 bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef<AffineForOp> loops);
53 
54 /// Get perfectly nested sequence of loops starting at root of loop nest
55 /// (the first op being another AffineFor, and the second op - a terminator).
56 /// A loop is perfectly nested iff: the first op in the loop's body is another
57 /// AffineForOp, and the second op is a terminator).
58 void getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
59                              AffineForOp root);
60 void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
61                              scf::ForOp root);
62 
63 /// Unrolls and jams this loop by the specified factor. Returns success if the
64 /// loop is successfully unroll-jammed.
65 LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
66                                     uint64_t unrollJamFactor);
67 
68 /// Unrolls and jams this loop by the specified factor or by the trip count (if
69 /// constant), whichever is lower.
70 LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
71                                       uint64_t unrollJamFactor);
72 
73 /// Promotes the loop body of a AffineForOp/scf::ForOp to its containing block
74 /// if the loop was known to have a single iteration.
75 LogicalResult promoteIfSingleIteration(AffineForOp forOp);
76 LogicalResult promoteIfSingleIteration(scf::ForOp forOp);
77 
78 /// Promotes all single iteration AffineForOp's in the Function, i.e., moves
79 /// their body into the containing Block.
80 void promoteSingleIterationLoops(FuncOp f);
81 
82 /// Skew the operations in an affine.for's body with the specified
83 /// operation-wise shifts. The shifts are with respect to the original execution
84 /// order, and are multiplied by the loop 'step' before being applied. If
85 /// `unrollPrologueEpilogue` is set, fully unroll the prologue and epilogue
86 /// loops when possible.
87 LLVM_NODISCARD
88 LogicalResult affineForOpBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts,
89                                   bool unrollPrologueEpilogue = false);
90 
91 /// Tiles the specified band of perfectly nested loops creating tile-space loops
92 /// and intra-tile loops. A band is a contiguous set of loops. `tiledNest` when
93 /// non-null is set to the loops of the tiled nest from outermost to innermost.
94 /// Loops in `input` are erased when the tiling is successful.
95 LLVM_NODISCARD
96 LogicalResult
97 tilePerfectlyNested(MutableArrayRef<AffineForOp> input,
98                     ArrayRef<unsigned> tileSizes,
99                     SmallVectorImpl<AffineForOp> *tiledNest = nullptr);
100 
101 /// Performs loop interchange on 'forOpA' and 'forOpB'. Requires that 'forOpA'
102 /// and 'forOpB' are part of a perfectly nested sequence of loops.
103 void interchangeLoops(AffineForOp forOpA, AffineForOp forOpB);
104 
105 /// Checks if the loop interchange permutation 'loopPermMap', of the perfectly
106 /// nested sequence of loops in 'loops', would violate dependences (loop 'i' in
107 /// 'loops' is mapped to location 'j = 'loopPermMap[i]' in the interchange).
108 bool isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
109                                        ArrayRef<unsigned> loopPermMap);
110 
111 /// Performs a loop permutation on a perfectly nested loop nest `inputNest`
112 /// (where the contained loops appear from outer to inner) as specified by the
113 /// permutation `permMap`: loop 'i' in `inputNest` is mapped to location
114 /// 'loopPermMap[i]', where positions 0, 1, ... are from the outermost position
115 /// to inner. Returns the position in `inputNest` of the AffineForOp that
116 /// becomes the new outermost loop of this nest. This method always succeeds,
117 /// asserts out on invalid input / specifications.
118 unsigned permuteLoops(MutableArrayRef<AffineForOp> inputNest,
119                       ArrayRef<unsigned> permMap);
120 
121 // Sinks all sequential loops to the innermost levels (while preserving
122 // relative order among them) and moves all parallel loops to the
123 // outermost (while again preserving relative order among them).
124 // Returns AffineForOp of the root of the new loop nest after loop interchanges.
125 AffineForOp sinkSequentialLoops(AffineForOp forOp);
126 
127 /// Performs tiling fo imperfectly nested loops (with interchange) by
128 /// strip-mining the `forOps` by `sizes` and sinking them, in their order of
129 /// occurrence in `forOps`, under each of the `targets`.
130 /// Returns the new AffineForOps, one per each of (`forOps`, `targets`) pair,
131 /// nested immediately under each of `targets`.
132 using Loops = SmallVector<scf::ForOp, 8>;
133 using TileLoops = std::pair<Loops, Loops>;
134 SmallVector<SmallVector<AffineForOp, 8>, 8> tile(ArrayRef<AffineForOp> forOps,
135                                                  ArrayRef<uint64_t> sizes,
136                                                  ArrayRef<AffineForOp> targets);
137 SmallVector<Loops, 8> tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
138                            ArrayRef<scf::ForOp> targets);
139 
140 /// Performs tiling (with interchange) by strip-mining the `forOps` by `sizes`
141 /// and sinking them, in their order of occurrence in `forOps`, under `target`.
142 /// Returns the new AffineForOps, one per `forOps`, nested immediately under
143 /// `target`.
144 SmallVector<AffineForOp, 8> tile(ArrayRef<AffineForOp> forOps,
145                                  ArrayRef<uint64_t> sizes, AffineForOp target);
146 Loops tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
147            scf::ForOp target);
148 
149 /// Tile a nest of scf::ForOp loops rooted at `rootForOp` with the given
150 /// (parametric) sizes. Sizes are expected to be strictly positive values at
151 /// runtime.  If more sizes than loops are provided, discard the trailing values
152 /// in sizes.  Assumes the loop nest is permutable.
153 /// Returns the newly created intra-tile loops.
154 Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
155 
156 /// Explicit copy / DMA generation options for mlir::affineDataCopyGenerate.
157 struct AffineCopyOptions {
158   // True if DMAs should be generated instead of point-wise copies.
159   bool generateDma;
160   // The slower memory space from which data is to be moved.
161   unsigned slowMemorySpace;
162   // Memory space of the faster one (typically a scratchpad).
163   unsigned fastMemorySpace;
164   // Memory space to place tags in: only meaningful for DMAs.
165   unsigned tagMemorySpace;
166   // Capacity of the fast memory space in bytes.
167   uint64_t fastMemCapacityBytes;
168 };
169 
170 /// Performs explicit copying for the contiguous sequence of operations in the
171 /// block iterator range [`begin', `end'), where `end' can't be past the
172 /// terminator of the block (since additional operations are potentially
173 /// inserted right before `end`. Returns the total size of fast memory space
174 /// buffers used. `copyOptions` provides various parameters, and the output
175 /// argument `copyNests` is the set of all copy nests inserted, each represented
176 /// by its root affine.for. Since we generate alloc's and dealloc's for all fast
177 /// buffers (before and after the range of operations resp. or at a hoisted
178 /// position), all of the fast memory capacity is assumed to be available for
179 /// processing this block range. When 'filterMemRef' is specified, copies are
180 /// only generated for the provided MemRef.
181 uint64_t affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
182                                 const AffineCopyOptions &copyOptions,
183                                 Optional<Value> filterMemRef,
184                                 DenseSet<Operation *> &copyNests);
185 
186 /// A convenience version of affineDataCopyGenerate for all ops in the body of
187 /// an AffineForOp.
188 uint64_t affineDataCopyGenerate(AffineForOp forOp,
189                                 const AffineCopyOptions &copyOptions,
190                                 Optional<Value> filterMemRef,
191                                 DenseSet<Operation *> &copyNests);
192 
193 /// Result for calling generateCopyForMemRegion.
194 struct CopyGenerateResult {
195   // Number of bytes used by alloc.
196   uint64_t sizeInBytes;
197 
198   // The newly created buffer allocation.
199   Operation *alloc;
200 
201   // Generated loop nest for copying data between the allocated buffer and the
202   // original memref.
203   Operation *copyNest;
204 };
205 
206 /// generateCopyForMemRegion is similar to affineDataCopyGenerate, but works
207 /// with a single memref region. `memrefRegion` is supposed to contain analysis
208 /// information within analyzedOp. The generated prologue and epilogue always
209 /// surround `analyzedOp`.
210 ///
211 /// Note that `analyzedOp` is a single op for API convenience, and the
212 /// [begin, end) version can be added as needed.
213 ///
214 /// Also note that certain options in `copyOptions` aren't looked at anymore,
215 /// like slowMemorySpace.
216 LogicalResult generateCopyForMemRegion(const MemRefRegion &memrefRegion,
217                                        Operation *analyzedOp,
218                                        const AffineCopyOptions &copyOptions,
219                                        CopyGenerateResult &result);
220 
221 /// Tile a nest of standard for loops rooted at `rootForOp` by finding such
222 /// parametric tile sizes that the outer loops have a fixed number of iterations
223 /// as defined in `sizes`.
224 TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef<int64_t> sizes);
225 
226 /// Replace a perfect nest of "for" loops with a single linearized loop. Assumes
227 /// `loops` contains a list of perfectly nested loops with bounds and steps
228 /// independent of any loop induction variable involved in the nest.
229 void coalesceLoops(MutableArrayRef<scf::ForOp> loops);
230 
231 /// Take the ParallelLoop and for each set of dimension indices, combine them
232 /// into a single dimension. combinedDimensions must contain each index into
233 /// loops exactly once.
234 void collapseParallelLoops(scf::ParallelOp loops,
235                            ArrayRef<std::vector<unsigned>> combinedDimensions);
236 
237 /// Maps `forOp` for execution on a parallel grid of virtual `processorIds` of
238 /// size given by `numProcessors`. This is achieved by embedding the SSA values
239 /// corresponding to `processorIds` and `numProcessors` into the bounds and step
240 /// of the `forOp`. No check is performed on the legality of the rewrite, it is
241 /// the caller's responsibility to ensure legality.
242 ///
243 /// Requires that `processorIds` and `numProcessors` have the same size and that
244 /// for each idx, `processorIds`[idx] takes, at runtime, all values between 0
245 /// and `numProcessors`[idx] - 1. This corresponds to traditional use cases for:
246 ///   1. GPU (threadIdx, get_local_id(), ...)
247 ///   2. MPI (MPI_Comm_rank)
248 ///   3. OpenMP (omp_get_thread_num)
249 ///
250 /// Example:
251 /// Assuming a 2-d grid with processorIds = [blockIdx.x, threadIdx.x] and
252 /// numProcessors = [gridDim.x, blockDim.x], the loop:
253 ///
254 /// ```
255 ///    scf.for %i = %lb to %ub step %step {
256 ///      ...
257 ///    }
258 /// ```
259 ///
260 /// is rewritten into a version resembling the following pseudo-IR:
261 ///
262 /// ```
263 ///    scf.for %i = %lb + %step * (threadIdx.x + blockIdx.x * blockDim.x)
264 ///       to %ub step %gridDim.x * blockDim.x * %step {
265 ///      ...
266 ///    }
267 /// ```
268 void mapLoopToProcessorIds(scf::ForOp forOp, ArrayRef<Value> processorId,
269                            ArrayRef<Value> numProcessors);
270 
271 /// Gathers all AffineForOps in 'func' grouped by loop depth.
272 void gatherLoops(FuncOp func,
273                  std::vector<SmallVector<AffineForOp, 2>> &depthToLoops);
274 
275 /// Creates an AffineForOp while ensuring that the lower and upper bounds are
276 /// canonicalized, i.e., unused and duplicate operands are removed, any constant
277 /// operands propagated/folded in, and duplicate bound maps dropped.
278 AffineForOp createCanonicalizedAffineForOp(OpBuilder b, Location loc,
279                                            ValueRange lbOperands,
280                                            AffineMap lbMap,
281                                            ValueRange ubOperands,
282                                            AffineMap ubMap, int64_t step = 1);
283 
284 /// Separates full tiles from partial tiles for a perfect nest `nest` by
285 /// generating a conditional guard that selects between the full tile version
286 /// and the partial tile version using an AffineIfOp. The original loop nest
287 /// is replaced by this guarded two version form.
288 ///
289 ///    affine.if (cond)
290 ///      // full_tile
291 ///    else
292 ///      // partial tile
293 ///
294 LogicalResult
295 separateFullTiles(MutableArrayRef<AffineForOp> nest,
296                   SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);
297 
298 /// Move loop invariant code out of `looplike`.
299 LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike);
300 
301 } // end namespace mlir
302 
303 #endif // MLIR_TRANSFORMS_LOOP_UTILS_H
304