1 //===- polly/ScheduleTreeTransform.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Make changes to isl's schedule tree data structure.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "polly/ScheduleTreeTransform.h"
14 #include "polly/Support/ISLTools.h"
15 #include "llvm/ADT/ArrayRef.h"
16 #include "llvm/ADT/SmallVector.h"
17 
18 using namespace polly;
19 
20 namespace {
21 
22 /// This class defines a simple visitor class that may be used for
23 /// various schedule tree analysis purposes.
24 template <typename Derived, typename RetTy = void, typename... Args>
25 struct ScheduleTreeVisitor {
getDerived__anonbc368a020111::ScheduleTreeVisitor26   Derived &getDerived() { return *static_cast<Derived *>(this); }
getDerived__anonbc368a020111::ScheduleTreeVisitor27   const Derived &getDerived() const {
28     return *static_cast<const Derived *>(this);
29   }
30 
visit__anonbc368a020111::ScheduleTreeVisitor31   RetTy visit(const isl::schedule_node &Node, Args... args) {
32     assert(!Node.is_null());
33     switch (isl_schedule_node_get_type(Node.get())) {
34     case isl_schedule_node_domain:
35       assert(isl_schedule_node_n_children(Node.get()) == 1);
36       return getDerived().visitDomain(Node, std::forward<Args>(args)...);
37     case isl_schedule_node_band:
38       assert(isl_schedule_node_n_children(Node.get()) == 1);
39       return getDerived().visitBand(Node, std::forward<Args>(args)...);
40     case isl_schedule_node_sequence:
41       assert(isl_schedule_node_n_children(Node.get()) >= 2);
42       return getDerived().visitSequence(Node, std::forward<Args>(args)...);
43     case isl_schedule_node_set:
44       return getDerived().visitSet(Node, std::forward<Args>(args)...);
45       assert(isl_schedule_node_n_children(Node.get()) >= 2);
46     case isl_schedule_node_leaf:
47       assert(isl_schedule_node_n_children(Node.get()) == 0);
48       return getDerived().visitLeaf(Node, std::forward<Args>(args)...);
49     case isl_schedule_node_mark:
50       assert(isl_schedule_node_n_children(Node.get()) == 1);
51       return getDerived().visitMark(Node, std::forward<Args>(args)...);
52     case isl_schedule_node_extension:
53       assert(isl_schedule_node_n_children(Node.get()) == 1);
54       return getDerived().visitExtension(Node, std::forward<Args>(args)...);
55     case isl_schedule_node_filter:
56       assert(isl_schedule_node_n_children(Node.get()) == 1);
57       return getDerived().visitFilter(Node, std::forward<Args>(args)...);
58     default:
59       llvm_unreachable("unimplemented schedule node type");
60     }
61   }
62 
visitDomain__anonbc368a020111::ScheduleTreeVisitor63   RetTy visitDomain(const isl::schedule_node &Domain, Args... args) {
64     return getDerived().visitSingleChild(Domain, std::forward<Args>(args)...);
65   }
66 
visitBand__anonbc368a020111::ScheduleTreeVisitor67   RetTy visitBand(const isl::schedule_node &Band, Args... args) {
68     return getDerived().visitSingleChild(Band, std::forward<Args>(args)...);
69   }
70 
visitSequence__anonbc368a020111::ScheduleTreeVisitor71   RetTy visitSequence(const isl::schedule_node &Sequence, Args... args) {
72     return getDerived().visitMultiChild(Sequence, std::forward<Args>(args)...);
73   }
74 
visitSet__anonbc368a020111::ScheduleTreeVisitor75   RetTy visitSet(const isl::schedule_node &Set, Args... args) {
76     return getDerived().visitMultiChild(Set, std::forward<Args>(args)...);
77   }
78 
visitLeaf__anonbc368a020111::ScheduleTreeVisitor79   RetTy visitLeaf(const isl::schedule_node &Leaf, Args... args) {
80     return getDerived().visitNode(Leaf, std::forward<Args>(args)...);
81   }
82 
visitMark__anonbc368a020111::ScheduleTreeVisitor83   RetTy visitMark(const isl::schedule_node &Mark, Args... args) {
84     return getDerived().visitSingleChild(Mark, std::forward<Args>(args)...);
85   }
86 
visitExtension__anonbc368a020111::ScheduleTreeVisitor87   RetTy visitExtension(const isl::schedule_node &Extension, Args... args) {
88     return getDerived().visitSingleChild(Extension,
89                                          std::forward<Args>(args)...);
90   }
91 
visitFilter__anonbc368a020111::ScheduleTreeVisitor92   RetTy visitFilter(const isl::schedule_node &Extension, Args... args) {
93     return getDerived().visitSingleChild(Extension,
94                                          std::forward<Args>(args)...);
95   }
96 
visitSingleChild__anonbc368a020111::ScheduleTreeVisitor97   RetTy visitSingleChild(const isl::schedule_node &Node, Args... args) {
98     return getDerived().visitNode(Node, std::forward<Args>(args)...);
99   }
100 
visitMultiChild__anonbc368a020111::ScheduleTreeVisitor101   RetTy visitMultiChild(const isl::schedule_node &Node, Args... args) {
102     return getDerived().visitNode(Node, std::forward<Args>(args)...);
103   }
104 
visitNode__anonbc368a020111::ScheduleTreeVisitor105   RetTy visitNode(const isl::schedule_node &Node, Args... args) {
106     llvm_unreachable("Unimplemented other");
107   }
108 };
109 
110 /// Recursively visit all nodes of a schedule tree.
111 template <typename Derived, typename RetTy = void, typename... Args>
112 struct RecursiveScheduleTreeVisitor
113     : public ScheduleTreeVisitor<Derived, RetTy, Args...> {
114   using BaseTy = ScheduleTreeVisitor<Derived, RetTy, Args...>;
getBase__anonbc368a020111::RecursiveScheduleTreeVisitor115   BaseTy &getBase() { return *this; }
getBase__anonbc368a020111::RecursiveScheduleTreeVisitor116   const BaseTy &getBase() const { return *this; }
getDerived__anonbc368a020111::RecursiveScheduleTreeVisitor117   Derived &getDerived() { return *static_cast<Derived *>(this); }
getDerived__anonbc368a020111::RecursiveScheduleTreeVisitor118   const Derived &getDerived() const {
119     return *static_cast<const Derived *>(this);
120   }
121 
122   /// When visiting an entire schedule tree, start at its root node.
visit__anonbc368a020111::RecursiveScheduleTreeVisitor123   RetTy visit(const isl::schedule &Schedule, Args... args) {
124     return getDerived().visit(Schedule.get_root(), std::forward<Args>(args)...);
125   }
126 
127   // Necessary to allow overload resolution with the added visit(isl::schedule)
128   // overload.
visit__anonbc368a020111::RecursiveScheduleTreeVisitor129   RetTy visit(const isl::schedule_node &Node, Args... args) {
130     return getBase().visit(Node, std::forward<Args>(args)...);
131   }
132 
visitNode__anonbc368a020111::RecursiveScheduleTreeVisitor133   RetTy visitNode(const isl::schedule_node &Node, Args... args) {
134     int NumChildren = isl_schedule_node_n_children(Node.get());
135     for (int i = 0; i < NumChildren; i += 1)
136       getDerived().visit(Node.child(i), std::forward<Args>(args)...);
137     return RetTy();
138   }
139 };
140 
141 /// Recursively visit all nodes of a schedule tree while allowing changes.
142 ///
143 /// The visit methods return an isl::schedule_node that is used to continue
144 /// visiting the tree. Structural changes such as returning a different node
145 /// will confuse the visitor.
146 template <typename Derived, typename... Args>
147 struct ScheduleNodeRewriter
148     : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node,
149                                           Args...> {
getDerived__anonbc368a020111::ScheduleNodeRewriter150   Derived &getDerived() { return *static_cast<Derived *>(this); }
getDerived__anonbc368a020111::ScheduleNodeRewriter151   const Derived &getDerived() const {
152     return *static_cast<const Derived *>(this);
153   }
154 
visitNode__anonbc368a020111::ScheduleNodeRewriter155   isl::schedule_node visitNode(const isl::schedule_node &Node, Args... args) {
156     if (!Node.has_children())
157       return Node;
158 
159     isl::schedule_node It = Node.first_child();
160     while (true) {
161       It = getDerived().visit(It, std::forward<Args>(args)...);
162       if (!It.has_next_sibling())
163         break;
164       It = It.next_sibling();
165     }
166     return It.parent();
167   }
168 };
169 
170 /// Rewrite a schedule tree by reconstructing it bottom-up.
171 ///
172 /// By default, the original schedule tree is reconstructed. To build a
173 /// different tree, redefine visitor methods in a derived class (CRTP).
174 ///
175 /// Note that AST build options are not applied; Setting the isolate[] option
176 /// makes the schedule tree 'anchored' and cannot be modified afterwards. Hence,
177 /// AST build options must be set after the tree has been constructed.
178 template <typename Derived, typename... Args>
179 struct ScheduleTreeRewriter
180     : public RecursiveScheduleTreeVisitor<Derived, isl::schedule, Args...> {
getDerived__anonbc368a020111::ScheduleTreeRewriter181   Derived &getDerived() { return *static_cast<Derived *>(this); }
getDerived__anonbc368a020111::ScheduleTreeRewriter182   const Derived &getDerived() const {
183     return *static_cast<const Derived *>(this);
184   }
185 
visitDomain__anonbc368a020111::ScheduleTreeRewriter186   isl::schedule visitDomain(const isl::schedule_node &Node, Args... args) {
187     // Every schedule_tree already has a domain node, no need to add one.
188     return getDerived().visit(Node.first_child(), std::forward<Args>(args)...);
189   }
190 
visitBand__anonbc368a020111::ScheduleTreeRewriter191   isl::schedule visitBand(const isl::schedule_node &Band, Args... args) {
192     isl::multi_union_pw_aff PartialSched =
193         isl::manage(isl_schedule_node_band_get_partial_schedule(Band.get()));
194     isl::schedule NewChild =
195         getDerived().visit(Band.child(0), std::forward<Args>(args)...);
196     isl::schedule_node NewNode =
197         NewChild.insert_partial_schedule(PartialSched).get_root().get_child(0);
198 
199     // Reapply permutability and coincidence attributes.
200     NewNode = isl::manage(isl_schedule_node_band_set_permutable(
201         NewNode.release(), isl_schedule_node_band_get_permutable(Band.get())));
202     unsigned BandDims = isl_schedule_node_band_n_member(Band.get());
203     for (unsigned i = 0; i < BandDims; i += 1)
204       NewNode = isl::manage(isl_schedule_node_band_member_set_coincident(
205           NewNode.release(), i,
206           isl_schedule_node_band_member_get_coincident(Band.get(), i)));
207 
208     return NewNode.get_schedule();
209   }
210 
visitSequence__anonbc368a020111::ScheduleTreeRewriter211   isl::schedule visitSequence(const isl::schedule_node &Sequence,
212                               Args... args) {
213     int NumChildren = isl_schedule_node_n_children(Sequence.get());
214     isl::schedule Result =
215         getDerived().visit(Sequence.child(0), std::forward<Args>(args)...);
216     for (int i = 1; i < NumChildren; i += 1)
217       Result = Result.sequence(
218           getDerived().visit(Sequence.child(i), std::forward<Args>(args)...));
219     return Result;
220   }
221 
visitSet__anonbc368a020111::ScheduleTreeRewriter222   isl::schedule visitSet(const isl::schedule_node &Set, Args... args) {
223     int NumChildren = isl_schedule_node_n_children(Set.get());
224     isl::schedule Result =
225         getDerived().visit(Set.child(0), std::forward<Args>(args)...);
226     for (int i = 1; i < NumChildren; i += 1)
227       Result = isl::manage(
228           isl_schedule_set(Result.release(),
229                            getDerived()
230                                .visit(Set.child(i), std::forward<Args>(args)...)
231                                .release()));
232     return Result;
233   }
234 
visitLeaf__anonbc368a020111::ScheduleTreeRewriter235   isl::schedule visitLeaf(const isl::schedule_node &Leaf, Args... args) {
236     return isl::schedule::from_domain(Leaf.get_domain());
237   }
238 
visitMark__anonbc368a020111::ScheduleTreeRewriter239   isl::schedule visitMark(const isl::schedule_node &Mark, Args... args) {
240     isl::id TheMark = Mark.mark_get_id();
241     isl::schedule_node NewChild =
242         getDerived()
243             .visit(Mark.first_child(), std::forward<Args>(args)...)
244             .get_root()
245             .first_child();
246     return NewChild.insert_mark(TheMark).get_schedule();
247   }
248 
visitExtension__anonbc368a020111::ScheduleTreeRewriter249   isl::schedule visitExtension(const isl::schedule_node &Extension,
250                                Args... args) {
251     isl::union_map TheExtension = Extension.extension_get_extension();
252     isl::schedule_node NewChild = getDerived()
253                                       .visit(Extension.child(0), args...)
254                                       .get_root()
255                                       .first_child();
256     isl::schedule_node NewExtension =
257         isl::schedule_node::from_extension(TheExtension);
258     return NewChild.graft_before(NewExtension).get_schedule();
259   }
260 
visitFilter__anonbc368a020111::ScheduleTreeRewriter261   isl::schedule visitFilter(const isl::schedule_node &Filter, Args... args) {
262     isl::union_set FilterDomain = Filter.filter_get_filter();
263     isl::schedule NewSchedule =
264         getDerived().visit(Filter.child(0), std::forward<Args>(args)...);
265     return NewSchedule.intersect_domain(FilterDomain);
266   }
267 
visitNode__anonbc368a020111::ScheduleTreeRewriter268   isl::schedule visitNode(const isl::schedule_node &Node, Args... args) {
269     llvm_unreachable("Not implemented");
270   }
271 };
272 
273 /// Rewrite a schedule tree to an equivalent one without extension nodes.
274 ///
275 /// Each visit method takes two additional arguments:
276 ///
277 ///  * The new domain the node, which is the inherited domain plus any domains
278 ///    added by extension nodes.
279 ///
280 ///  * A map of extension domains of all children is returned; it is required by
281 ///    band nodes to schedule the additional domains at the same position as the
282 ///    extension node would.
283 ///
284 struct ExtensionNodeRewriter
285     : public ScheduleTreeRewriter<ExtensionNodeRewriter, const isl::union_set &,
286                                   isl::union_map &> {
287   using BaseTy = ScheduleTreeRewriter<ExtensionNodeRewriter,
288                                       const isl::union_set &, isl::union_map &>;
getBase__anonbc368a020111::ExtensionNodeRewriter289   BaseTy &getBase() { return *this; }
getBase__anonbc368a020111::ExtensionNodeRewriter290   const BaseTy &getBase() const { return *this; }
291 
visitSchedule__anonbc368a020111::ExtensionNodeRewriter292   isl::schedule visitSchedule(const isl::schedule &Schedule) {
293     isl::union_map Extensions;
294     isl::schedule Result =
295         visit(Schedule.get_root(), Schedule.get_domain(), Extensions);
296     assert(Extensions && Extensions.is_empty());
297     return Result;
298   }
299 
visitSequence__anonbc368a020111::ExtensionNodeRewriter300   isl::schedule visitSequence(const isl::schedule_node &Sequence,
301                               const isl::union_set &Domain,
302                               isl::union_map &Extensions) {
303     int NumChildren = isl_schedule_node_n_children(Sequence.get());
304     isl::schedule NewNode = visit(Sequence.first_child(), Domain, Extensions);
305     for (int i = 1; i < NumChildren; i += 1) {
306       isl::schedule_node OldChild = Sequence.child(i);
307       isl::union_map NewChildExtensions;
308       isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
309       NewNode = NewNode.sequence(NewChildNode);
310       Extensions = Extensions.unite(NewChildExtensions);
311     }
312     return NewNode;
313   }
314 
visitSet__anonbc368a020111::ExtensionNodeRewriter315   isl::schedule visitSet(const isl::schedule_node &Set,
316                          const isl::union_set &Domain,
317                          isl::union_map &Extensions) {
318     int NumChildren = isl_schedule_node_n_children(Set.get());
319     isl::schedule NewNode = visit(Set.first_child(), Domain, Extensions);
320     for (int i = 1; i < NumChildren; i += 1) {
321       isl::schedule_node OldChild = Set.child(i);
322       isl::union_map NewChildExtensions;
323       isl::schedule NewChildNode = visit(OldChild, Domain, NewChildExtensions);
324       NewNode = isl::manage(
325           isl_schedule_set(NewNode.release(), NewChildNode.release()));
326       Extensions = Extensions.unite(NewChildExtensions);
327     }
328     return NewNode;
329   }
330 
visitLeaf__anonbc368a020111::ExtensionNodeRewriter331   isl::schedule visitLeaf(const isl::schedule_node &Leaf,
332                           const isl::union_set &Domain,
333                           isl::union_map &Extensions) {
334     isl::ctx Ctx = Leaf.get_ctx();
335     Extensions = isl::union_map::empty(isl::space::params_alloc(Ctx, 0));
336     return isl::schedule::from_domain(Domain);
337   }
338 
visitBand__anonbc368a020111::ExtensionNodeRewriter339   isl::schedule visitBand(const isl::schedule_node &OldNode,
340                           const isl::union_set &Domain,
341                           isl::union_map &OuterExtensions) {
342     isl::schedule_node OldChild = OldNode.first_child();
343     isl::multi_union_pw_aff PartialSched =
344         isl::manage(isl_schedule_node_band_get_partial_schedule(OldNode.get()));
345 
346     isl::union_map NewChildExtensions;
347     isl::schedule NewChild = visit(OldChild, Domain, NewChildExtensions);
348 
349     // Add the extensions to the partial schedule.
350     OuterExtensions = isl::union_map::empty(NewChildExtensions.get_space());
351     isl::union_map NewPartialSchedMap = isl::union_map::from(PartialSched);
352     unsigned BandDims = isl_schedule_node_band_n_member(OldNode.get());
353     for (isl::map Ext : NewChildExtensions.get_map_list()) {
354       unsigned ExtDims = Ext.dim(isl::dim::in);
355       assert(ExtDims >= BandDims);
356       unsigned OuterDims = ExtDims - BandDims;
357 
358       isl::map BandSched =
359           Ext.project_out(isl::dim::in, 0, OuterDims).reverse();
360       NewPartialSchedMap = NewPartialSchedMap.unite(BandSched);
361 
362       // There might be more outer bands that have to schedule the extensions.
363       if (OuterDims > 0) {
364         isl::map OuterSched =
365             Ext.project_out(isl::dim::in, OuterDims, BandDims);
366         OuterExtensions = OuterExtensions.add_map(OuterSched);
367       }
368     }
369     isl::multi_union_pw_aff NewPartialSchedAsAsMultiUnionPwAff =
370         isl::multi_union_pw_aff::from_union_map(NewPartialSchedMap);
371     isl::schedule_node NewNode =
372         NewChild.insert_partial_schedule(NewPartialSchedAsAsMultiUnionPwAff)
373             .get_root()
374             .get_child(0);
375 
376     // Reapply permutability and coincidence attributes.
377     NewNode = isl::manage(isl_schedule_node_band_set_permutable(
378         NewNode.release(),
379         isl_schedule_node_band_get_permutable(OldNode.get())));
380     for (unsigned i = 0; i < BandDims; i += 1) {
381       NewNode = isl::manage(isl_schedule_node_band_member_set_coincident(
382           NewNode.release(), i,
383           isl_schedule_node_band_member_get_coincident(OldNode.get(), i)));
384     }
385 
386     return NewNode.get_schedule();
387   }
388 
visitFilter__anonbc368a020111::ExtensionNodeRewriter389   isl::schedule visitFilter(const isl::schedule_node &Filter,
390                             const isl::union_set &Domain,
391                             isl::union_map &Extensions) {
392     isl::union_set FilterDomain = Filter.filter_get_filter();
393     isl::union_set NewDomain = Domain.intersect(FilterDomain);
394 
395     // A filter is added implicitly if necessary when joining schedule trees.
396     return visit(Filter.first_child(), NewDomain, Extensions);
397   }
398 
visitExtension__anonbc368a020111::ExtensionNodeRewriter399   isl::schedule visitExtension(const isl::schedule_node &Extension,
400                                const isl::union_set &Domain,
401                                isl::union_map &Extensions) {
402     isl::union_map ExtDomain = Extension.extension_get_extension();
403     isl::union_set NewDomain = Domain.unite(ExtDomain.range());
404     isl::union_map ChildExtensions;
405     isl::schedule NewChild =
406         visit(Extension.first_child(), NewDomain, ChildExtensions);
407     Extensions = ChildExtensions.unite(ExtDomain);
408     return NewChild;
409   }
410 };
411 
412 /// Collect all AST build options in any schedule tree band.
413 ///
414 /// ScheduleTreeRewriter cannot apply the schedule tree options. This class
415 /// collects these options to apply them later.
416 struct CollectASTBuildOptions
417     : public RecursiveScheduleTreeVisitor<CollectASTBuildOptions> {
418   using BaseTy = RecursiveScheduleTreeVisitor<CollectASTBuildOptions>;
getBase__anonbc368a020111::CollectASTBuildOptions419   BaseTy &getBase() { return *this; }
getBase__anonbc368a020111::CollectASTBuildOptions420   const BaseTy &getBase() const { return *this; }
421 
422   llvm::SmallVector<isl::union_set, 8> ASTBuildOptions;
423 
visitBand__anonbc368a020111::CollectASTBuildOptions424   void visitBand(const isl::schedule_node &Band) {
425     ASTBuildOptions.push_back(
426         isl::manage(isl_schedule_node_band_get_ast_build_options(Band.get())));
427     return getBase().visitBand(Band);
428   }
429 };
430 
431 /// Apply AST build options to the bands in a schedule tree.
432 ///
433 /// This rewrites a schedule tree with the AST build options applied. We assume
434 /// that the band nodes are visited in the same order as they were when the
435 /// build options were collected, typically by CollectASTBuildOptions.
436 struct ApplyASTBuildOptions
437     : public ScheduleNodeRewriter<ApplyASTBuildOptions> {
438   using BaseTy = ScheduleNodeRewriter<ApplyASTBuildOptions>;
getBase__anonbc368a020111::ApplyASTBuildOptions439   BaseTy &getBase() { return *this; }
getBase__anonbc368a020111::ApplyASTBuildOptions440   const BaseTy &getBase() const { return *this; }
441 
442   size_t Pos;
443   llvm::ArrayRef<isl::union_set> ASTBuildOptions;
444 
ApplyASTBuildOptions__anonbc368a020111::ApplyASTBuildOptions445   ApplyASTBuildOptions(llvm::ArrayRef<isl::union_set> ASTBuildOptions)
446       : ASTBuildOptions(ASTBuildOptions) {}
447 
visitSchedule__anonbc368a020111::ApplyASTBuildOptions448   isl::schedule visitSchedule(const isl::schedule &Schedule) {
449     Pos = 0;
450     isl::schedule Result = visit(Schedule).get_schedule();
451     assert(Pos == ASTBuildOptions.size() &&
452            "AST build options must match to band nodes");
453     return Result;
454   }
455 
visitBand__anonbc368a020111::ApplyASTBuildOptions456   isl::schedule_node visitBand(const isl::schedule_node &Band) {
457     isl::schedule_node Result =
458         Band.band_set_ast_build_options(ASTBuildOptions[Pos]);
459     Pos += 1;
460     return getBase().visitBand(Result);
461   }
462 };
463 
464 } // namespace
465 
466 /// Return whether the schedule contains an extension node.
containsExtensionNode(isl::schedule Schedule)467 static bool containsExtensionNode(isl::schedule Schedule) {
468   assert(!Schedule.is_null());
469 
470   auto Callback = [](__isl_keep isl_schedule_node *Node,
471                      void *User) -> isl_bool {
472     if (isl_schedule_node_get_type(Node) == isl_schedule_node_extension) {
473       // Stop walking the schedule tree.
474       return isl_bool_error;
475     }
476 
477     // Continue searching the subtree.
478     return isl_bool_true;
479   };
480   isl_stat RetVal = isl_schedule_foreach_schedule_node_top_down(
481       Schedule.get(), Callback, nullptr);
482 
483   // We assume that the traversal itself does not fail, i.e. the only reason to
484   // return isl_stat_error is that an extension node was found.
485   return RetVal == isl_stat_error;
486 }
487 
hoistExtensionNodes(isl::schedule Sched)488 isl::schedule polly::hoistExtensionNodes(isl::schedule Sched) {
489   // If there is no extension node in the first place, return the original
490   // schedule tree.
491   if (!containsExtensionNode(Sched))
492     return Sched;
493 
494   // Build options can anchor schedule nodes, such that the schedule tree cannot
495   // be modified anymore. Therefore, apply build options after the tree has been
496   // created.
497   CollectASTBuildOptions Collector;
498   Collector.visit(Sched);
499 
500   // Rewrite the schedule tree without extension nodes.
501   ExtensionNodeRewriter Rewriter;
502   isl::schedule NewSched = Rewriter.visitSchedule(Sched);
503 
504   // Reapply the AST build options. The rewriter must not change the iteration
505   // order of bands. Any other node type is ignored.
506   ApplyASTBuildOptions Applicator(Collector.ASTBuildOptions);
507   NewSched = Applicator.visitSchedule(NewSched);
508 
509   return NewSched;
510 }
511