1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file schedule_dataflow_rewrite.cc
22 */
23 #include <tvm/te/operation.h>
24 #include <tvm/te/schedule.h>
25 #include <tvm/tir/op.h>
26 #include <tvm/tir/stmt_functor.h>
27
28 #include <unordered_set>
29
30 #include "../../tir/transforms/ir_util.h"
31 #include "message_passing.h"
32 #include "operation_inline.h"
33
34 namespace tvm {
35 namespace te {
36 // find first occurance location in leaf
37 template <typename T>
FindNodeRef(ArrayNode * array_node,const T & v)38 size_t FindNodeRef(ArrayNode* array_node, const T& v) {
39 const Object* n = v.get();
40 for (size_t i = 0; i < array_node->size(); ++i) {
41 if (array_node->at(i).get() == n) return i;
42 }
43 return array_node->size();
44 }
45
46 // The replacer of cache.
47 class VarReplacer : public tir::StmtExprMutator {
48 public:
VarReplacer(const std::unordered_map<const VarNode *,PrimExpr> & vsub)49 explicit VarReplacer(const std::unordered_map<const VarNode*, PrimExpr>& vsub) : vsub_(vsub) {}
VisitExpr_(const VarNode * op)50 PrimExpr VisitExpr_(const VarNode* op) final {
51 auto it = vsub_.find(op);
52 if (it != vsub_.end()) return it->second;
53 return GetRef<PrimExpr>(op);
54 }
55
MutateCommReducer(tir::CommReducer combiner)56 tir::CommReducer MutateCommReducer(tir::CommReducer combiner) {
57 // Replace free variables in combiner
58 auto new_identity = tir::UpdateArray(combiner->identity_element,
59 [this](const PrimExpr& e) { return this->VisitExpr(e); });
60 auto new_result = tir::UpdateArray(combiner->result,
61 [this](const PrimExpr& e) { return this->VisitExpr(e); });
62
63 if (combiner->identity_element.same_as(new_identity) &&
64 combiner->identity_element.same_as(new_result)) {
65 return combiner;
66 } else {
67 return tir::CommReducer(combiner->lhs, combiner->rhs, new_result, new_identity);
68 }
69 }
70
VisitExpr_(const tir::ReduceNode * op)71 PrimExpr VisitExpr_(const tir::ReduceNode* op) final {
72 PrimExpr new_e = StmtExprMutator::VisitExpr_(op);
73 const tir::ReduceNode* new_reduce = new_e.as<tir::ReduceNode>();
74 tir::CommReducer new_combiner = MutateCommReducer(op->combiner);
75 if (op->combiner.same_as(new_combiner)) {
76 return new_e;
77 } else {
78 return tir::Reduce(new_combiner, new_reduce->source, new_reduce->axis, new_reduce->condition,
79 new_reduce->value_index, new_reduce->init);
80 }
81 }
82
83 private:
84 const std::unordered_map<const VarNode*, PrimExpr>& vsub_;
85 };
86
InjectPredicate(const Array<PrimExpr> & predicates,PrimExpr body)87 PrimExpr InjectPredicate(const Array<PrimExpr>& predicates, PrimExpr body) {
88 using tir::ReduceNode;
89 using tir::SelectNode;
90 if (predicates.size() == 0) return body;
91 const ReduceNode* reduce = body.as<ReduceNode>();
92 auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
93
94 if (reduce) {
95 auto n = make_object<ReduceNode>(*reduce);
96 n->condition = foldl(fand, n->condition, predicates);
97 return PrimExpr(n);
98 }
99 return Select(foldl(fand, const_true(1), predicates), body, make_zero(body.dtype()));
100 }
101
102 // Replace data flow appears in all stages given the tensor change.
103 // Also update vmap if subsequent dataflow need to be replaced.
104 // Need to keep an update to the date transitive closure property on the vmap by a reverse map.
ReplaceDataFlow(const Array<Stage> & stages,std::unordered_map<Tensor,Tensor> * vmap,std::unordered_map<Tensor,Tensor> * rvmap)105 void ReplaceDataFlow(const Array<Stage>& stages, std::unordered_map<Tensor, Tensor>* vmap,
106 std::unordered_map<Tensor, Tensor>* rvmap) {
107 for (Stage s : stages) {
108 Operation op = s->op->ReplaceInputs(s->op, *vmap);
109 if (!op.same_as(s->op)) {
110 for (int i = 0; i < op->num_outputs(); ++i) {
111 auto it = rvmap->find(s->op.output(i));
112 if (it != rvmap->end()) {
113 (*vmap)[it->second] = op.output(i);
114 } else {
115 (*vmap)[s->op.output(i)] = op.output(i);
116 (*rvmap)[op.output(i)] = s->op.output(i);
117 }
118 }
119 s->op = op;
120 }
121 }
122 }
123
ReduceEqual(const tir::ReduceNode * a,const tir::ReduceNode * b)124 inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) {
125 return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) &&
126 (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) &&
127 ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init)));
128 }
129
cache_read(const Tensor & tensor,const std::string & scope,const Array<Operation> & readers)130 Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope,
131 const Array<Operation>& readers) {
132 (*this)->InvalidateCache();
133 // create identity mapping.
134 std::ostringstream os;
135 os << tensor->op->name;
136 if (tensor->op->num_outputs() != 1) {
137 os << ".v" << tensor->value_index;
138 }
139 os << "." << scope;
140
141 std::unordered_map<Tensor, Tensor> vsub;
142 Stage s = operator[](tensor->op);
143 Tensor sugar_tensor = s->op.output(tensor->value_index);
144 Tensor cache = compute(
145 sugar_tensor->shape,
146 [&sugar_tensor](const Array<Var>& i) {
147 return sugar_tensor(Array<PrimExpr>(i.begin(), i.end()));
148 },
149 os.str());
150 vsub[sugar_tensor] = cache;
151
152 std::unordered_map<Tensor, Tensor> vmap;
153 std::unordered_map<Tensor, Tensor> rvmap;
154 for (Operation op : readers) {
155 Stage s = operator[](op);
156 Operation repl_op = s->op->ReplaceInputs(s->op, vsub);
157 CHECK(!repl_op.same_as(s->op)) << "Cannot find " << tensor << " in the inputs of " << s->op;
158 vmap[s->op.output(0)] = repl_op.output(0);
159 rvmap[repl_op.output(0)] = s->op.output(0);
160 s->op = repl_op;
161 }
162 ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
163 Array<Stage>& stages = (*this)->stages;
164 Stage op_stage = operator[](tensor->op);
165 size_t pos = FindNodeRef(stages.GetArrayNode(), op_stage);
166 Stage cache_stage = Stage(cache->op);
167 cache_stage.set_scope(scope);
168 CHECK_LT(pos, stages.size());
169 stages.insert(stages.begin() + pos + 1, cache_stage);
170 (*this)->stage_map.Set(cache->op, cache_stage);
171 // Update group
172 cache_stage->group = op_stage->group;
173 if (cache_stage->group.defined()) {
174 ++cache_stage->group->num_child_stages;
175 }
176 return cache;
177 }
178
179 template <typename OpType>
PrepareAxisMapping(Stage orig_stage,OpType * op,std::unordered_set<IterVar> * p_red_axis,Array<IterVar> * p_new_axis,std::unordered_map<IterVar,Range> * p_dom_map,std::unordered_map<const VarNode *,PrimExpr> * p_vsub,std::unordered_map<const VarNode *,PrimExpr> * p_vsub2newvar,std::vector<PrimExpr> * p_predicates)180 void PrepareAxisMapping(Stage orig_stage, OpType* op, std::unordered_set<IterVar>* p_red_axis,
181 Array<IterVar>* p_new_axis, std::unordered_map<IterVar, Range>* p_dom_map,
182 std::unordered_map<const VarNode*, PrimExpr>* p_vsub,
183 std::unordered_map<const VarNode*, PrimExpr>* p_vsub2newvar,
184 std::vector<PrimExpr>* p_predicates) {
185 auto& red_axis = *p_red_axis;
186 auto& new_axis = *p_new_axis;
187 auto& dom_map = *p_dom_map;
188 auto& vsub = *p_vsub;
189 auto& vsub2newvar = *p_vsub2newvar;
190 auto& predicates = *p_predicates;
191 arith::Analyzer analyzer;
192
193 for (IterVar iv : op->reduce_axis) {
194 red_axis.insert(iv);
195 }
196 for (IterVar iv : op->axis) {
197 dom_map[iv] = iv->dom;
198 analyzer.Bind(iv->var, iv->dom);
199 }
200 te::PassDownDomain(orig_stage, &dom_map, &analyzer, true);
201 {
202 // The source->cache
203 std::unordered_map<IterVar, PrimExpr> value_map;
204 for (IterVar iv : orig_stage->leaf_iter_vars) {
205 if (red_axis.count(iv)) continue;
206 CHECK_EQ(iv->iter_type, kDataPar) << "Can only relayout with in data parallel dimensions";
207 Range dom = dom_map.at(iv);
208 IterVar new_iv = IterVar(dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
209 new_axis.push_back(new_iv);
210 if (is_one(dom->min)) {
211 value_map[iv] = dom->min;
212 } else {
213 value_map[iv] = iv->var;
214 vsub2newvar[iv->var.get()] = new_iv->var;
215 }
216 }
217 // skip reduction iteration.
218 std::unordered_set<IterVar> skip_bound_check;
219 for (IterVar iv : op->reduce_axis) {
220 skip_bound_check.insert(iv);
221 }
222 PassUpIndex(orig_stage, dom_map, &value_map, true);
223 predicates = MakeBoundCheck(orig_stage, dom_map, value_map, true, skip_bound_check);
224 // The root axis
225 for (IterVar iv : op->axis) {
226 if (value_map.count(iv)) {
227 vsub[iv->var.get()] = value_map.at(iv);
228 } // to handle tensor axis
229 }
230 }
231 }
232
ReplaceOriginalOp(Schedule sch,Stage orig_stage,const std::string & scope,Operation cache_op,Operation orig_new_op,size_t tensor_size)233 Array<Tensor> ReplaceOriginalOp(Schedule sch, Stage orig_stage, const std::string& scope,
234 Operation cache_op, Operation orig_new_op, size_t tensor_size) {
235 Array<Tensor> cache_tensor_list;
236 for (size_t i = 0; i < tensor_size; i++) {
237 Tensor cache_tensor = cache_op.output(i);
238 cache_tensor_list.push_back(cache_tensor);
239 }
240 // The replace of the dataflow
241 std::unordered_map<Tensor, Tensor> vmap;
242 std::unordered_map<Tensor, Tensor> rvmap;
243 vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
244 rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
245 for (size_t i = 0; i < tensor_size; i++) {
246 vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
247 rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
248 }
249 ReplaceDataFlow(sch->stages, &vmap, &rvmap);
250 // mutate orig stage
251 orig_stage->op = orig_new_op;
252 orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
253 orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
254 orig_stage->relations = Array<IterVarRelation>();
255 // create schedule for new cached stage.
256 Array<Stage>& stages = sch->stages;
257 size_t pos = FindNodeRef(stages.GetArrayNode(), orig_stage);
258 Stage cache_stage = Stage(cache_op);
259 cache_stage.set_scope(scope);
260 CHECK_LT(pos, stages.size());
261 stages.insert(stages.begin() + pos, cache_stage);
262 sch->stage_map.Set(cache_op, cache_stage);
263 // Update group
264 cache_stage->group = orig_stage->group;
265 if (cache_stage->group.defined()) {
266 ++cache_stage->group->num_child_stages;
267 }
268 return cache_tensor_list;
269 }
270
271 // Cache write and relayout the data according to loop pattern
CacheWriteWithReLayout(Schedule sch,const Array<Tensor> & tensor_array,const std::string & scope)272 Array<Tensor> CacheWriteWithReLayout(Schedule sch, const Array<Tensor>& tensor_array,
273 const std::string& scope) {
274 size_t tensor_size = tensor_array.size();
275 sch->InvalidateCache();
276 Tensor tensor = tensor_array[0];
277 Stage orig_stage = sch[tensor->op];
278 const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
279
280 std::unordered_set<IterVar> red_axis;
281 Array<IterVar> new_axis;
282 std::unordered_map<IterVar, Range> dom_map;
283
284 std::unordered_map<const VarNode*, PrimExpr> vsub;
285 std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
286 std::vector<PrimExpr> predicates;
287
288 PrepareAxisMapping(orig_stage, compute, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar,
289 &predicates);
290
291 PrimExpr body;
292 Array<PrimExpr> body_list;
293 const tir::ReduceNode* first_reduce = nullptr;
294 for (auto cbody : compute->body) {
295 body = VarReplacer(vsub)(cbody);
296 body = InjectPredicate(predicates, body);
297 body = VarReplacer(vsub2newvar)(body);
298 // Reduce nodes in ONE computeOp must be the same except value_index
299 // This is right only if the original body ensures Reduce nodes are the same
300 if (body->IsInstance<tir::ReduceNode>()) {
301 const tir::ReduceNode* reduce_body = body.as<tir::ReduceNode>();
302 if (first_reduce != nullptr) {
303 CHECK(ReduceEqual(reduce_body, first_reduce));
304 body = tir::Reduce(first_reduce->combiner, first_reduce->source, first_reduce->axis,
305 first_reduce->condition, reduce_body->value_index, reduce_body->init);
306 } else {
307 first_reduce = reduce_body;
308 }
309 } else {
310 CHECK(first_reduce == nullptr) << "cannot mix reduce and other node in ONE compute bodys";
311 }
312 body_list.push_back(body);
313 }
314 // The reader args
315 Array<PrimExpr> args;
316 {
317 // cache->compute
318 std::unordered_map<IterVar, PrimExpr> value_map;
319 for (IterVar iv : compute->axis) {
320 value_map[iv] = iv->var;
321 }
322 te::PassDownIndex(orig_stage, dom_map, &value_map, true);
323 for (IterVar iv : orig_stage->leaf_iter_vars) {
324 if (red_axis.count(iv)) continue;
325 args.push_back(value_map.at(iv));
326 }
327 }
328 Operation cache_op =
329 ComputeOp(compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list);
330
331 Array<PrimExpr> cache_expr_list;
332 for (size_t i = 0; i < tensor_size; i++) {
333 Tensor cache_tensor = cache_op.output(i);
334 cache_expr_list.push_back(cache_tensor(args));
335 }
336 Operation orig_new_op =
337 ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list);
338 return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size);
339 }
340
341 // for tensor compute op
CacheWriteWithReLayoutTensor(Schedule sch,const Array<Tensor> & tensor_array,const std::string & scope)342 Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch, const Array<Tensor>& tensor_array,
343 const std::string& scope) {
344 size_t tensor_size = tensor_array.size();
345 sch->InvalidateCache();
346 Tensor tensor = tensor_array[0];
347 Stage orig_stage = sch[tensor->op];
348 const TensorComputeOpNode* tensor_op = orig_stage->op.as<TensorComputeOpNode>();
349 CHECK_EQ(tensor_op->num_outputs(), 1)
350 << "cache write only support single output tensor_compute_op";
351
352 std::unordered_set<IterVar> red_axis;
353 Array<IterVar> new_axis;
354 std::unordered_map<IterVar, Range> dom_map;
355
356 std::unordered_map<const VarNode*, PrimExpr> vsub;
357 std::unordered_map<const VarNode*, PrimExpr> vsub2newvar;
358 std::vector<PrimExpr> predicates;
359
360 PrepareAxisMapping(orig_stage, tensor_op, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar,
361 &predicates);
362
363 for (int i = tensor_op->schedulable_ndim; i < static_cast<int>(tensor_op->axis.size()); ++i) {
364 IterVar iv = tensor_op->axis[i];
365 IterVar new_iv = IterVar(iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
366 new_axis.push_back(new_iv);
367 }
368 Array<Region> new_regions;
369 for (Region old_region : tensor_op->input_regions) {
370 Region region;
371 for (Range r : old_region) {
372 PrimExpr min = VarReplacer(vsub2newvar)(r->min);
373 PrimExpr extent = VarReplacer(vsub2newvar)(r->extent);
374 region.push_back(Range::FromMinExtent(min, extent));
375 }
376 new_regions.push_back(region);
377 }
378
379 Array<PrimExpr> new_scalar_inputs;
380 for (PrimExpr old_input : tensor_op->scalar_inputs) {
381 new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input));
382 }
383
384 Operation cache_op =
385 TensorComputeOp(tensor_op->name + "." + scope, tensor_op->tag, new_axis,
386 tensor_op->reduce_axis, tensor_op->schedulable_ndim, tensor_op->intrin,
387 tensor_op->inputs, new_regions, new_scalar_inputs);
388
389 // axis will be used in generating compute op
390 Array<IterVar> compute_axis = tensor_op->axis;
391 for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
392 IterVar iv = tensor_op->axis[i];
393 IterVar aiv = IterVar(iv->dom, iv->var, kDataPar);
394 compute_axis.Set(i, aiv);
395 }
396
397 // The reader args
398 Array<PrimExpr> args;
399 {
400 // cache->compute
401 std::unordered_map<IterVar, PrimExpr> value_map;
402 for (IterVar iv : compute_axis) {
403 value_map[iv] = iv->var;
404 }
405 PassDownIndex(orig_stage, dom_map, &value_map, true);
406 for (IterVar iv : orig_stage->leaf_iter_vars) {
407 if (red_axis.count(iv)) continue;
408 args.push_back(value_map.at(iv));
409 }
410 // tensorized region axis
411 for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) {
412 IterVar iv = compute_axis[i];
413 args.push_back(value_map.at(iv));
414 }
415 }
416
417 Array<PrimExpr> cache_expr_list;
418 for (size_t i = 0; i < tensor_size; i++) {
419 Tensor cache_tensor = cache_op.output(i);
420 cache_expr_list.push_back(cache_tensor(args));
421 }
422 Operation orig_new_op =
423 ComputeOp(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list);
424 return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size);
425 }
426
cache_write(const Array<Tensor> & tensor_array,const std::string & scope)427 Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array, const std::string& scope) {
428 (*this)->InvalidateCache();
429 CHECK(tensor_array.size() > 0) << "size of tensor_array must be greater than 0";
430 Tensor tensor = tensor_array[0];
431 Stage orig_stage = operator[](tensor->op);
432 const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
433 CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
434 << "size of input tensor list must be same as number of stage outputs";
435 for (size_t i = 1; i < tensor_array.size(); i++) {
436 Stage tmp_stage = operator[](tensor_array[i]->op);
437 CHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp";
438 }
439 return CacheWriteWithReLayout(*this, tensor_array, scope);
440 }
441
cache_write(const Tensor & tensor,const std::string & scope)442 Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) {
443 // support original compute and tensor compute both
444 (*this)->InvalidateCache();
445 if (tensor->op.as<ComputeOpNode>()) {
446 return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
447 } else if (tensor->op.as<TensorComputeOpNode>()) {
448 return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0];
449 } else {
450 LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers";
451 return Tensor();
452 }
453 }
454
RebaseNonZeroMinLoop(ScheduleNode * sch)455 void RebaseNonZeroMinLoop(ScheduleNode* sch) {
456 std::unordered_map<IterVar, IterVar> rebase_map;
457 for (Stage s : sch->stages) {
458 if (s->attach_type == kInlinedAlready) continue;
459
460 auto root_iter_vars = s->op->root_iter_vars();
461 ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
462 for (IterVar iv : root_iter_vars) {
463 size_t idx = FindNodeRef(leaf_vars, iv);
464 auto it = s->iter_var_attrs.find(iv);
465 // don;t need to rebase path that are binded.
466 if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) {
467 continue;
468 }
469 if (idx < leaf_vars->size()) {
470 // insert rebase
471 IterVar rebased = IterVar(Range(), iv->var.copy_with_suffix(""), iv->iter_type);
472 s->relations.push_back(te::Rebase(iv, rebased));
473 if (s->iter_var_attrs.count(iv)) {
474 s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv));
475 }
476 leaf_vars->SetItem(idx, rebased);
477 rebase_map[iv] = rebased;
478 }
479 }
480 }
481 // remap the parent relation
482 for (Stage s : sch->stages) {
483 if (s->attach_type != kScope) continue;
484 if (rebase_map.count(s->attach_ivar)) {
485 s->attach_ivar = rebase_map.at(s->attach_ivar);
486 }
487 }
488 for (Stage s : sch->groups) {
489 if (s->attach_type != kScope) continue;
490 if (rebase_map.count(s->attach_ivar)) {
491 s->attach_ivar = rebase_map.at(s->attach_ivar);
492 }
493 }
494 }
495
InjectInline(ScheduleNode * sch)496 void InjectInline(ScheduleNode* sch) {
497 sch->InvalidateCache();
498
499 std::vector<Array<PrimExpr> > new_body(sch->stages.size());
500 std::vector<bool> changed(sch->stages.size(), false);
501 std::vector<Stmt> new_hybrid_body(sch->stages.size());
502 std::vector<bool> hybrid_changed(sch->stages.size(), false);
503 // inline all the ops
504 for (size_t i = sch->stages.size(); i != 0; --i) {
505 Stage stage = sch->stages[i - 1];
506 if (stage->attach_type == kInline) {
507 stage->attach_type = kInlinedAlready;
508 Array<Var> args;
509 PrimExpr body;
510 {
511 // setup args
512 const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
513 CHECK(compute) << "can only inline compute op";
514 for (auto iv : compute->axis) {
515 args.push_back(iv->var);
516 }
517 CHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output";
518 body = compute->body[0];
519 }
520 for (size_t j = i; j < sch->stages.size(); ++j) {
521 Stage s = sch->stages[j];
522 const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
523 const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
524 if (compute) {
525 if (!new_body[j].size()) {
526 new_body[j] = compute->body;
527 }
528 if (new_body[j][0]->IsInstance<tir::ReduceNode>()) {
529 // specially handle reduction inline for multiplre reductions.
530 const tir::ReduceNode* reduce = new_body[j][0].as<tir::ReduceNode>();
531 for (size_t k = 1; k < new_body[j].size(); ++k) {
532 const tir::ReduceNode* reduce_ = new_body[j][k].as<tir::ReduceNode>();
533 CHECK(reduce_);
534 CHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should "
535 << "have the same attribute except value_index";
536 }
537 PrimExpr new_value = Inline(tir::Evaluate(new_body[j][0]), stage->op, args, body)
538 .as<tir::EvaluateNode>()
539 ->value;
540 if (!new_value.same_as(new_body[j][0])) {
541 changed[j] = true;
542 const tir::ReduceNode* r = new_value.as<tir::ReduceNode>();
543 CHECK(r != nullptr);
544 CHECK_EQ(new_body[j].size(), r->source.size());
545 for (size_t k = 0; k < new_body[j].size(); ++k) {
546 auto n = make_object<tir::ReduceNode>(*r);
547 n->value_index = static_cast<int>(k);
548 n->dtype = r->source[k].dtype();
549 new_body[j].Set(k, PrimExpr(n));
550 }
551 }
552 } else {
553 for (size_t k = 0; k < new_body[j].size(); ++k) {
554 PrimExpr new_value = Inline(tir::Evaluate(new_body[j][k]), stage->op, args, body)
555 .as<tir::EvaluateNode>()
556 ->value;
557 if (!new_value.same_as(new_body[j][k])) {
558 new_body[j].Set(k, new_value);
559 changed[j] = true;
560 }
561 }
562 }
563 } else if (hybrid) {
564 if (!new_hybrid_body[j].defined()) {
565 new_hybrid_body[j] = hybrid->body;
566 }
567 Stmt new_stmt = Inline(new_hybrid_body[j], stage->op, args, body);
568 if (!new_stmt.same_as(new_hybrid_body[j])) {
569 new_hybrid_body[j] = new_stmt;
570 hybrid_changed[j] = true;
571 }
572 }
573 }
574 }
575 }
576 std::unordered_map<Tensor, Tensor> repl;
577 // rewrite dataflow
578 for (size_t i = 0; i < sch->stages.size(); ++i) {
579 Stage s = sch->stages[i];
580 if (s->attach_type == kInlinedAlready) continue;
581 if (new_body[i].size()) {
582 // Logics from ReplaceDataFlow
583 const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
584 CHECK(compute);
585 Operation op = s->op;
586 if (changed[i]) {
587 op = ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, new_body[i]);
588 }
589 op = op->ReplaceInputs(op, repl);
590 if (!op.same_as(s->op)) {
591 for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
592 repl[s->op.output(idx)] = op.output(idx);
593 }
594 s->op = op;
595 }
596 } else if (hybrid_changed[i]) {
597 const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
598 CHECK(hybrid);
599 Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
600 hybrid->outputs, new_hybrid_body[i]);
601 op = op->ReplaceInputs(op, repl);
602 for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
603 repl[s->op.output(idx)] = op.output(idx);
604 }
605 s->op = op;
606 } else {
607 Operation op = s->op->ReplaceInputs(s->op, repl);
608 if (!op.same_as(s->op)) {
609 for (int j = 0; j < op->num_outputs(); ++j) {
610 repl[s->op.output(j)] = op.output(j);
611 }
612 s->op = op;
613 }
614 }
615 }
616 }
617
LegalizeInvalidAttach(ScheduleNode * sch)618 void LegalizeInvalidAttach(ScheduleNode* sch) {
619 // Legalize the compute_at location if the target iterator of compute_at is split or fused.
620 // Case 1: If the target of compute_at is split,
621 // we will move the compute_at location to the inner iterator.
622 // Case 2: If the target of compute_at is fused,
623 // we will move the compute_at location to the newly fused iterator.
624 // Note that case 2 can only happen if the target of compute_at
625 // is the innermost operand of fuse operation.
626
627 // Map an old invalid attach point to its new valid attach point
628 std::unordered_map<IterVar, IterVar> replace_map;
629
630 for (Stage stage : sch->stages) {
631 for (Stage s = stage; s.defined();) {
632 // The following logic is simiar to the `CreateAttachPath` in `src/te/schedule/graph.h`,
633 // because we follow the validation check in that function to legalize the attach.
634 Stage spec = s.GetAttachSpec();
635 if (spec->attach_type != kScope) {
636 break;
637 }
638 bool start_attach = false;
639 IterVar attach_ivar = spec->attach_ivar;
640 s = spec->attach_stage;
641 CHECK(attach_ivar.defined());
642 CHECK(s.defined());
643
644 for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
645 IterVar iv = s->leaf_iter_vars[i - 1];
646 if (!start_attach && iv.same_as(attach_ivar)) {
647 start_attach = true;
648 break;
649 }
650 }
651
652 if (!start_attach) {
653 IterVar new_attach_ivar = attach_ivar;
654 bool updated = true;
655 // recursively update the relations
656 while (updated) {
657 updated = false;
658 for (const auto& rel : s->relations) {
659 if (const FuseNode* r = rel.as<FuseNode>()) {
660 if (new_attach_ivar.same_as(r->inner)) {
661 new_attach_ivar = r->fused;
662 updated = true;
663 }
664 } else if (const SplitNode* r = rel.as<SplitNode>()) {
665 if (new_attach_ivar.same_as(r->parent)) {
666 new_attach_ivar = r->inner;
667 updated = true;
668 }
669 }
670 }
671 replace_map[attach_ivar] = new_attach_ivar;
672 }
673 }
674 }
675 }
676
677 // remap the parent relation
678 for (Stage s : sch->stages) {
679 if (s->attach_type != kScope) continue;
680 if (replace_map.count(s->attach_ivar)) {
681 s->attach_ivar = replace_map.at(s->attach_ivar);
682 }
683 }
684 for (Stage s : sch->groups) {
685 if (s->attach_type != kScope) continue;
686 if (replace_map.count(s->attach_ivar)) {
687 s->attach_ivar = replace_map.at(s->attach_ivar);
688 }
689 }
690 }
691
normalize()692 Schedule Schedule::normalize() {
693 Schedule sn = copy();
694 InjectInline(sn.operator->());
695 RebaseNonZeroMinLoop(sn.operator->());
696 LegalizeInvalidAttach(sn.operator->());
697 return sn;
698 }
699
700 // Handle reduction factor.
rfactor(const Tensor & tensor,const IterVar & axis,int factor_axis)701 Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis) {
702 (*this)->InvalidateCache();
703 using tir::ReduceNode;
704 CHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis";
705 Stage reduce_stage = operator[](tensor->op);
706 const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
707 CHECK(compute_op) << "Can only factor ComputeOp";
708 ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
709 {
710 size_t axis_pos = FindNodeRef(leaf_vars, axis);
711 CHECK_NE(axis_pos, leaf_vars->size()) << "Cannot find IterVar " << axis << " in leaf iter vars";
712 }
713 // Find touched reduction axis.
714 std::unordered_map<IterVar, int> touch_map;
715 touch_map[axis] = 1;
716 te::PassUpBitMaskOr(reduce_stage, &touch_map, true);
717 te::PassDownBitMaskOr(reduce_stage, &touch_map, true);
718 // skip reduction iteration.
719 std::unordered_set<IterVar> skip_bound_check;
720 // Verify normal axis are not touched.
721 for (IterVar iv : compute_op->axis) {
722 CHECK(!touch_map.count(iv)) << "Factor axis touches normal axis.";
723 skip_bound_check.insert(iv);
724 }
725 // get analyzer.
726 arith::Analyzer analyzer;
727 // Get the replace index
728 std::unordered_map<IterVar, Range> dom_map;
729 std::unordered_map<IterVar, PrimExpr> value_map;
730 for (IterVar iv : compute_op->reduce_axis) {
731 if (touch_map.count(iv)) {
732 dom_map[iv] = iv->dom;
733 } else {
734 skip_bound_check.insert(iv);
735 }
736 analyzer.Bind(iv->var, iv->dom);
737 }
738 te::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
739 for (IterVar iv : reduce_stage->leaf_iter_vars) {
740 if (touch_map.count(iv)) {
741 Range dom = dom_map.at(iv);
742 if (is_one(dom->extent)) {
743 value_map[iv] = dom->min;
744 } else {
745 value_map[iv] = iv->var;
746 }
747 }
748 }
749 te::PassUpIndex(reduce_stage, dom_map, &value_map, true);
750 std::vector<PrimExpr> predicates =
751 MakeBoundCheck(reduce_stage, dom_map, value_map, true, skip_bound_check);
752
753 // Get the factored op node.
754 const int factor_axis_pos =
755 factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
756 CHECK_LE(factor_axis_pos, compute_op->axis.size());
757 auto n = make_object<ComputeOpNode>();
758 n->name = compute_op->name + ".rf";
759 {
760 // axis relacement.
761 auto iv_node = make_object<IterVarNode>();
762 iv_node->dom = dom_map.at(axis);
763 CHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0";
764 iv_node->var = axis->var;
765 iv_node->iter_type = kDataPar;
766
767 const int size = compute_op->axis.size();
768 for (int idx = 0; idx < size; ++idx) {
769 if (factor_axis_pos == idx) {
770 n->axis.push_back(IterVar(iv_node));
771 }
772 n->axis.push_back(compute_op->axis[idx]);
773 }
774 if (factor_axis_pos == size) {
775 n->axis.push_back(IterVar(iv_node));
776 }
777 }
778 // predicate generation, copy not touched axis.
779 int idx = tensor->value_index;
780 const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
781 CHECK(reduce) << "Can only rfactor non-inline reductions";
782 predicates.push_back(reduce->condition);
783 auto fand = [](PrimExpr a, PrimExpr b) { return a && b; };
784
785 PrimExpr predicate = likely(foldl(fand, const_true(1), predicates));
786
787 std::unordered_map<const VarNode*, PrimExpr> vsub;
788
789 for (IterVar iv : compute_op->reduce_axis) {
790 if (!touch_map.count(iv)) {
791 n->reduce_axis.push_back(iv);
792 } else {
793 CHECK(value_map.count(iv));
794 PrimExpr index = value_map.at(iv);
795 vsub[iv->var.get()] = index;
796 }
797 }
798
799 // Copy touched axis.
800 for (IterVar iv : reduce_stage->leaf_iter_vars) {
801 if (touch_map.count(iv) && !iv.same_as(axis)) {
802 CHECK_EQ(iv->iter_type, kCommReduce);
803 auto ncpy = make_object<IterVarNode>(*iv.operator->());
804 ncpy->dom = dom_map.at(iv);
805 n->reduce_axis.push_back(IterVar(ncpy));
806 }
807 }
808 VarReplacer replacer(vsub);
809 Array<PrimExpr> new_source =
810 tir::UpdateArray(reduce->source, [&replacer](const PrimExpr& e) { return replacer(e); });
811
812 PrimExpr new_pred = replacer(predicate);
813
814 std::vector<PrimExpr> body;
815 for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
816 body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx, {}));
817 }
818 n->body = Array<PrimExpr>(body);
819 // refresh relations, keep the un-touched relations.
820 Array<IterVarRelation> rels;
821 for (IterVarRelation rel : reduce_stage->relations) {
822 bool touched = false;
823 if (const SplitNode* r = rel.as<SplitNode>()) {
824 if (touch_map.count(r->parent)) touched = true;
825 } else if (const FuseNode* r = rel.as<FuseNode>()) {
826 if (touch_map.count(r->fused)) touched = true;
827 } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
828 if (touch_map.count(r->parent)) touched = true;
829 } else {
830 LOG(FATAL) << "unknown relation type";
831 }
832 if (!touched) {
833 rels.push_back(rel);
834 }
835 }
836 // initialize the factored stage.
837 Operation factor_op(n);
838 Array<Stage>& stages = (*this)->stages;
839 size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage);
840 Stage factor_stage = Stage(factor_op);
841 factor_stage->relations = rels;
842 CHECK_LT(stage_pos, stages.size());
843 stages.insert(stages.begin() + stage_pos, factor_stage);
844 (*this)->stage_map.Set(factor_op, factor_stage);
845 factor_stage->group = reduce_stage->group;
846 if (factor_stage->group.defined()) {
847 ++factor_stage->group->num_child_stages;
848 }
849 // Replace the old reduction.
850 IterVar repl_red_axis = reduce_axis(dom_map.at(axis), axis->var->name_hint + ".v");
851 Array<Tensor> factor_tensors;
852 Array<Tensor> old_tensors;
853 int size = factor_op->num_outputs();
854 for (int idx = 0; idx < size; ++idx) {
855 factor_tensors.push_back(factor_op.output(idx));
856 old_tensors.push_back(reduce_stage->op.output(idx));
857 }
858 Array<Tensor> repl_tensors = compute(
859 old_tensors[0]->shape,
860 [&](const Array<Var>& i) {
861 Array<PrimExpr> indices;
862 const int idx_size = static_cast<int>(i.size());
863 for (int idx = 0; idx < idx_size; ++idx) {
864 if (factor_axis_pos == idx) {
865 indices.push_back(repl_red_axis->var);
866 }
867 indices.push_back(i[idx]);
868 }
869 Array<PrimExpr> new_init = reduce->init;
870 if (!reduce->init.empty()) {
871 std::unordered_map<const VarNode*, PrimExpr> init_vsub;
872 for (const auto& init : reduce->init) {
873 if (init->IsInstance<ProducerLoadNode>()) {
874 CHECK_EQ(compute_op->axis.size(), idx_size)
875 << "'init' should have the number of dimensions as output when using with "
876 "rfactor";
877 for (int idx = 0; idx < idx_size; idx++) {
878 init_vsub[compute_op->axis[idx]->var.get()] = i[idx];
879 }
880 }
881 }
882 VarReplacer init_replacer(init_vsub);
883 new_init = tir::UpdateArray(
884 reduce->init, [&init_replacer](const PrimExpr& e) { return init_replacer(e); });
885 }
886 if (factor_axis_pos == idx_size) {
887 indices.push_back(repl_red_axis->var);
888 }
889 Array<PrimExpr> factor_exprs;
890 for (int idx = 0; idx < size; ++idx) {
891 factor_exprs.push_back(factor_tensors[idx](indices));
892 }
893 Array<PrimExpr> reductions;
894 Array<IterVar> axis = {repl_red_axis};
895 PrimExpr cond = const_true();
896 for (int idx = 0; idx < size; ++idx) {
897 reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx, new_init));
898 }
899 return reductions;
900 },
901 reduce_stage->op->name + ".repl");
902
903 std::unordered_map<Tensor, Tensor> vmap;
904 std::unordered_map<Tensor, Tensor> rvmap;
905 for (int idx = 0; idx < size; ++idx) {
906 vmap[old_tensors[idx]] = repl_tensors[idx];
907 rvmap[repl_tensors[idx]] = old_tensors[idx];
908 }
909 ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
910 // revamp the reduction stage.
911 reduce_stage->op = repl_tensors[0]->op;
912 reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
913 reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
914 reduce_stage->relations = Array<IterVarRelation>();
915 return factor_tensors;
916 }
917 } // namespace te
918 } // namespace tvm
919