1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file remap_thread_axis.cc
22  */
23 #include <tvm/ir.h>
24 #include <tvm/ir_mutator.h>
25 #include <tvm/ir_visitor.h>
26 #include <tvm/ir_pass.h>
27 #include <unordered_map>
28 
29 
30 namespace tvm {
31 namespace ir {
32 
33 // Mutator to change the read pattern
34 class ThreadAxisRewriter : private IRMutator {
35  public:
ThreadAxisRewriter(const std::unordered_map<std::string,IterVar> & tmap)36   explicit ThreadAxisRewriter(
37       const std::unordered_map<std::string, IterVar>& tmap)
38       : tmap_(tmap) {
39   }
40 
Rewrite(Stmt stmt)41   Stmt Rewrite(Stmt stmt) {
42     return Mutate(stmt);
43   }
44 
45  private:
Mutate_(const AttrStmt * op,const Stmt & stmt)46   Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) final {
47     if (op->attr_key == attr::thread_extent) {
48       IterVar iv = Downcast<IterVar>(op->node);
49       CHECK_NE(iv->thread_tag.length(), 0U);
50       auto it = tmap_.find(iv->thread_tag);
51       if (it != tmap_.end()) {
52         const IterVar& new_iv = it->second;
53         const Variable* v = iv->var.get();
54         if (!vmap_.count(v)) {
55           vmap_[v] = new_iv->var;
56         } else {
57           CHECK(vmap_[v].same_as(new_iv->var));
58         }
59         Stmt body = this->Mutate(op->body);
60         return AttrStmt::make(
61             new_iv, op->attr_key, op->value, body);
62       }
63     }
64     return IRMutator::Mutate_(op, stmt);
65   }
66 
Mutate_(const Variable * op,const Expr & expr)67   Expr Mutate_(const Variable* op, const Expr& expr) final {
68     auto it = vmap_.find(op);
69     if (it != vmap_.end()) return it->second;
70     return IRMutator::Mutate_(op, expr);
71   }
72   // The thread map
73   const std::unordered_map<std::string, IterVar>& tmap_;
74   // variable map
75   std::unordered_map<const Variable*, Var> vmap_;
76 };
77 
78 LoweredFunc
RemapThreadAxis(LoweredFunc f,Map<Expr,IterVar> thread_map)79 RemapThreadAxis(LoweredFunc f, Map<Expr, IterVar> thread_map) {
80   std::unordered_map<std::string, IterVar> tmap;
81   for (const auto& kv : thread_map) {
82     const StringImm* str = kv.first.as<StringImm>();
83     CHECK(str != nullptr);
84     tmap[str->value] = kv.second;
85   }
86 
87   CHECK_EQ(f->func_type, kDeviceFunc);
88   auto n = make_node<LoweredFuncNode>(*f.operator->());
89   // replace the thread axis
90   for (size_t i = 0; i < n->thread_axis.size(); ++i) {
91     auto it = tmap.find(n->thread_axis[i]->thread_tag);
92     if (it != tmap.end()) {
93       n->thread_axis.Set(i, it->second);
94     }
95   }
96   n->body = ThreadAxisRewriter(tmap).Rewrite(n->body);
97   return LoweredFunc(n);
98 }
99 
100 }  // namespace ir
101 }  // namespace tvm
102