1 //===- LoopUtils.h - Loop transformation utilities --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines prototypes for various loop transformation utility
10 // methods: these are not passes by themselves but are used either by passes,
11 // optimization sequences, or in turn by other transformation utilities.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_TRANSFORMS_LOOP_UTILS_H
16 #define MLIR_TRANSFORMS_LOOP_UTILS_H
17 
18 #include "mlir/IR/Block.h"
19 #include "mlir/Support/LLVM.h"
20 #include "mlir/Support/LogicalResult.h"
21 
22 namespace mlir {
23 class AffineForOp;
24 class AffineMap;
25 class FuncOp;
26 class LoopLikeOpInterface;
27 struct MemRefRegion;
28 class OpBuilder;
29 class Value;
30 class ValueRange;
31 
32 namespace scf {
33 class ForOp;
34 class ParallelOp;
35 } // end namespace scf
36 
37 /// Unrolls this for operation completely if the trip count is known to be
38 /// constant. Returns failure otherwise.
39 LogicalResult loopUnrollFull(AffineForOp forOp);
40 
41 /// Unrolls this for operation by the specified unroll factor. Returns failure
42 /// if the loop cannot be unrolled either due to restrictions or due to invalid
43 /// unroll factors. Requires positive loop bounds and step.
44 LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor);
45 LogicalResult loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor);
46 
47 /// Unrolls this loop by the specified unroll factor or its trip count,
48 /// whichever is lower.
49 LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor);
50 
51 /// Returns true if `loops` is a perfectly nested loop nest, where loops appear
52 /// in it from outermost to innermost.
53 bool LLVM_ATTRIBUTE_UNUSED isPerfectlyNested(ArrayRef<AffineForOp> loops);
54 
55 /// Get perfectly nested sequence of loops starting at root of loop nest
56 /// (the first op being another AffineFor, and the second op - a terminator).
57 /// A loop is perfectly nested iff: the first op in the loop's body is another
58 /// AffineForOp, and the second op is a terminator).
59 void getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
60                              AffineForOp root);
61 void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
62                              scf::ForOp root);
63 
64 /// Unrolls and jams this loop by the specified factor. Returns success if the
65 /// loop is successfully unroll-jammed.
66 LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
67                                     uint64_t unrollJamFactor);
68 
69 /// Unrolls and jams this loop by the specified factor or by the trip count (if
70 /// constant), whichever is lower.
71 LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
72                                       uint64_t unrollJamFactor);
73 
74 /// Promotes the loop body of a AffineForOp/scf::ForOp to its containing block
75 /// if the loop was known to have a single iteration.
76 LogicalResult promoteIfSingleIteration(AffineForOp forOp);
77 LogicalResult promoteIfSingleIteration(scf::ForOp forOp);
78 
79 /// Promotes all single iteration AffineForOp's in the Function, i.e., moves
80 /// their body into the containing Block.
81 void promoteSingleIterationLoops(FuncOp f);
82 
83 /// Skew the operations in an affine.for's body with the specified
84 /// operation-wise shifts. The shifts are with respect to the original execution
85 /// order, and are multiplied by the loop 'step' before being applied. If
86 /// `unrollPrologueEpilogue` is set, fully unroll the prologue and epilogue
87 /// loops when possible.
88 LLVM_NODISCARD
89 LogicalResult affineForOpBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts,
90                                   bool unrollPrologueEpilogue = false);
91 
92 /// Identify valid and profitable bands of loops to tile. This is currently just
93 /// a temporary placeholder to test the mechanics of tiled code generation.
94 /// Returns all maximal outermost perfect loop nests to tile.
95 void getTileableBands(FuncOp f,
96                       std::vector<SmallVector<AffineForOp, 6>> *bands);
97 
98 /// Tiles the specified band of perfectly nested loops creating tile-space loops
99 /// and intra-tile loops. A band is a contiguous set of loops.
100 LLVM_NODISCARD
101 LogicalResult
102 tilePerfectlyNested(MutableArrayRef<AffineForOp> input,
103                     ArrayRef<unsigned> tileSizes,
104                     SmallVectorImpl<AffineForOp> *tiledNest = nullptr);
105 
106 /// Tiles the specified band of perfectly nested loops creating tile-space
107 /// loops and intra-tile loops, using SSA values as tiling parameters. A band
108 /// is a contiguous set of loops.
109 LLVM_NODISCARD
110 LogicalResult tilePerfectlyNestedParametric(
111     MutableArrayRef<AffineForOp> input, ArrayRef<Value> tileSizes,
112     SmallVectorImpl<AffineForOp> *tiledNest = nullptr);
113 
114 /// Performs loop interchange on 'forOpA' and 'forOpB'. Requires that 'forOpA'
115 /// and 'forOpB' are part of a perfectly nested sequence of loops.
116 void interchangeLoops(AffineForOp forOpA, AffineForOp forOpB);
117 
118 /// Checks if the loop interchange permutation 'loopPermMap', of the perfectly
119 /// nested sequence of loops in 'loops', would violate dependences (loop 'i' in
120 /// 'loops' is mapped to location 'j = 'loopPermMap[i]' in the interchange).
121 bool isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
122                                        ArrayRef<unsigned> loopPermMap);
123 
124 /// Performs a loop permutation on a perfectly nested loop nest `inputNest`
125 /// (where the contained loops appear from outer to inner) as specified by the
126 /// permutation `permMap`: loop 'i' in `inputNest` is mapped to location
127 /// 'loopPermMap[i]', where positions 0, 1, ... are from the outermost position
128 /// to inner. Returns the position in `inputNest` of the AffineForOp that
129 /// becomes the new outermost loop of this nest. This method always succeeds,
130 /// asserts out on invalid input / specifications.
131 unsigned permuteLoops(MutableArrayRef<AffineForOp> inputNest,
132                       ArrayRef<unsigned> permMap);
133 
134 // Sinks all sequential loops to the innermost levels (while preserving
135 // relative order among them) and moves all parallel loops to the
136 // outermost (while again preserving relative order among them).
137 // Returns AffineForOp of the root of the new loop nest after loop interchanges.
138 AffineForOp sinkSequentialLoops(AffineForOp forOp);
139 
140 /// Performs tiling fo imperfectly nested loops (with interchange) by
141 /// strip-mining the `forOps` by `sizes` and sinking them, in their order of
142 /// occurrence in `forOps`, under each of the `targets`.
143 /// Returns the new AffineForOps, one per each of (`forOps`, `targets`) pair,
144 /// nested immediately under each of `targets`.
145 using Loops = SmallVector<scf::ForOp, 8>;
146 using TileLoops = std::pair<Loops, Loops>;
147 SmallVector<SmallVector<AffineForOp, 8>, 8> tile(ArrayRef<AffineForOp> forOps,
148                                                  ArrayRef<uint64_t> sizes,
149                                                  ArrayRef<AffineForOp> targets);
150 SmallVector<Loops, 8> tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
151                            ArrayRef<scf::ForOp> targets);
152 
153 /// Performs tiling (with interchange) by strip-mining the `forOps` by `sizes`
154 /// and sinking them, in their order of occurrence in `forOps`, under `target`.
155 /// Returns the new AffineForOps, one per `forOps`, nested immediately under
156 /// `target`.
157 SmallVector<AffineForOp, 8> tile(ArrayRef<AffineForOp> forOps,
158                                  ArrayRef<uint64_t> sizes, AffineForOp target);
159 Loops tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
160            scf::ForOp target);
161 
162 /// Tile a nest of scf::ForOp loops rooted at `rootForOp` with the given
163 /// (parametric) sizes. Sizes are expected to be strictly positive values at
164 /// runtime.  If more sizes than loops are provided, discard the trailing values
165 /// in sizes.  Assumes the loop nest is permutable.
166 /// Returns the newly created intra-tile loops.
167 Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
168 
169 /// Explicit copy / DMA generation options for mlir::affineDataCopyGenerate.
170 struct AffineCopyOptions {
171   // True if DMAs should be generated instead of point-wise copies.
172   bool generateDma;
173   // The slower memory space from which data is to be moved.
174   unsigned slowMemorySpace;
175   // Memory space of the faster one (typically a scratchpad).
176   unsigned fastMemorySpace;
177   // Memory space to place tags in: only meaningful for DMAs.
178   unsigned tagMemorySpace;
179   // Capacity of the fast memory space in bytes.
180   uint64_t fastMemCapacityBytes;
181 };
182 
183 /// Performs explicit copying for the contiguous sequence of operations in the
184 /// block iterator range [`begin', `end'), where `end' can't be past the
185 /// terminator of the block (since additional operations are potentially
186 /// inserted right before `end`. Returns the total size of fast memory space
187 /// buffers used. `copyOptions` provides various parameters, and the output
188 /// argument `copyNests` is the set of all copy nests inserted, each represented
189 /// by its root affine.for. Since we generate alloc's and dealloc's for all fast
190 /// buffers (before and after the range of operations resp. or at a hoisted
191 /// position), all of the fast memory capacity is assumed to be available for
192 /// processing this block range. When 'filterMemRef' is specified, copies are
193 /// only generated for the provided MemRef.
194 uint64_t affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
195                                 const AffineCopyOptions &copyOptions,
196                                 Optional<Value> filterMemRef,
197                                 DenseSet<Operation *> &copyNests);
198 
199 /// A convenience version of affineDataCopyGenerate for all ops in the body of
200 /// an AffineForOp.
201 uint64_t affineDataCopyGenerate(AffineForOp forOp,
202                                 const AffineCopyOptions &copyOptions,
203                                 Optional<Value> filterMemRef,
204                                 DenseSet<Operation *> &copyNests);
205 
206 /// Result for calling generateCopyForMemRegion.
207 struct CopyGenerateResult {
208   // Number of bytes used by alloc.
209   uint64_t sizeInBytes;
210 
211   // The newly created buffer allocation.
212   Operation *alloc;
213 
214   // Generated loop nest for copying data between the allocated buffer and the
215   // original memref.
216   Operation *copyNest;
217 };
218 
219 /// generateCopyForMemRegion is similar to affineDataCopyGenerate, but works
220 /// with a single memref region. `memrefRegion` is supposed to contain analysis
221 /// information within analyzedOp. The generated prologue and epilogue always
222 /// surround `analyzedOp`.
223 ///
224 /// Note that `analyzedOp` is a single op for API convenience, and the
225 /// [begin, end) version can be added as needed.
226 ///
227 /// Also note that certain options in `copyOptions` aren't looked at anymore,
228 /// like slowMemorySpace.
229 LogicalResult generateCopyForMemRegion(const MemRefRegion &memrefRegion,
230                                        Operation *analyzedOp,
231                                        const AffineCopyOptions &copyOptions,
232                                        CopyGenerateResult &result);
233 
234 /// Tile a nest of standard for loops rooted at `rootForOp` by finding such
235 /// parametric tile sizes that the outer loops have a fixed number of iterations
236 /// as defined in `sizes`.
237 TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef<int64_t> sizes);
238 
239 /// Replace a perfect nest of "for" loops with a single linearized loop. Assumes
240 /// `loops` contains a list of perfectly nested loops with bounds and steps
241 /// independent of any loop induction variable involved in the nest.
242 void coalesceLoops(MutableArrayRef<scf::ForOp> loops);
243 
244 /// Take the ParallelLoop and for each set of dimension indices, combine them
245 /// into a single dimension. combinedDimensions must contain each index into
246 /// loops exactly once.
247 void collapseParallelLoops(scf::ParallelOp loops,
248                            ArrayRef<std::vector<unsigned>> combinedDimensions);
249 
250 /// Maps `forOp` for execution on a parallel grid of virtual `processorIds` of
251 /// size given by `numProcessors`. This is achieved by embedding the SSA values
252 /// corresponding to `processorIds` and `numProcessors` into the bounds and step
253 /// of the `forOp`. No check is performed on the legality of the rewrite, it is
254 /// the caller's responsibility to ensure legality.
255 ///
256 /// Requires that `processorIds` and `numProcessors` have the same size and that
257 /// for each idx, `processorIds`[idx] takes, at runtime, all values between 0
258 /// and `numProcessors`[idx] - 1. This corresponds to traditional use cases for:
259 ///   1. GPU (threadIdx, get_local_id(), ...)
260 ///   2. MPI (MPI_Comm_rank)
261 ///   3. OpenMP (omp_get_thread_num)
262 ///
263 /// Example:
264 /// Assuming a 2-d grid with processorIds = [blockIdx.x, threadIdx.x] and
265 /// numProcessors = [gridDim.x, blockDim.x], the loop:
266 ///
267 /// ```
268 ///    scf.for %i = %lb to %ub step %step {
269 ///      ...
270 ///    }
271 /// ```
272 ///
273 /// is rewritten into a version resembling the following pseudo-IR:
274 ///
275 /// ```
276 ///    scf.for %i = %lb + %step * (threadIdx.x + blockIdx.x * blockDim.x)
277 ///       to %ub step %gridDim.x * blockDim.x * %step {
278 ///      ...
279 ///    }
280 /// ```
281 void mapLoopToProcessorIds(scf::ForOp forOp, ArrayRef<Value> processorId,
282                            ArrayRef<Value> numProcessors);
283 
284 /// Gathers all AffineForOps in 'func' grouped by loop depth.
285 void gatherLoops(FuncOp func,
286                  std::vector<SmallVector<AffineForOp, 2>> &depthToLoops);
287 
288 /// Creates an AffineForOp while ensuring that the lower and upper bounds are
289 /// canonicalized, i.e., unused and duplicate operands are removed, any constant
290 /// operands propagated/folded in, and duplicate bound maps dropped.
291 AffineForOp createCanonicalizedAffineForOp(OpBuilder b, Location loc,
292                                            ValueRange lbOperands,
293                                            AffineMap lbMap,
294                                            ValueRange ubOperands,
295                                            AffineMap ubMap, int64_t step = 1);
296 
297 /// Separates full tiles from partial tiles for a perfect nest `nest` by
298 /// generating a conditional guard that selects between the full tile version
299 /// and the partial tile version using an AffineIfOp. The original loop nest
300 /// is replaced by this guarded two version form.
301 ///
302 ///    affine.if (cond)
303 ///      // full_tile
304 ///    else
305 ///      // partial tile
306 ///
307 LogicalResult
308 separateFullTiles(MutableArrayRef<AffineForOp> nest,
309                   SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);
310 
311 /// Move loop invariant code out of `looplike`.
312 LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike);
313 
314 } // end namespace mlir
315 
316 #endif // MLIR_TRANSFORMS_LOOP_UTILS_H
317