1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file remap_thread_axis.cc
22 */
23 #include <tvm/ir.h>
24 #include <tvm/ir_mutator.h>
25 #include <tvm/ir_visitor.h>
26 #include <tvm/ir_pass.h>
27 #include <unordered_map>
28
29
30 namespace tvm {
31 namespace ir {
32
33 // Mutator to change the read pattern
34 class ThreadAxisRewriter : private IRMutator {
35 public:
ThreadAxisRewriter(const std::unordered_map<std::string,IterVar> & tmap)36 explicit ThreadAxisRewriter(
37 const std::unordered_map<std::string, IterVar>& tmap)
38 : tmap_(tmap) {
39 }
40
Rewrite(Stmt stmt)41 Stmt Rewrite(Stmt stmt) {
42 return Mutate(stmt);
43 }
44
45 private:
Mutate_(const AttrStmt * op,const Stmt & stmt)46 Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
47 if (op->attr_key == attr::thread_extent) {
48 IterVar iv = Downcast<IterVar>(op->node);
49 CHECK_NE(iv->thread_tag.length(), 0U);
50 auto it = tmap_.find(iv->thread_tag);
51 if (it != tmap_.end()) {
52 const IterVar& new_iv = it->second;
53 const Variable* v = iv->var.get();
54 if (!vmap_.count(v)) {
55 vmap_[v] = new_iv->var;
56 } else {
57 CHECK(vmap_[v].same_as(new_iv->var));
58 }
59 Stmt body = this->Mutate(op->body);
60 return AttrStmt::make(
61 new_iv, op->attr_key, op->value, body);
62 }
63 }
64 return IRMutator::Mutate_(op, stmt);
65 }
66
Mutate_(const Variable * op,const Expr & expr)67 Expr Mutate_(const Variable* op, const Expr& expr) final {
68 auto it = vmap_.find(op);
69 if (it != vmap_.end()) return it->second;
70 return IRMutator::Mutate_(op, expr);
71 }
72 // The thread map
73 const std::unordered_map<std::string, IterVar>& tmap_;
74 // variable map
75 std::unordered_map<const Variable*, Var> vmap_;
76 };
77
78 LoweredFunc
RemapThreadAxis(LoweredFunc f,Map<Expr,IterVar> thread_map)79 RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) {
80 std::unordered_map<std::string, IterVar> tmap;
81 for (const auto& kv : thread_map) {
82 const StringImm* str = kv.first.as<StringImm>();
83 CHECK(str != nullptr);
84 tmap[str->value] = kv.second;
85 }
86
87 CHECK_EQ(f->func_type, kDeviceFunc);
88 auto n = make_node<LoweredFuncNode>(*f.operator->());
89 // replace the thread axis
90 for (size_t i = 0; i < n->thread_axis.size(); ++i) {
91 auto it = tmap.find(n->thread_axis[i]->thread_tag);
92 if (it != tmap.end()) {
93 n->thread_axis.Set(i, it->second);
94 }
95 }
96 n->body = ThreadAxisRewriter(tmap).Rewrite(n->body);
97 return LoweredFunc(n);
98 }
99
100 } // namespace ir
101 } // namespace tvm
102