1 //===- MatrixUtils.h - Utilities to lower matrix intrinsics -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utilities for generating tiled loops for matrix operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
14 #define LLVM_TRANSFORMS_UTILS_MATRIXUTILS_H
15 
16 #include "llvm/ADT/StringRef.h"
17 
18 namespace llvm {
19 class DomTreeUpdater;
20 class BasicBlock;
21 class Value;
22 class Loop;
23 class LoopInfo;
24 class IRBuilderBase;
25 
26 /// A helper struct to create IR loop nests for tiling in IR of the following
27 /// form:
28 ///   for ColumnLoop.Index = 0..NumColumns
29 ///     for RowLoop.Index = 0..NumRows
30 ///       for KLoop.Index = 0..NumInner
31 struct TileInfo {
32   /// Number of rows of the matrix.
33   unsigned NumRows;
34 
35   /// Number of columns of the matrix.
36   unsigned NumColumns;
37 
38   /// Number of columns of the first matrix of a multiply /
39   /// number of rows of the second matrix of a multiply.
40   unsigned NumInner;
41 
42   /// Number of rows/columns in a tile.
43   unsigned TileSize = -1;
44 
45   /// Properties of a single loop used when generating the tiled loop nest.
46   struct MatrixLoop {
47     /// The index updated on every iteration.
48     Value *Index = nullptr;
49     /// The header and latch of the loop.
50     BasicBlock *Header = nullptr;
51     BasicBlock *Latch = nullptr;
52   };
53 
54   /// The loop iterating on the rows.
55   MatrixLoop RowLoop;
56   /// The loop iterating on the columns.
57   MatrixLoop ColumnLoop;
58   /// The loop iterating on k (inner dimension).
59   MatrixLoop KLoop;
60 
TileInfoTileInfo61   TileInfo(unsigned NumRows, unsigned NumColumns, unsigned NumInner,
62            unsigned TileSize)
63       : NumRows(NumRows), NumColumns(NumColumns), NumInner(NumInner),
64         TileSize(TileSize) {}
65 
66   /// Creates an IR loop nests for tiling of the form below. Returns the block
67   /// for the inner loop body and sets {Column,Row,Inner}LoopHeader/Latch
68   /// fields.
69   ///
70   /// for ColumnLoop.Index = 0..NumColumns
71   ///   for RowLoop.Index = 0..NumRows
72   ///     for InnerLoop.Index = 0..NumInner
73   BasicBlock *CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
74                                IRBuilderBase &B, DomTreeUpdater &DTU,
75                                LoopInfo &LI);
76 
77 private:
78   /// Creates a new loop with header, body and latch blocks that iterates from
79   /// [0, Bound). Updates \p Preheader to branch to the new header and uses \p
80   /// Exit as exit block.  Adds the new loop blocks to \L and applies dominator
81   /// tree updates to \p DTU.
82   static BasicBlock *CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
83                                 Value *Bound, Value *Step, StringRef Name,
84                                 IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
85                                 LoopInfo &LI);
86 };
87 } // namespace llvm
88 
89 #endif
90