1 //===- MatmulOptimizer.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "polly/MatmulOptimizer.h"
10 #include "polly/DependenceInfo.h"
11 #include "polly/Options.h"
12 #include "polly/ScheduleTreeTransform.h"
13 #include "polly/ScopInfo.h"
14 #include "polly/ScopPass.h"
15 #include "polly/Simplify.h"
16 #include "polly/Support/ISLTools.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/Optional.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/iterator_range.h"
23 #include "llvm/Analysis/TargetTransformInfo.h"
24 #include "llvm/IR/DataLayout.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/Support/CommandLine.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/TypeSize.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include "isl/ctx.h"
32 #include "isl/schedule_node.h"
33 #include "isl/schedule_type.h"
34 #include "isl/union_map.h"
35 #include "isl/union_set.h"
36 #include <algorithm>
37 #include <cassert>
38 #include <cmath>
39 #include <cstdint>
40 #include <string>
41 #include <vector>
42 
43 #define DEBUG_TYPE "polly-opt-isl"
44 
45 using namespace llvm;
46 using namespace polly;
47 
48 namespace llvm {
49 class Value;
50 }
51 
52 static cl::opt<int> LatencyVectorFma(
53     "polly-target-latency-vector-fma",
54     cl::desc("The minimal number of cycles between issuing two "
55              "dependent consecutive vector fused multiply-add "
56              "instructions."),
57     cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
58 
59 static cl::opt<int> ThroughputVectorFma(
60     "polly-target-throughput-vector-fma",
61     cl::desc("A throughput of the processor floating-point arithmetic units "
62              "expressed in the number of vector fused multiply-add "
63              "instructions per clock cycle."),
64     cl::Hidden, cl::init(1), cl::ZeroOrMore, cl::cat(PollyCategory));
65 
66 static cl::opt<int> FirstCacheLevelSize(
67     "polly-target-1st-cache-level-size",
68     cl::desc("The size of the first cache level specified in bytes."),
69     cl::Hidden, cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
70 
71 static cl::opt<int> FirstCacheLevelDefaultSize(
72     "polly-target-1st-cache-level-default-size",
73     cl::desc("The default size of the first cache level specified in bytes"
74              " (if not enough were provided by the TargetTransformInfo)."),
75     cl::Hidden, cl::init(32768), cl::ZeroOrMore, cl::cat(PollyCategory));
76 
77 static cl::opt<int> SecondCacheLevelSize(
78     "polly-target-2nd-cache-level-size",
79     cl::desc("The size of the second level specified in bytes."), cl::Hidden,
80     cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
81 
82 static cl::opt<int> SecondCacheLevelDefaultSize(
83     "polly-target-2nd-cache-level-default-size",
84     cl::desc("The default size of the second cache level specified in bytes"
85              " (if not enough were provided by the TargetTransformInfo)."),
86     cl::Hidden, cl::init(262144), cl::ZeroOrMore, cl::cat(PollyCategory));
87 
88 // This option, along with --polly-target-2nd-cache-level-associativity,
89 // --polly-target-1st-cache-level-size, and --polly-target-2st-cache-level-size
90 // represent the parameters of the target cache, which do not have typical
91 // values that can be used by default. However, to apply the pattern matching
92 // optimizations, we use the values of the parameters of Intel Core i7-3820
93 // SandyBridge in case the parameters are not specified or not provided by the
94 // TargetTransformInfo.
95 static cl::opt<int> FirstCacheLevelAssociativity(
96     "polly-target-1st-cache-level-associativity",
97     cl::desc("The associativity of the first cache level."), cl::Hidden,
98     cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
99 
100 static cl::opt<int> FirstCacheLevelDefaultAssociativity(
101     "polly-target-1st-cache-level-default-associativity",
102     cl::desc("The default associativity of the first cache level"
103              " (if not enough were provided by the TargetTransformInfo)."),
104     cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
105 
106 static cl::opt<int> SecondCacheLevelAssociativity(
107     "polly-target-2nd-cache-level-associativity",
108     cl::desc("The associativity of the second cache level."), cl::Hidden,
109     cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
110 
111 static cl::opt<int> SecondCacheLevelDefaultAssociativity(
112     "polly-target-2nd-cache-level-default-associativity",
113     cl::desc("The default associativity of the second cache level"
114              " (if not enough were provided by the TargetTransformInfo)."),
115     cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
116 
117 static cl::opt<int> VectorRegisterBitwidth(
118     "polly-target-vector-register-bitwidth",
119     cl::desc("The size in bits of a vector register (if not set, this "
120              "information is taken from LLVM's target information."),
121     cl::Hidden, cl::init(-1), cl::ZeroOrMore, cl::cat(PollyCategory));
122 
123 static cl::opt<int> PollyPatternMatchingNcQuotient(
124     "polly-pattern-matching-nc-quotient",
125     cl::desc("Quotient that is obtained by dividing Nc, the parameter of the"
126              "macro-kernel, by Nr, the parameter of the micro-kernel"),
127     cl::Hidden, cl::init(256), cl::ZeroOrMore, cl::cat(PollyCategory));
128 
129 namespace {
130 /// Parameters of the micro kernel.
131 ///
132 /// Parameters, which determine sizes of rank-1 (i.e., outer product) update
133 /// used in the optimized matrix multiplication.
134 struct MicroKernelParamsTy {
135   int Mr;
136   int Nr;
137 };
138 
139 /// Parameters of the macro kernel.
140 ///
141 /// Parameters, which determine sizes of blocks of partitioned matrices
142 /// used in the optimized matrix multiplication.
143 struct MacroKernelParamsTy {
144   int Mc;
145   int Nc;
146   int Kc;
147 };
148 
149 /// Parameters of the matrix multiplication operands.
150 ///
151 /// Parameters, which describe access relations that represent operands of the
152 /// matrix multiplication.
153 struct MatMulInfoTy {
154   MemoryAccess *A = nullptr;
155   MemoryAccess *B = nullptr;
156   MemoryAccess *ReadFromC = nullptr;
157   MemoryAccess *WriteToC = nullptr;
158   int i = -1;
159   int j = -1;
160   int k = -1;
161 };
162 
163 /// Create an isl::union_set, which describes the option of the form
164 /// [isolate[] -> unroll[x]].
165 ///
166 /// @param Ctx An isl::ctx, which is used to create the isl::union_set.
getUnrollIsolatedSetOptions(isl::ctx Ctx)167 static isl::union_set getUnrollIsolatedSetOptions(isl::ctx Ctx) {
168   isl::space Space = isl::space(Ctx, 0, 0, 1);
169   isl::map UnrollIsolatedSetOption = isl::map::universe(Space);
170   isl::id DimInId = isl::id::alloc(Ctx, "isolate", nullptr);
171   isl::id DimOutId = isl::id::alloc(Ctx, "unroll", nullptr);
172   UnrollIsolatedSetOption =
173       UnrollIsolatedSetOption.set_tuple_id(isl::dim::in, DimInId);
174   UnrollIsolatedSetOption =
175       UnrollIsolatedSetOption.set_tuple_id(isl::dim::out, DimOutId);
176   return UnrollIsolatedSetOption.wrap();
177 }
178 
179 /// Permute the two dimensions of the isl map.
180 ///
181 /// Permute @p DstPos and @p SrcPos dimensions of the isl map @p Map that
182 /// have type @p DimType.
183 ///
184 /// @param Map     The isl map to be modified.
185 /// @param DimType The type of the dimensions.
186 /// @param DstPos  The first dimension.
187 /// @param SrcPos  The second dimension.
188 /// @return        The modified map.
permuteDimensions(isl::map Map,isl::dim DimType,unsigned DstPos,unsigned SrcPos)189 static isl::map permuteDimensions(isl::map Map, isl::dim DimType,
190                                   unsigned DstPos, unsigned SrcPos) {
191   assert((isl_size)DstPos < Map.dim(DimType) &&
192          (isl_size)SrcPos < Map.dim(DimType));
193   if (DstPos == SrcPos)
194     return Map;
195   isl::id DimId;
196   if (Map.has_tuple_id(DimType))
197     DimId = Map.get_tuple_id(DimType);
198   auto FreeDim = DimType == isl::dim::in ? isl::dim::out : isl::dim::in;
199   isl::id FreeDimId;
200   if (Map.has_tuple_id(FreeDim))
201     FreeDimId = Map.get_tuple_id(FreeDim);
202   auto MaxDim = std::max(DstPos, SrcPos);
203   auto MinDim = std::min(DstPos, SrcPos);
204   Map = Map.move_dims(FreeDim, 0, DimType, MaxDim, 1);
205   Map = Map.move_dims(FreeDim, 0, DimType, MinDim, 1);
206   Map = Map.move_dims(DimType, MinDim, FreeDim, 1, 1);
207   Map = Map.move_dims(DimType, MaxDim, FreeDim, 0, 1);
208   if (!DimId.is_null())
209     Map = Map.set_tuple_id(DimType, DimId);
210   if (!FreeDimId.is_null())
211     Map = Map.set_tuple_id(FreeDim, FreeDimId);
212   return Map;
213 }
214 
215 /// Check the form of the access relation.
216 ///
217 /// Check that the access relation @p AccMap has the form M[i][j], where i
218 /// is a @p FirstPos and j is a @p SecondPos.
219 ///
220 /// @param AccMap    The access relation to be checked.
221 /// @param FirstPos  The index of the input dimension that is mapped to
222 ///                  the first output dimension.
223 /// @param SecondPos The index of the input dimension that is mapped to the
224 ///                  second output dimension.
225 /// @return          True in case @p AccMap has the expected form and false,
226 ///                  otherwise.
isMatMulOperandAcc(isl::set Domain,isl::map AccMap,int & FirstPos,int & SecondPos)227 static bool isMatMulOperandAcc(isl::set Domain, isl::map AccMap, int &FirstPos,
228                                int &SecondPos) {
229   isl::space Space = AccMap.get_space();
230   isl::map Universe = isl::map::universe(Space);
231 
232   if (Space.dim(isl::dim::out) != 2)
233     return false;
234 
235   // MatMul has the form:
236   // for (i = 0; i < N; i++)
237   //   for (j = 0; j < M; j++)
238   //     for (k = 0; k < P; k++)
239   //       C[i, j] += A[i, k] * B[k, j]
240   //
241   // Permutation of three outer loops: 3! = 6 possibilities.
242   int FirstDims[] = {0, 0, 1, 1, 2, 2};
243   int SecondDims[] = {1, 2, 2, 0, 0, 1};
244   for (int i = 0; i < 6; i += 1) {
245     auto PossibleMatMul =
246         Universe.equate(isl::dim::in, FirstDims[i], isl::dim::out, 0)
247             .equate(isl::dim::in, SecondDims[i], isl::dim::out, 1);
248 
249     AccMap = AccMap.intersect_domain(Domain);
250     PossibleMatMul = PossibleMatMul.intersect_domain(Domain);
251 
252     // If AccMap spans entire domain (Non-partial write),
253     // compute FirstPos and SecondPos.
254     // If AccMap != PossibleMatMul here (the two maps have been gisted at
255     // this point), it means that the writes are not complete, or in other
256     // words, it is a Partial write and Partial writes must be rejected.
257     if (AccMap.is_equal(PossibleMatMul)) {
258       if (FirstPos != -1 && FirstPos != FirstDims[i])
259         continue;
260       FirstPos = FirstDims[i];
261       if (SecondPos != -1 && SecondPos != SecondDims[i])
262         continue;
263       SecondPos = SecondDims[i];
264       return true;
265     }
266   }
267 
268   return false;
269 }
270 
271 /// Does the memory access represent a non-scalar operand of the matrix
272 /// multiplication.
273 ///
274 /// Check that the memory access @p MemAccess is the read access to a non-scalar
275 /// operand of the matrix multiplication or its result.
276 ///
277 /// @param MemAccess The memory access to be checked.
278 /// @param MMI       Parameters of the matrix multiplication operands.
279 /// @return          True in case the memory access represents the read access
280 ///                  to a non-scalar operand of the matrix multiplication and
281 ///                  false, otherwise.
isMatMulNonScalarReadAccess(MemoryAccess * MemAccess,MatMulInfoTy & MMI)282 static bool isMatMulNonScalarReadAccess(MemoryAccess *MemAccess,
283                                         MatMulInfoTy &MMI) {
284   if (!MemAccess->isLatestArrayKind() || !MemAccess->isRead())
285     return false;
286   auto AccMap = MemAccess->getLatestAccessRelation();
287   isl::set StmtDomain = MemAccess->getStatement()->getDomain();
288   if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.j) && !MMI.ReadFromC) {
289     MMI.ReadFromC = MemAccess;
290     return true;
291   }
292   if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.i, MMI.k) && !MMI.A) {
293     MMI.A = MemAccess;
294     return true;
295   }
296   if (isMatMulOperandAcc(StmtDomain, AccMap, MMI.k, MMI.j) && !MMI.B) {
297     MMI.B = MemAccess;
298     return true;
299   }
300   return false;
301 }
302 
303 /// Check accesses to operands of the matrix multiplication.
304 ///
305 /// Check that accesses of the SCoP statement, which corresponds to
306 /// the partial schedule @p PartialSchedule, are scalar in terms of loops
307 /// containing the matrix multiplication, in case they do not represent
308 /// accesses to the non-scalar operands of the matrix multiplication or
309 /// its result.
310 ///
311 /// @param  PartialSchedule The partial schedule of the SCoP statement.
312 /// @param  MMI             Parameters of the matrix multiplication operands.
313 /// @return                 True in case the corresponding SCoP statement
314 ///                         represents matrix multiplication and false,
315 ///                         otherwise.
containsOnlyMatrMultAcc(isl::map PartialSchedule,MatMulInfoTy & MMI)316 static bool containsOnlyMatrMultAcc(isl::map PartialSchedule,
317                                     MatMulInfoTy &MMI) {
318   auto InputDimId = PartialSchedule.get_tuple_id(isl::dim::in);
319   auto *Stmt = static_cast<ScopStmt *>(InputDimId.get_user());
320   isl_size OutDimNum = PartialSchedule.range_tuple_dim();
321   assert(OutDimNum > 2 && "In case of the matrix multiplication the loop nest "
322                           "and, consequently, the corresponding scheduling "
323                           "functions have at least three dimensions.");
324   auto MapI =
325       permuteDimensions(PartialSchedule, isl::dim::out, MMI.i, OutDimNum - 1);
326   auto MapJ =
327       permuteDimensions(PartialSchedule, isl::dim::out, MMI.j, OutDimNum - 1);
328   auto MapK =
329       permuteDimensions(PartialSchedule, isl::dim::out, MMI.k, OutDimNum - 1);
330 
331   auto Accesses = getAccessesInOrder(*Stmt);
332   for (auto *MemA = Accesses.begin(); MemA != Accesses.end() - 1; MemA++) {
333     auto *MemAccessPtr = *MemA;
334     if (MemAccessPtr->isLatestArrayKind() && MemAccessPtr != MMI.WriteToC &&
335         !isMatMulNonScalarReadAccess(MemAccessPtr, MMI) &&
336         !(MemAccessPtr->isStrideZero(MapI)) &&
337         MemAccessPtr->isStrideZero(MapJ) && MemAccessPtr->isStrideZero(MapK))
338       return false;
339   }
340   return true;
341 }
342 
343 /// Check for dependencies corresponding to the matrix multiplication.
344 ///
345 /// Check that there is only true dependence of the form
346 /// S(..., k, ...) -> S(..., k + 1, …), where S is the SCoP statement
347 /// represented by @p Schedule and k is @p Pos. Such a dependence corresponds
348 /// to the dependency produced by the matrix multiplication.
349 ///
350 /// @param  Schedule The schedule of the SCoP statement.
351 /// @param  D The SCoP dependencies.
352 /// @param  Pos The parameter to describe an acceptable true dependence.
353 ///             In case it has a negative value, try to determine its
354 ///             acceptable value.
355 /// @return True in case dependencies correspond to the matrix multiplication
356 ///         and false, otherwise.
containsOnlyMatMulDep(isl::map Schedule,const Dependences * D,int & Pos)357 static bool containsOnlyMatMulDep(isl::map Schedule, const Dependences *D,
358                                   int &Pos) {
359   isl::union_map Dep = D->getDependences(Dependences::TYPE_RAW);
360   isl::union_map Red = D->getDependences(Dependences::TYPE_RED);
361   if (!Red.is_null())
362     Dep = Dep.unite(Red);
363   auto DomainSpace = Schedule.get_space().domain();
364   auto Space = DomainSpace.map_from_domain_and_range(DomainSpace);
365   auto Deltas = Dep.extract_map(Space).deltas();
366   isl_size DeltasDimNum = Deltas.dim(isl::dim::set);
367   for (int i = 0; i < DeltasDimNum; i++) {
368     auto Val = Deltas.plain_get_val_if_fixed(isl::dim::set, i);
369     Pos = Pos < 0 && Val.is_one() ? i : Pos;
370     if (Val.is_nan() || !(Val.is_zero() || (i == Pos && Val.is_one())))
371       return false;
372   }
373   if (DeltasDimNum == 0 || Pos < 0)
374     return false;
375   return true;
376 }
377 
378 /// Check if the SCoP statement could probably be optimized with analytical
379 /// modeling.
380 ///
381 /// containsMatrMult tries to determine whether the following conditions
382 /// are true:
383 /// 1. The last memory access modeling an array, MA1, represents writing to
384 ///    memory and has the form S(..., i1, ..., i2, ...) -> M(i1, i2) or
385 ///    S(..., i2, ..., i1, ...) -> M(i1, i2), where S is the SCoP statement
386 ///    under consideration.
387 /// 2. There is only one loop-carried true dependency, and it has the
388 ///    form S(..., i3, ...) -> S(..., i3 + 1, ...), and there are no
389 ///    loop-carried or anti dependencies.
390 /// 3. SCoP contains three access relations, MA2, MA3, and MA4 that represent
391 ///    reading from memory and have the form S(..., i3, ...) -> M(i1, i3),
392 ///    S(..., i3, ...) -> M(i3, i2), S(...) -> M(i1, i2), respectively,
393 ///    and all memory accesses of the SCoP that are different from MA1, MA2,
394 ///    MA3, and MA4 have stride 0, if the innermost loop is exchanged with any
395 ///    of loops i1, i2 and i3.
396 ///
397 /// @param PartialSchedule The PartialSchedule that contains a SCoP statement
398 ///        to check.
399 /// @D     The SCoP dependencies.
400 /// @MMI   Parameters of the matrix multiplication operands.
containsMatrMult(isl::map PartialSchedule,const Dependences * D,MatMulInfoTy & MMI)401 static bool containsMatrMult(isl::map PartialSchedule, const Dependences *D,
402                              MatMulInfoTy &MMI) {
403   auto InputDimsId = PartialSchedule.get_tuple_id(isl::dim::in);
404   auto *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
405   if (Stmt->size() <= 1)
406     return false;
407 
408   auto Accesses = getAccessesInOrder(*Stmt);
409   for (auto *MemA = Accesses.end() - 1; MemA != Accesses.begin(); MemA--) {
410     auto *MemAccessPtr = *MemA;
411     if (!MemAccessPtr->isLatestArrayKind())
412       continue;
413     if (!MemAccessPtr->isWrite())
414       return false;
415     auto AccMap = MemAccessPtr->getLatestAccessRelation();
416     if (!isMatMulOperandAcc(Stmt->getDomain(), AccMap, MMI.i, MMI.j))
417       return false;
418     MMI.WriteToC = MemAccessPtr;
419     break;
420   }
421 
422   if (!containsOnlyMatMulDep(PartialSchedule, D, MMI.k))
423     return false;
424 
425   if (!MMI.WriteToC || !containsOnlyMatrMultAcc(PartialSchedule, MMI))
426     return false;
427 
428   if (!MMI.A || !MMI.B || !MMI.ReadFromC)
429     return false;
430   return true;
431 }
432 
433 /// Permute two dimensions of the band node.
434 ///
435 /// Permute FirstDim and SecondDim dimensions of the Node.
436 ///
437 /// @param Node The band node to be modified.
438 /// @param FirstDim The first dimension to be permuted.
439 /// @param SecondDim The second dimension to be permuted.
permuteBandNodeDimensions(isl::schedule_node Node,unsigned FirstDim,unsigned SecondDim)440 static isl::schedule_node permuteBandNodeDimensions(isl::schedule_node Node,
441                                                     unsigned FirstDim,
442                                                     unsigned SecondDim) {
443   assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band &&
444          (unsigned)isl_schedule_node_band_n_member(Node.get()) >
445              std::max(FirstDim, SecondDim));
446   auto PartialSchedule =
447       isl::manage(isl_schedule_node_band_get_partial_schedule(Node.get()));
448   auto PartialScheduleFirstDim = PartialSchedule.get_union_pw_aff(FirstDim);
449   auto PartialScheduleSecondDim = PartialSchedule.get_union_pw_aff(SecondDim);
450   PartialSchedule =
451       PartialSchedule.set_union_pw_aff(SecondDim, PartialScheduleFirstDim);
452   PartialSchedule =
453       PartialSchedule.set_union_pw_aff(FirstDim, PartialScheduleSecondDim);
454   Node = isl::manage(isl_schedule_node_delete(Node.release()));
455   return Node.insert_partial_schedule(PartialSchedule);
456 }
457 
458 static isl::schedule_node
createMicroKernel(isl::schedule_node Node,MicroKernelParamsTy MicroKernelParams)459 createMicroKernel(isl::schedule_node Node,
460                   MicroKernelParamsTy MicroKernelParams) {
461   Node = applyRegisterTiling(Node, {MicroKernelParams.Mr, MicroKernelParams.Nr},
462                              1);
463   Node = Node.parent().parent();
464   return permuteBandNodeDimensions(Node, 0, 1).child(0).child(0);
465 }
466 
467 /// Create the BLIS macro-kernel.
468 ///
469 /// We create the BLIS macro-kernel by applying a combination of tiling
470 /// of dimensions of the band node and interchanging of two innermost
471 /// modified dimensions. The values of of MacroKernelParams's fields are used
472 /// as tile sizes.
473 ///
474 /// @param Node The schedule node to be modified.
475 /// @param MacroKernelParams Parameters of the macro kernel
476 ///                          to be used as tile sizes.
477 static isl::schedule_node
createMacroKernel(isl::schedule_node Node,MacroKernelParamsTy MacroKernelParams)478 createMacroKernel(isl::schedule_node Node,
479                   MacroKernelParamsTy MacroKernelParams) {
480   assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
481   if (MacroKernelParams.Mc == 1 && MacroKernelParams.Nc == 1 &&
482       MacroKernelParams.Kc == 1)
483     return Node;
484   int DimOutNum = isl_schedule_node_band_n_member(Node.get());
485   std::vector<int> TileSizes(DimOutNum, 1);
486   TileSizes[DimOutNum - 3] = MacroKernelParams.Mc;
487   TileSizes[DimOutNum - 2] = MacroKernelParams.Nc;
488   TileSizes[DimOutNum - 1] = MacroKernelParams.Kc;
489   Node = tileNode(Node, "1st level tiling", TileSizes, 1);
490   Node = Node.parent().parent();
491   Node = permuteBandNodeDimensions(Node, DimOutNum - 2, DimOutNum - 1);
492   Node = permuteBandNodeDimensions(Node, DimOutNum - 3, DimOutNum - 1);
493 
494   // Mark the outermost loop as parallelizable.
495   Node = Node.band_member_set_coincident(0, true);
496 
497   return Node.child(0).child(0);
498 }
499 
500 /// Get the size of the widest type of the matrix multiplication operands
501 /// in bytes, including alignment padding.
502 ///
503 /// @param MMI Parameters of the matrix multiplication operands.
504 /// @return The size of the widest type of the matrix multiplication operands
505 ///         in bytes, including alignment padding.
getMatMulAlignTypeSize(MatMulInfoTy MMI)506 static uint64_t getMatMulAlignTypeSize(MatMulInfoTy MMI) {
507   auto *S = MMI.A->getStatement()->getParent();
508   auto &DL = S->getFunction().getParent()->getDataLayout();
509   auto ElementSizeA = DL.getTypeAllocSize(MMI.A->getElementType());
510   auto ElementSizeB = DL.getTypeAllocSize(MMI.B->getElementType());
511   auto ElementSizeC = DL.getTypeAllocSize(MMI.WriteToC->getElementType());
512   return std::max({ElementSizeA, ElementSizeB, ElementSizeC});
513 }
514 
515 /// Get the size of the widest type of the matrix multiplication operands
516 /// in bits.
517 ///
518 /// @param MMI Parameters of the matrix multiplication operands.
519 /// @return The size of the widest type of the matrix multiplication operands
520 ///         in bits.
getMatMulTypeSize(MatMulInfoTy MMI)521 static uint64_t getMatMulTypeSize(MatMulInfoTy MMI) {
522   auto *S = MMI.A->getStatement()->getParent();
523   auto &DL = S->getFunction().getParent()->getDataLayout();
524   auto ElementSizeA = DL.getTypeSizeInBits(MMI.A->getElementType());
525   auto ElementSizeB = DL.getTypeSizeInBits(MMI.B->getElementType());
526   auto ElementSizeC = DL.getTypeSizeInBits(MMI.WriteToC->getElementType());
527   return std::max({ElementSizeA, ElementSizeB, ElementSizeC});
528 }
529 
530 /// Get parameters of the BLIS micro kernel.
531 ///
532 /// We choose the Mr and Nr parameters of the micro kernel to be large enough
533 /// such that no stalls caused by the combination of latencies and dependencies
534 /// are introduced during the updates of the resulting matrix of the matrix
535 /// multiplication. However, they should also be as small as possible to
536 /// release more registers for entries of multiplied matrices.
537 ///
538 /// @param TTI Target Transform Info.
539 /// @param MMI Parameters of the matrix multiplication operands.
540 /// @return The structure of type MicroKernelParamsTy.
541 /// @see MicroKernelParamsTy
542 static struct MicroKernelParamsTy
getMicroKernelParams(const TargetTransformInfo * TTI,MatMulInfoTy MMI)543 getMicroKernelParams(const TargetTransformInfo *TTI, MatMulInfoTy MMI) {
544   assert(TTI && "The target transform info should be provided.");
545 
546   // Nvec - Number of double-precision floating-point numbers that can be hold
547   // by a vector register. Use 2 by default.
548   long RegisterBitwidth = VectorRegisterBitwidth;
549 
550   if (RegisterBitwidth == -1)
551     RegisterBitwidth =
552         TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector);
553   auto ElementSize = getMatMulTypeSize(MMI);
554   assert(ElementSize > 0 && "The element size of the matrix multiplication "
555                             "operands should be greater than zero.");
556   auto Nvec = RegisterBitwidth / ElementSize;
557   if (Nvec == 0)
558     Nvec = 2;
559   int Nr = ceil(sqrt((double)(Nvec * LatencyVectorFma * ThroughputVectorFma)) /
560                 Nvec) *
561            Nvec;
562   int Mr = ceil((double)(Nvec * LatencyVectorFma * ThroughputVectorFma / Nr));
563   return {Mr, Nr};
564 }
565 
566 /// Determine parameters of the target cache.
567 ///
568 /// @param TTI Target Transform Info.
getTargetCacheParameters(const llvm::TargetTransformInfo * TTI)569 static void getTargetCacheParameters(const llvm::TargetTransformInfo *TTI) {
570   auto L1DCache = llvm::TargetTransformInfo::CacheLevel::L1D;
571   auto L2DCache = llvm::TargetTransformInfo::CacheLevel::L2D;
572   if (FirstCacheLevelSize == -1) {
573     if (TTI->getCacheSize(L1DCache).hasValue())
574       FirstCacheLevelSize = TTI->getCacheSize(L1DCache).getValue();
575     else
576       FirstCacheLevelSize = static_cast<int>(FirstCacheLevelDefaultSize);
577   }
578   if (SecondCacheLevelSize == -1) {
579     if (TTI->getCacheSize(L2DCache).hasValue())
580       SecondCacheLevelSize = TTI->getCacheSize(L2DCache).getValue();
581     else
582       SecondCacheLevelSize = static_cast<int>(SecondCacheLevelDefaultSize);
583   }
584   if (FirstCacheLevelAssociativity == -1) {
585     if (TTI->getCacheAssociativity(L1DCache).hasValue())
586       FirstCacheLevelAssociativity =
587           TTI->getCacheAssociativity(L1DCache).getValue();
588     else
589       FirstCacheLevelAssociativity =
590           static_cast<int>(FirstCacheLevelDefaultAssociativity);
591   }
592   if (SecondCacheLevelAssociativity == -1) {
593     if (TTI->getCacheAssociativity(L2DCache).hasValue())
594       SecondCacheLevelAssociativity =
595           TTI->getCacheAssociativity(L2DCache).getValue();
596     else
597       SecondCacheLevelAssociativity =
598           static_cast<int>(SecondCacheLevelDefaultAssociativity);
599   }
600 }
601 
602 /// Get parameters of the BLIS macro kernel.
603 ///
604 /// During the computation of matrix multiplication, blocks of partitioned
605 /// matrices are mapped to different layers of the memory hierarchy.
606 /// To optimize data reuse, blocks should be ideally kept in cache between
607 /// iterations. Since parameters of the macro kernel determine sizes of these
608 /// blocks, there are upper and lower bounds on these parameters.
609 ///
610 /// @param TTI Target Transform Info.
611 /// @param MicroKernelParams Parameters of the micro-kernel
612 ///                          to be taken into account.
613 /// @param MMI Parameters of the matrix multiplication operands.
614 /// @return The structure of type MacroKernelParamsTy.
615 /// @see MacroKernelParamsTy
616 /// @see MicroKernelParamsTy
617 static struct MacroKernelParamsTy
getMacroKernelParams(const llvm::TargetTransformInfo * TTI,const MicroKernelParamsTy & MicroKernelParams,MatMulInfoTy MMI)618 getMacroKernelParams(const llvm::TargetTransformInfo *TTI,
619                      const MicroKernelParamsTy &MicroKernelParams,
620                      MatMulInfoTy MMI) {
621   getTargetCacheParameters(TTI);
622   // According to www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf,
623   // it requires information about the first two levels of a cache to determine
624   // all the parameters of a macro-kernel. It also checks that an associativity
625   // degree of a cache level is greater than two. Otherwise, another algorithm
626   // for determination of the parameters should be used.
627   if (!(MicroKernelParams.Mr > 0 && MicroKernelParams.Nr > 0 &&
628         FirstCacheLevelSize > 0 && SecondCacheLevelSize > 0 &&
629         FirstCacheLevelAssociativity > 2 && SecondCacheLevelAssociativity > 2))
630     return {1, 1, 1};
631   // The quotient should be greater than zero.
632   if (PollyPatternMatchingNcQuotient <= 0)
633     return {1, 1, 1};
634   int Car = floor(
635       (FirstCacheLevelAssociativity - 1) /
636       (1 + static_cast<double>(MicroKernelParams.Nr) / MicroKernelParams.Mr));
637 
638   // Car can be computed to be zero since it is floor to int.
639   // On Mac OS, division by 0 does not raise a signal. This causes negative
640   // tile sizes to be computed. Prevent division by Cac==0 by early returning
641   // if this happens.
642   if (Car == 0)
643     return {1, 1, 1};
644 
645   auto ElementSize = getMatMulAlignTypeSize(MMI);
646   assert(ElementSize > 0 && "The element size of the matrix multiplication "
647                             "operands should be greater than zero.");
648   int Kc = (Car * FirstCacheLevelSize) /
649            (MicroKernelParams.Mr * FirstCacheLevelAssociativity * ElementSize);
650   double Cac =
651       static_cast<double>(Kc * ElementSize * SecondCacheLevelAssociativity) /
652       SecondCacheLevelSize;
653   int Mc = floor((SecondCacheLevelAssociativity - 2) / Cac);
654   int Nc = PollyPatternMatchingNcQuotient * MicroKernelParams.Nr;
655 
656   assert(Mc > 0 && Nc > 0 && Kc > 0 &&
657          "Matrix block sizes should be  greater than zero");
658   return {Mc, Nc, Kc};
659 }
660 
661 /// Create an access relation that is specific to
662 ///        the matrix multiplication pattern.
663 ///
664 /// Create an access relation of the following form:
665 /// [O0, O1, O2, O3, O4, O5, O6, O7, O8] -> [OI, O5, OJ]
666 /// where I is @p FirstDim, J is @p SecondDim.
667 ///
668 /// It can be used, for example, to create relations that helps to consequently
669 /// access elements of operands of a matrix multiplication after creation of
670 /// the BLIS micro and macro kernels.
671 ///
672 /// @see ScheduleTreeOptimizer::createMicroKernel
673 /// @see ScheduleTreeOptimizer::createMacroKernel
674 ///
675 /// Subsequently, the described access relation is applied to the range of
676 /// @p MapOldIndVar, that is used to map original induction variables to
677 /// the ones, which are produced by schedule transformations. It helps to
678 /// define relations using a new space and, at the same time, keep them
679 /// in the original one.
680 ///
681 /// @param MapOldIndVar The relation, which maps original induction variables
682 ///                     to the ones, which are produced by schedule
683 ///                     transformations.
684 /// @param FirstDim, SecondDim The input dimensions that are used to define
685 ///        the specified access relation.
686 /// @return The specified access relation.
getMatMulAccRel(isl::map MapOldIndVar,unsigned FirstDim,unsigned SecondDim)687 static isl::map getMatMulAccRel(isl::map MapOldIndVar, unsigned FirstDim,
688                                 unsigned SecondDim) {
689   auto AccessRelSpace = isl::space(MapOldIndVar.ctx(), 0, 9, 3);
690   auto AccessRel = isl::map::universe(AccessRelSpace);
691   AccessRel = AccessRel.equate(isl::dim::in, FirstDim, isl::dim::out, 0);
692   AccessRel = AccessRel.equate(isl::dim::in, 5, isl::dim::out, 1);
693   AccessRel = AccessRel.equate(isl::dim::in, SecondDim, isl::dim::out, 2);
694   return MapOldIndVar.apply_range(AccessRel);
695 }
696 
createExtensionNode(isl::schedule_node Node,isl::map ExtensionMap)697 static isl::schedule_node createExtensionNode(isl::schedule_node Node,
698                                               isl::map ExtensionMap) {
699   auto Extension = isl::union_map(ExtensionMap);
700   auto NewNode = isl::schedule_node::from_extension(Extension);
701   return Node.graft_before(NewNode);
702 }
703 
optimizePackedB(isl::schedule_node Node,ScopStmt * Stmt,isl::map MapOldIndVar,MicroKernelParamsTy MicroParams,MacroKernelParamsTy MacroParams,MatMulInfoTy & MMI)704 static isl::schedule_node optimizePackedB(isl::schedule_node Node,
705                                           ScopStmt *Stmt, isl::map MapOldIndVar,
706                                           MicroKernelParamsTy MicroParams,
707                                           MacroKernelParamsTy MacroParams,
708                                           MatMulInfoTy &MMI) {
709   Scop *S = Stmt->getParent();
710   isl::set Domain = Stmt->getDomain();
711 
712   // Create packed array.
713   unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr;
714   unsigned SecondDimSize = MacroParams.Kc;
715   unsigned ThirdDimSize = MicroParams.Nr;
716   ScopArrayInfo *PackedB =
717       S->createScopArrayInfo(MMI.B->getElementType(), "Packed_B",
718                              {FirstDimSize, SecondDimSize, ThirdDimSize});
719 
720   // Compute the access relation for copying from B to PackedB.
721   isl::map AccRelB = MMI.B->getLatestAccessRelation();
722   isl::map AccRelPackedB = getMatMulAccRel(MapOldIndVar, 3, 7);
723   AccRelPackedB =
724       AccRelPackedB.set_tuple_id(isl::dim::out, PackedB->getBasePtrId());
725 
726   // Create the copy statement and redirect access.
727   ScopStmt *CopyStmt = S->addScopStmt(AccRelB, AccRelPackedB, Domain);
728   MMI.B->setNewAccessRelation(AccRelPackedB);
729 
730   // Insert into the schedule tree.
731   isl::map ExtMap = MapOldIndVar.project_out(
732       isl::dim::out, 2, MapOldIndVar.range_tuple_dim() - 2);
733   ExtMap = ExtMap.reverse();
734   ExtMap = ExtMap.fix_si(isl::dim::out, MMI.i, 0);
735   ExtMap = ExtMap.intersect_range(Domain);
736   ExtMap = ExtMap.set_tuple_id(isl::dim::out, CopyStmt->getDomainId());
737   return createExtensionNode(Node, ExtMap);
738 }
739 
optimizePackedA(isl::schedule_node Node,ScopStmt *,isl::map MapOldIndVar,MicroKernelParamsTy MicroParams,MacroKernelParamsTy MacroParams,MatMulInfoTy & MMI)740 static isl::schedule_node optimizePackedA(isl::schedule_node Node, ScopStmt *,
741                                           isl::map MapOldIndVar,
742                                           MicroKernelParamsTy MicroParams,
743                                           MacroKernelParamsTy MacroParams,
744                                           MatMulInfoTy &MMI) {
745   isl::id InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in);
746   ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
747   isl::set Domain = Stmt->getDomain();
748   isl::id DomainId = Domain.get_tuple_id();
749 
750   // Create the packed array.
751   unsigned FirstDimSize = MacroParams.Mc / MicroParams.Mr;
752   unsigned SecondDimSize = MacroParams.Kc;
753   unsigned ThirdDimSize = MicroParams.Mr;
754   ScopArrayInfo *PackedA = Stmt->getParent()->createScopArrayInfo(
755       MMI.A->getElementType(), "Packed_A",
756       {FirstDimSize, SecondDimSize, ThirdDimSize});
757 
758   // Compute the access relation for copying from A to PackedA.
759   isl::map AccRelA = MMI.A->getLatestAccessRelation();
760   isl::map AccRelPackedA = getMatMulAccRel(MapOldIndVar, 4, 6);
761   AccRelPackedA =
762       AccRelPackedA.set_tuple_id(isl::dim::out, PackedA->getBasePtrId());
763   // { MemrefA[] -> PackedA[] }
764   isl::map PackedATranslator = AccRelPackedA.apply_domain(AccRelA);
765 
766   // Compute the domain for the copy statement.
767   // Construct the copy statement domain out of the 3 outermost scatter
768   // dimensions (to match the 3 band nodes surrounding the extension node) and
769   // the array elements to copy (one statement instance per array element).
770   // { Scatter[] }
771   isl::set ScatterDomain = MapOldIndVar.intersect_domain(Domain).range();
772   // { Scatter[] -> OutermostScatter[] }
773   isl::map OuterDomainMap =
774       makeIdentityMap(ScatterDomain, true).project_out(isl::dim::out, 3, 6);
775   // { Scatter[] -> MemrefA[] }
776   isl::map CopyFrom = MapOldIndVar.reverse().apply_range(AccRelA);
777   // { Scatter[] -> CopyStmt[] }
778   isl::map DomainTranslator = OuterDomainMap.range_product(CopyFrom);
779   // { CopyStmt[] }
780   isl::set CopyDomain = DomainTranslator.range();
781 
782   // Translate the access relations to the new domain.
783   // { CopyStmt[] -> MemrefA[] }
784   CopyFrom = CopyFrom.apply_domain(DomainTranslator);
785   // { CopyStmt[] -> PackedA[] }
786   isl::map CopyTo = CopyFrom.apply_range(PackedATranslator);
787 
788   // Create the copy statement and redirect access.
789   ScopStmt *CopyStmt =
790       Stmt->getParent()->addScopStmt(CopyFrom, CopyTo, CopyDomain);
791   MMI.A->setNewAccessRelation(AccRelPackedA);
792 
793   // Insert into the schedule tree.
794   // { Scatter[] -> CopyStmt[] }
795   isl::map ExtScatterCopy = makeIdentityMap(CopyStmt->getDomain(), true);
796   ExtScatterCopy = ExtScatterCopy.project_out(isl::dim::in, 3, 2);
797   return createExtensionNode(Node, ExtScatterCopy);
798 }
799 
800 /// Apply the packing transformation.
801 ///
802 /// The packing transformation can be described as a data-layout
803 /// transformation that requires to introduce a new array, copy data
804 /// to the array, and change memory access locations to reference the array.
805 /// It can be used to ensure that elements of the new array are read in-stride
806 /// access, aligned to cache lines boundaries, and preloaded into certain cache
807 /// levels.
808 ///
809 /// As an example let us consider the packing of the array A that would help
810 /// to read its elements with in-stride access. An access to the array A
811 /// is represented by an access relation that has the form
812 /// S[i, j, k] -> A[i, k]. The scheduling function of the SCoP statement S has
813 /// the form S[i,j, k] -> [floor((j mod Nc) / Nr), floor((i mod Mc) / Mr),
814 /// k mod Kc, j mod Nr, i mod Mr].
815 ///
816 /// To ensure that elements of the array A are read in-stride access, we add
817 /// a new array Packed_A[Mc/Mr][Kc][Mr] to the SCoP, using
818 /// Scop::createScopArrayInfo, change the access relation
819 /// S[i, j, k] -> A[i, k] to
820 /// S[i, j, k] -> Packed_A[floor((i mod Mc) / Mr), k mod Kc, i mod Mr], using
821 /// MemoryAccess::setNewAccessRelation, and copy the data to the array, using
822 /// the copy statement created by Scop::addScopStmt.
823 ///
824 /// @param Node The schedule node to be optimized.
825 /// @param MapOldIndVar The relation, which maps original induction variables
826 ///                     to the ones, which are produced by schedule
827 ///                     transformations.
828 /// @param MicroParams, MacroParams Parameters of the BLIS kernel
829 ///                                 to be taken into account.
830 /// @param MMI Parameters of the matrix multiplication operands.
831 /// @return The optimized schedule node.
832 static isl::schedule_node
optimizeDataLayoutMatrMulPattern(isl::schedule_node Node,isl::map MapOldIndVar,MicroKernelParamsTy MicroParams,MacroKernelParamsTy MacroParams,MatMulInfoTy & MMI)833 optimizeDataLayoutMatrMulPattern(isl::schedule_node Node, isl::map MapOldIndVar,
834                                  MicroKernelParamsTy MicroParams,
835                                  MacroKernelParamsTy MacroParams,
836                                  MatMulInfoTy &MMI) {
837   isl::id InputDimsId = MapOldIndVar.get_tuple_id(isl::dim::in);
838   ScopStmt *Stmt = static_cast<ScopStmt *>(InputDimsId.get_user());
839 
840   Node = Node.parent().parent().parent().parent().parent().parent();
841   Node = isl::manage(isl_schedule_node_band_split(Node.release(), 2));
842 
843   Node = Node.child(0);
844   Node =
845       optimizePackedB(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI);
846 
847   Node = Node.child(0);
848   Node =
849       optimizePackedA(Node, Stmt, MapOldIndVar, MicroParams, MacroParams, MMI);
850 
851   return Node.child(0).child(0).child(0).child(0).child(0);
852 }
853 
854 /// Get a relation mapping induction variables produced by schedule
855 /// transformations to the original ones.
856 ///
857 /// @param Node The schedule node produced as the result of creation
858 ///        of the BLIS kernels.
859 /// @param MicroKernelParams, MacroKernelParams Parameters of the BLIS kernel
860 ///                                             to be taken into account.
861 /// @return  The relation mapping original induction variables to the ones
862 ///          produced by schedule transformation.
863 /// @see ScheduleTreeOptimizer::createMicroKernel
864 /// @see ScheduleTreeOptimizer::createMacroKernel
865 /// @see getMacroKernelParams
866 static isl::map
getInductionVariablesSubstitution(isl::schedule_node Node,MicroKernelParamsTy MicroKernelParams,MacroKernelParamsTy MacroKernelParams)867 getInductionVariablesSubstitution(isl::schedule_node Node,
868                                   MicroKernelParamsTy MicroKernelParams,
869                                   MacroKernelParamsTy MacroKernelParams) {
870   auto Child = Node.child(0);
871   auto UnMapOldIndVar = Child.get_prefix_schedule_union_map();
872   auto MapOldIndVar = isl::map::from_union_map(UnMapOldIndVar);
873   if (MapOldIndVar.range_tuple_dim() > 9)
874     return MapOldIndVar.project_out(isl::dim::out, 0,
875                                     MapOldIndVar.range_tuple_dim() - 9);
876   return MapOldIndVar;
877 }
878 
879 /// Isolate a set of partial tile prefixes and unroll the isolated part.
880 ///
881 /// The set should ensure that it contains only partial tile prefixes that have
882 /// exactly Mr x Nr iterations of the two innermost loops produced by
883 /// the optimization of the matrix multiplication. Mr and Nr are parameters of
884 /// the micro-kernel.
885 ///
886 /// In case of parametric bounds, this helps to auto-vectorize the unrolled
887 /// innermost loops, using the SLP vectorizer.
888 ///
889 /// @param Node              The schedule node to be modified.
890 /// @param MicroKernelParams Parameters of the micro-kernel
891 ///                          to be taken into account.
892 /// @return The modified isl_schedule_node.
893 static isl::schedule_node
isolateAndUnrollMatMulInnerLoops(isl::schedule_node Node,struct MicroKernelParamsTy MicroKernelParams)894 isolateAndUnrollMatMulInnerLoops(isl::schedule_node Node,
895                                  struct MicroKernelParamsTy MicroKernelParams) {
896   isl::schedule_node Child = Node.get_child(0);
897   isl::union_map UnMapOldIndVar = Child.get_prefix_schedule_relation();
898   isl::set Prefix = isl::map::from_union_map(UnMapOldIndVar).range();
899   isl_size Dims = Prefix.tuple_dim();
900   Prefix = Prefix.project_out(isl::dim::set, Dims - 1, 1);
901   Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Nr);
902   Prefix = getPartialTilePrefixes(Prefix, MicroKernelParams.Mr);
903 
904   isl::union_set IsolateOption =
905       getIsolateOptions(Prefix.add_dims(isl::dim::set, 3), 3);
906   isl::ctx Ctx = Node.ctx();
907   auto Options = IsolateOption.unite(getDimOptions(Ctx, "unroll"));
908   Options = Options.unite(getUnrollIsolatedSetOptions(Ctx));
909   Node = Node.band_set_ast_build_options(Options);
910   Node = Node.parent().parent().parent();
911   IsolateOption = getIsolateOptions(Prefix, 3);
912   Options = IsolateOption.unite(getDimOptions(Ctx, "separate"));
913   Node = Node.band_set_ast_build_options(Options);
914   Node = Node.child(0).child(0).child(0);
915   return Node;
916 }
917 
918 /// Mark @p BasePtr with "Inter iteration alias-free" mark node.
919 ///
920 /// @param Node The child of the mark node to be inserted.
921 /// @param BasePtr The pointer to be marked.
922 /// @return The modified isl_schedule_node.
markInterIterationAliasFree(isl::schedule_node Node,Value * BasePtr)923 static isl::schedule_node markInterIterationAliasFree(isl::schedule_node Node,
924                                                       Value *BasePtr) {
925   if (!BasePtr)
926     return Node;
927 
928   auto Id = isl::id::alloc(Node.ctx(), "Inter iteration alias-free", BasePtr);
929   return Node.insert_mark(Id).child(0);
930 }
931 
932 /// Insert "Loop Vectorizer Disabled" mark node.
933 ///
934 /// @param Node The child of the mark node to be inserted.
935 /// @return The modified isl_schedule_node.
markLoopVectorizerDisabled(isl::schedule_node Node)936 static isl::schedule_node markLoopVectorizerDisabled(isl::schedule_node Node) {
937   auto Id = isl::id::alloc(Node.ctx(), "Loop Vectorizer Disabled", nullptr);
938   return Node.insert_mark(Id).child(0);
939 }
940 
941 /// Restore the initial ordering of dimensions of the band node
942 ///
943 /// In case the band node represents all the dimensions of the iteration
944 /// domain, recreate the band node to restore the initial ordering of the
945 /// dimensions.
946 ///
947 /// @param Node The band node to be modified.
948 /// @return The modified schedule node.
949 static isl::schedule_node
getBandNodeWithOriginDimOrder(isl::schedule_node Node)950 getBandNodeWithOriginDimOrder(isl::schedule_node Node) {
951   assert(isl_schedule_node_get_type(Node.get()) == isl_schedule_node_band);
952   if (isl_schedule_node_get_type(Node.child(0).get()) != isl_schedule_node_leaf)
953     return Node;
954   auto Domain = Node.get_universe_domain();
955   assert(isl_union_set_n_set(Domain.get()) == 1);
956   if (Node.get_schedule_depth() != 0 ||
957       (isl::set(Domain).tuple_dim() !=
958        isl_schedule_node_band_n_member(Node.get())))
959     return Node;
960   Node = isl::manage(isl_schedule_node_delete(Node.copy()));
961   auto PartialSchedulePwAff = Domain.identity_union_pw_multi_aff();
962   auto PartialScheduleMultiPwAff =
963       isl::multi_union_pw_aff(PartialSchedulePwAff);
964   PartialScheduleMultiPwAff =
965       PartialScheduleMultiPwAff.reset_tuple_id(isl::dim::set);
966   return Node.insert_partial_schedule(PartialScheduleMultiPwAff);
967 }
968 
optimizeMatMulPattern(isl::schedule_node Node,const TargetTransformInfo * TTI,MatMulInfoTy & MMI)969 static isl::schedule_node optimizeMatMulPattern(isl::schedule_node Node,
970                                                 const TargetTransformInfo *TTI,
971                                                 MatMulInfoTy &MMI) {
972   assert(TTI && "The target transform info should be provided.");
973   Node = markInterIterationAliasFree(
974       Node, MMI.WriteToC->getLatestScopArrayInfo()->getBasePtr());
975   int DimOutNum = isl_schedule_node_band_n_member(Node.get());
976   assert(DimOutNum > 2 && "In case of the matrix multiplication the loop nest "
977                           "and, consequently, the corresponding scheduling "
978                           "functions have at least three dimensions.");
979   Node = getBandNodeWithOriginDimOrder(Node);
980   Node = permuteBandNodeDimensions(Node, MMI.i, DimOutNum - 3);
981   int NewJ = MMI.j == DimOutNum - 3 ? MMI.i : MMI.j;
982   int NewK = MMI.k == DimOutNum - 3 ? MMI.i : MMI.k;
983   Node = permuteBandNodeDimensions(Node, NewJ, DimOutNum - 2);
984   NewK = NewK == DimOutNum - 2 ? NewJ : NewK;
985   Node = permuteBandNodeDimensions(Node, NewK, DimOutNum - 1);
986   auto MicroKernelParams = getMicroKernelParams(TTI, MMI);
987   auto MacroKernelParams = getMacroKernelParams(TTI, MicroKernelParams, MMI);
988   Node = createMacroKernel(Node, MacroKernelParams);
989   Node = createMicroKernel(Node, MicroKernelParams);
990   if (MacroKernelParams.Mc == 1 || MacroKernelParams.Nc == 1 ||
991       MacroKernelParams.Kc == 1)
992     return Node;
993   auto MapOldIndVar = getInductionVariablesSubstitution(Node, MicroKernelParams,
994                                                         MacroKernelParams);
995   if (MapOldIndVar.is_null())
996     return Node;
997   Node = markLoopVectorizerDisabled(Node.parent()).child(0);
998   Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams);
999   return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams,
1000                                           MacroKernelParams, MMI);
1001 }
1002 
1003 /// Check if this node contains a partial schedule that could
1004 ///        probably be optimized with analytical modeling.
1005 ///
1006 /// isMatrMultPattern tries to determine whether the following conditions
1007 /// are true:
1008 /// 1. the partial schedule contains only one statement.
1009 /// 2. there are exactly three input dimensions.
1010 /// 3. all memory accesses of the statement will have stride 0 or 1, if we
1011 ///    interchange loops (switch the variable used in the inner loop to
1012 ///    the outer loop).
1013 /// 4. all memory accesses of the statement except from the last one, are
1014 ///    read memory access and the last one is write memory access.
1015 /// 5. all subscripts of the last memory access of the statement don't
1016 ///    contain the variable used in the inner loop.
1017 /// If this is the case, we could try to use an approach that is similar to
1018 /// the one used to get close-to-peak performance of matrix multiplications.
1019 ///
1020 /// @param Node The node to check.
1021 /// @param D    The SCoP dependencies.
1022 /// @param MMI  Parameters of the matrix multiplication operands.
isMatrMultPattern(isl::schedule_node Node,const Dependences * D,MatMulInfoTy & MMI)1023 static bool isMatrMultPattern(isl::schedule_node Node, const Dependences *D,
1024                               MatMulInfoTy &MMI) {
1025   auto PartialSchedule = isl::manage(
1026       isl_schedule_node_band_get_partial_schedule_union_map(Node.get()));
1027   Node = Node.child(0);
1028   auto LeafType = isl_schedule_node_get_type(Node.get());
1029   Node = Node.parent();
1030   if (LeafType != isl_schedule_node_leaf ||
1031       isl_schedule_node_band_n_member(Node.get()) < 3 ||
1032       Node.get_schedule_depth() != 0 ||
1033       isl_union_map_n_map(PartialSchedule.get()) != 1)
1034     return false;
1035   auto NewPartialSchedule = isl::map::from_union_map(PartialSchedule);
1036   if (containsMatrMult(NewPartialSchedule, D, MMI))
1037     return true;
1038   return false;
1039 }
1040 
1041 } // namespace
1042 
1043 isl::schedule_node
tryOptimizeMatMulPattern(isl::schedule_node Node,const llvm::TargetTransformInfo * TTI,const Dependences * D)1044 polly::tryOptimizeMatMulPattern(isl::schedule_node Node,
1045                                 const llvm::TargetTransformInfo *TTI,
1046                                 const Dependences *D) {
1047   MatMulInfoTy MMI;
1048   if (isMatrMultPattern(Node, D, MMI)) {
1049     LLVM_DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
1050     return optimizeMatMulPattern(Node, TTI, MMI);
1051   }
1052   return {};
1053 }
1054