1 //===------ FlattenAlgo.cpp ------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Main algorithm of the FlattenSchedulePass. This is a separate file to avoid
10 // the unittest for this requiring linking against LLVM.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "polly/FlattenAlgo.h"
15 #include "polly/Support/ISLOStream.h"
16 #include "polly/Support/ISLTools.h"
17 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "polly-flatten-algo"
19 
20 using namespace polly;
21 using namespace llvm;
22 
23 namespace {
24 
25 /// Whether a dimension of a set is bounded (lower and upper) by a constant,
26 /// i.e. there are two constants Min and Max, such that every value x of the
27 /// chosen dimensions is Min <= x <= Max.
isDimBoundedByConstant(isl::set Set,unsigned dim)28 bool isDimBoundedByConstant(isl::set Set, unsigned dim) {
29   auto ParamDims = Set.dim(isl::dim::param);
30   Set = Set.project_out(isl::dim::param, 0, ParamDims);
31   Set = Set.project_out(isl::dim::set, 0, dim);
32   auto SetDims = Set.tuple_dim();
33   Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
34   return bool(Set.is_bounded());
35 }
36 
37 /// Whether a dimension of a set is (lower and upper) bounded by a constant or
38 /// parameters, i.e. there are two expressions Min_p and Max_p of the parameters
39 /// p, such that every value x of the chosen dimensions is
40 /// Min_p <= x <= Max_p.
isDimBoundedByParameter(isl::set Set,unsigned dim)41 bool isDimBoundedByParameter(isl::set Set, unsigned dim) {
42   Set = Set.project_out(isl::dim::set, 0, dim);
43   auto SetDims = Set.tuple_dim();
44   Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
45   return bool(Set.is_bounded());
46 }
47 
48 /// Whether BMap's first out-dimension is not a constant.
isVariableDim(const isl::basic_map & BMap)49 bool isVariableDim(const isl::basic_map &BMap) {
50   auto FixedVal = BMap.plain_get_val_if_fixed(isl::dim::out, 0);
51   return FixedVal.is_null() || FixedVal.is_nan();
52 }
53 
54 /// Whether Map's first out dimension is no constant nor piecewise constant.
isVariableDim(const isl::map & Map)55 bool isVariableDim(const isl::map &Map) {
56   for (isl::basic_map BMap : Map.get_basic_map_list())
57     if (isVariableDim(BMap))
58       return false;
59 
60   return true;
61 }
62 
63 /// Whether UMap's first out dimension is no (piecewise) constant.
isVariableDim(const isl::union_map & UMap)64 bool isVariableDim(const isl::union_map &UMap) {
65   for (isl::map Map : UMap.get_map_list())
66     if (isVariableDim(Map))
67       return false;
68   return true;
69 }
70 
71 /// Compute @p UPwAff - @p Val.
subtract(isl::union_pw_aff UPwAff,isl::val Val)72 isl::union_pw_aff subtract(isl::union_pw_aff UPwAff, isl::val Val) {
73   if (Val.is_zero())
74     return UPwAff;
75 
76   auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
77   isl::stat Stat =
78       UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
79         auto ValAff =
80             isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
81         auto Subtracted = PwAff.sub(ValAff);
82         Result = Result.union_add(isl::union_pw_aff(Subtracted));
83         return isl::stat::ok();
84       });
85   if (Stat.is_error())
86     return {};
87   return Result;
88 }
89 
90 /// Compute @UPwAff * @p Val.
multiply(isl::union_pw_aff UPwAff,isl::val Val)91 isl::union_pw_aff multiply(isl::union_pw_aff UPwAff, isl::val Val) {
92   if (Val.is_one())
93     return UPwAff;
94 
95   auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
96   isl::stat Stat =
97       UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
98         auto ValAff =
99             isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
100         auto Multiplied = PwAff.mul(ValAff);
101         Result = Result.union_add(Multiplied);
102         return isl::stat::ok();
103       });
104   if (Stat.is_error())
105     return {};
106   return Result;
107 }
108 
109 /// Remove @p n dimensions from @p UMap's range, starting at @p first.
110 ///
111 /// It is assumed that all maps in the maps have at least the necessary number
112 /// of out dimensions.
scheduleProjectOut(const isl::union_map & UMap,unsigned first,unsigned n)113 isl::union_map scheduleProjectOut(const isl::union_map &UMap, unsigned first,
114                                   unsigned n) {
115   if (n == 0)
116     return UMap; /* isl_map_project_out would also reset the tuple, which should
117                     have no effect on schedule ranges */
118 
119   auto Result = isl::union_map::empty(UMap.ctx());
120   for (isl::map Map : UMap.get_map_list()) {
121     auto Outprojected = Map.project_out(isl::dim::out, first, n);
122     Result = Result.unite(Outprojected);
123   }
124   return Result;
125 }
126 
127 /// Return the number of dimensions in the input map's range.
128 ///
129 /// Because this function takes an isl_union_map, the out dimensions could be
130 /// different. We return the maximum number in this case. However, a different
131 /// number of dimensions is not supported by the other code in this file.
scheduleScatterDims(const isl::union_map & Schedule)132 isl_size scheduleScatterDims(const isl::union_map &Schedule) {
133   isl_size Dims = 0;
134   for (isl::map Map : Schedule.get_map_list()) {
135     if (Map.is_null())
136       continue;
137 
138     Dims = std::max(Dims, Map.range_tuple_dim());
139   }
140   return Dims;
141 }
142 
143 /// Return the @p pos' range dimension, converted to an isl_union_pw_aff.
scheduleExtractDimAff(isl::union_map UMap,unsigned pos)144 isl::union_pw_aff scheduleExtractDimAff(isl::union_map UMap, unsigned pos) {
145   auto SingleUMap = isl::union_map::empty(UMap.ctx());
146   for (isl::map Map : UMap.get_map_list()) {
147     unsigned MapDims = Map.range_tuple_dim();
148     isl::map SingleMap = Map.project_out(isl::dim::out, 0, pos);
149     SingleMap = SingleMap.project_out(isl::dim::out, 1, MapDims - pos - 1);
150     SingleUMap = SingleUMap.unite(SingleMap);
151   };
152 
153   auto UAff = isl::union_pw_multi_aff(SingleUMap);
154   auto FirstMAff = isl::multi_union_pw_aff(UAff);
155   return FirstMAff.get_union_pw_aff(0);
156 }
157 
158 /// Flatten a sequence-like first dimension.
159 ///
160 /// A sequence-like scatter dimension is constant, or at least only small
161 /// variation, typically the result of ordering a sequence of different
162 /// statements. An example would be:
163 ///   { Stmt_A[] -> [0, X, ...]; Stmt_B[] -> [1, Y, ...] }
164 /// to schedule all instances of Stmt_A before any instance of Stmt_B.
165 ///
166 /// To flatten, first begin with an offset of zero. Then determine the lowest
167 /// possible value of the dimension, call it "i" [In the example we start at 0].
168 /// Considering only schedules with that value, consider only instances with
169 /// that value and determine the extent of the next dimension. Let l_X(i) and
170 /// u_X(i) its minimum (lower bound) and maximum (upper bound) value. Add them
171 /// as "Offset + X - l_X(i)" to the new schedule, then add "u_X(i) - l_X(i) + 1"
172 /// to Offset and remove all i-instances from the old schedule. Repeat with the
173 /// remaining lowest value i' until there are no instances in the old schedule
174 /// left.
175 /// The example schedule would be transformed to:
176 ///   { Stmt_X[] -> [X - l_X, ...]; Stmt_B -> [l_X - u_X + 1 + Y - l_Y, ...] }
tryFlattenSequence(isl::union_map Schedule)177 isl::union_map tryFlattenSequence(isl::union_map Schedule) {
178   auto IslCtx = Schedule.ctx();
179   auto ScatterSet = isl::set(Schedule.range());
180 
181   auto ParamSpace = Schedule.get_space().params();
182   auto Dims = ScatterSet.tuple_dim();
183   assert(Dims >= 2);
184 
185   // Would cause an infinite loop.
186   if (!isDimBoundedByConstant(ScatterSet, 0)) {
187     LLVM_DEBUG(dbgs() << "Abort; dimension is not of fixed size\n");
188     return {};
189   }
190 
191   auto AllDomains = Schedule.domain();
192   auto AllDomainsToNull = isl::union_pw_multi_aff(AllDomains);
193 
194   auto NewSchedule = isl::union_map::empty(ParamSpace.ctx());
195   auto Counter = isl::pw_aff(isl::local_space(ParamSpace.set_from_params()));
196 
197   while (!ScatterSet.is_empty()) {
198     LLVM_DEBUG(dbgs() << "Next counter:\n  " << Counter << "\n");
199     LLVM_DEBUG(dbgs() << "Remaining scatter set:\n  " << ScatterSet << "\n");
200     auto ThisSet = ScatterSet.project_out(isl::dim::set, 1, Dims - 1);
201     auto ThisFirst = ThisSet.lexmin();
202     auto ScatterFirst = ThisFirst.add_dims(isl::dim::set, Dims - 1);
203 
204     auto SubSchedule = Schedule.intersect_range(ScatterFirst);
205     SubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
206     SubSchedule = flattenSchedule(SubSchedule);
207 
208     auto SubDims = scheduleScatterDims(SubSchedule);
209     auto FirstSubSchedule = scheduleProjectOut(SubSchedule, 1, SubDims - 1);
210     auto FirstScheduleAff = scheduleExtractDimAff(FirstSubSchedule, 0);
211     auto RemainingSubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
212 
213     auto FirstSubScatter = isl::set(FirstSubSchedule.range());
214     LLVM_DEBUG(dbgs() << "Next step in sequence is:\n  " << FirstSubScatter
215                       << "\n");
216 
217     if (!isDimBoundedByParameter(FirstSubScatter, 0)) {
218       LLVM_DEBUG(dbgs() << "Abort; sequence step is not bounded\n");
219       return {};
220     }
221 
222     auto FirstSubScatterMap = isl::map::from_range(FirstSubScatter);
223 
224     // isl_set_dim_max returns a strange isl_pw_aff with domain tuple_id of
225     // 'none'. It doesn't match with any space including a 0-dimensional
226     // anonymous tuple.
227     // Interesting, one can create such a set using
228     // isl_set_universe(ParamSpace). Bug?
229     auto PartMin = FirstSubScatterMap.dim_min(0);
230     auto PartMax = FirstSubScatterMap.dim_max(0);
231     auto One = isl::pw_aff(isl::set::universe(ParamSpace.set_from_params()),
232                            isl::val::one(IslCtx));
233     auto PartLen = PartMax.add(PartMin.neg()).add(One);
234 
235     auto AllPartMin = isl::union_pw_aff(PartMin).pullback(AllDomainsToNull);
236     auto FirstScheduleAffNormalized = FirstScheduleAff.sub(AllPartMin);
237     auto AllCounter = isl::union_pw_aff(Counter).pullback(AllDomainsToNull);
238     auto FirstScheduleAffWithOffset =
239         FirstScheduleAffNormalized.add(AllCounter);
240 
241     auto ScheduleWithOffset = isl::union_map(FirstScheduleAffWithOffset)
242                                   .flat_range_product(RemainingSubSchedule);
243     NewSchedule = NewSchedule.unite(ScheduleWithOffset);
244 
245     ScatterSet = ScatterSet.subtract(ScatterFirst);
246     Counter = Counter.add(PartLen);
247   }
248 
249   LLVM_DEBUG(dbgs() << "Sequence-flatten result is:\n  " << NewSchedule
250                     << "\n");
251   return NewSchedule;
252 }
253 
254 /// Flatten a loop-like first dimension.
255 ///
256 /// A loop-like dimension is one that depends on a variable (usually a loop's
257 /// induction variable). Let the input schedule look like this:
258 ///   { Stmt[i] -> [i, X, ...] }
259 ///
260 /// To flatten, we determine the largest extent of X which may not depend on the
261 /// actual value of i. Let l_X() the smallest possible value of X and u_X() its
262 /// largest value. Then, construct a new schedule
263 ///   { Stmt[i] -> [i * (u_X() - l_X() + 1), ...] }
tryFlattenLoop(isl::union_map Schedule)264 isl::union_map tryFlattenLoop(isl::union_map Schedule) {
265   assert(scheduleScatterDims(Schedule) >= 2);
266 
267   auto Remaining = scheduleProjectOut(Schedule, 0, 1);
268   auto SubSchedule = flattenSchedule(Remaining);
269   auto SubDims = scheduleScatterDims(SubSchedule);
270 
271   auto SubExtent = isl::set(SubSchedule.range());
272   auto SubExtentDims = SubExtent.dim(isl::dim::param);
273   SubExtent = SubExtent.project_out(isl::dim::param, 0, SubExtentDims);
274   SubExtent = SubExtent.project_out(isl::dim::set, 1, SubDims - 1);
275 
276   if (!isDimBoundedByConstant(SubExtent, 0)) {
277     LLVM_DEBUG(dbgs() << "Abort; dimension not bounded by constant\n");
278     return {};
279   }
280 
281   auto Min = SubExtent.dim_min(0);
282   LLVM_DEBUG(dbgs() << "Min bound:\n  " << Min << "\n");
283   auto MinVal = getConstant(Min, false, true);
284   auto Max = SubExtent.dim_max(0);
285   LLVM_DEBUG(dbgs() << "Max bound:\n  " << Max << "\n");
286   auto MaxVal = getConstant(Max, true, false);
287 
288   if (MinVal.is_null() || MaxVal.is_null() || MinVal.is_nan() ||
289       MaxVal.is_nan()) {
290     LLVM_DEBUG(dbgs() << "Abort; dimension bounds could not be determined\n");
291     return {};
292   }
293 
294   auto FirstSubScheduleAff = scheduleExtractDimAff(SubSchedule, 0);
295   auto RemainingSubSchedule = scheduleProjectOut(std::move(SubSchedule), 0, 1);
296 
297   auto LenVal = MaxVal.sub(MinVal).add_ui(1);
298   auto FirstSubScheduleNormalized = subtract(FirstSubScheduleAff, MinVal);
299 
300   // TODO: Normalize FirstAff to zero (convert to isl_map, determine minimum,
301   // subtract it)
302   auto FirstAff = scheduleExtractDimAff(Schedule, 0);
303   auto Offset = multiply(FirstAff, LenVal);
304   auto Index = FirstSubScheduleNormalized.add(Offset);
305   auto IndexMap = isl::union_map(Index);
306 
307   auto Result = IndexMap.flat_range_product(RemainingSubSchedule);
308   LLVM_DEBUG(dbgs() << "Loop-flatten result is:\n  " << Result << "\n");
309   return Result;
310 }
311 } // anonymous namespace
312 
flattenSchedule(isl::union_map Schedule)313 isl::union_map polly::flattenSchedule(isl::union_map Schedule) {
314   auto Dims = scheduleScatterDims(Schedule);
315   LLVM_DEBUG(dbgs() << "Recursive schedule to process:\n  " << Schedule
316                     << "\n");
317 
318   // Base case; no dimensions left
319   if (Dims == 0) {
320     // TODO: Add one dimension?
321     return Schedule;
322   }
323 
324   // Base case; already one-dimensional
325   if (Dims == 1)
326     return Schedule;
327 
328   // Fixed dimension; no need to preserve variabledness.
329   if (!isVariableDim(Schedule)) {
330     LLVM_DEBUG(dbgs() << "Fixed dimension; try sequence flattening\n");
331     auto NewScheduleSequence = tryFlattenSequence(Schedule);
332     if (!NewScheduleSequence.is_null())
333       return NewScheduleSequence;
334   }
335 
336   // Constant stride
337   LLVM_DEBUG(dbgs() << "Try loop flattening\n");
338   auto NewScheduleLoop = tryFlattenLoop(Schedule);
339   if (!NewScheduleLoop.is_null())
340     return NewScheduleLoop;
341 
342   // Try again without loop condition (may blow up the number of pieces!!)
343   LLVM_DEBUG(dbgs() << "Try sequence flattening again\n");
344   auto NewScheduleSequence = tryFlattenSequence(Schedule);
345   if (!NewScheduleSequence.is_null())
346     return NewScheduleSequence;
347 
348   // Cannot flatten
349   return Schedule;
350 }
351