1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tir/ir/transform.cc
22  * \brief TIR specific transformation passes.
23  */
24 #include <tvm/node/repr_printer.h>
25 #include <tvm/runtime/registry.h>
26 #include <tvm/tir/transform.h>
27 
28 namespace tvm {
29 namespace tir {
30 namespace transform {
31 
32 /*!
33  * \brief Function level pass that applies transformations to all
34  *        TIR functions within the module.
35  */
36 class PrimFuncPassNode : public PassNode {
37  public:
38   /* \brief The pass meta data.*/
39   PassInfo pass_info;
40 
41   /*! \brief The pass function called on each. */
42   runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func;
43 
VisitAttrs(tvm::AttrVisitor * v)44   void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
45 
46   /*!
47    * \brief Run a function pass on given pass context.
48    *
49    * \param mod The module that an optimization pass is applied on.
50    * \param pass_ctx The context that an optimization pass executes on.
51    *
52    * \return Return the updated module.
53    */
54   IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
55 
56   /*!
57    * \brief Get the pass information/meta data.
58    */
Info() const59   PassInfo Info() const override { return pass_info; }
60 
61   static constexpr const char* _type_key = "tir.PrimFuncPass";
62   TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncPassNode, PassNode);
63 };
64 
65 class PrimFuncPass : public Pass {
66  public:
67   /*!
68    * \brief The constructor
69    * \param pass_func The packed function which implements a pass.
70    * \param pass_info The pass info.
71    */
72   TVM_DLL PrimFuncPass(
73       runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
74       PassInfo pass_info);
75 
76   TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode);
77 };
78 
PrimFuncPass(runtime::TypedPackedFunc<PrimFunc (PrimFunc,IRModule,PassContext)> pass_func,PassInfo pass_info)79 PrimFuncPass::PrimFuncPass(
80     runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
81     PassInfo pass_info) {
82   auto n = make_object<PrimFuncPassNode>();
83   n->pass_func = std::move(pass_func);
84   n->pass_info = std::move(pass_info);
85   data_ = std::move(n);
86 }
87 
88 // Perform Module -> Module optimizations at the PrimFunc level.
operator ()(IRModule mod,const PassContext & pass_ctx) const89 IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
90   const PassInfo& pass_info = Info();
91   CHECK(mod.defined());
92   pass_ctx.Trace(mod, pass_info, true);
93   std::vector<ObjectRef> deleted_list;
94   IRModuleNode* mod_ptr = mod.CopyOnWrite();
95   auto* func_dict = mod_ptr->functions.CopyOnWrite();
96   // directly loop over the underlying dict
97   for (auto& kv : *func_dict) {
98     // only picks up tir::PrimFunc
99     if (kv.second->IsInstance<PrimFuncNode>()) {
100       // move out the function so that it is the only copy.
101       PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
102       func = pass_func(std::move(func), mod, pass_ctx);
103       kv.second = std::move(func);
104 
105       if (!kv.second.defined()) {
106         deleted_list.push_back(kv.first);
107       }
108     }
109   }
110 
111   // automatic removal of None
112   for (const auto& gv : deleted_list) {
113     func_dict->erase(gv);
114   }
115   pass_ctx.Trace(mod, pass_info, false);
116   return mod;
117 }
118 
CreatePrimFuncPass(const runtime::TypedPackedFunc<PrimFunc (PrimFunc,IRModule,PassContext)> & pass_func,int opt_level,String name,tvm::Array<String> required)119 Pass CreatePrimFuncPass(
120     const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
121     int opt_level, String name, tvm::Array<String> required) {
122   PassInfo pass_info = PassInfo(opt_level, name, required);
123   return PrimFuncPass(pass_func, pass_info);
124 }
125 
126 TVM_REGISTER_NODE_TYPE(PrimFuncPassNode);
127 
128 TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass")
129     .set_body_typed(
130         [](runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
__anonee27b92f0102(runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func, PassInfo pass_info) 131            PassInfo pass_info) { return PrimFuncPass(pass_func, pass_info); });
132 
133 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
__anonee27b92f0202(const ObjectRef& ref, ReprPrinter* p) 134     .set_dispatch<PrimFuncPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
135       auto* node = static_cast<const PrimFuncPassNode*>(ref.get());
136       const PassInfo info = node->Info();
137       p->stream << "PrimFuncPass(" << info->name << ", opt_level=" << info->opt_level << ")";
138     });
139 
140 }  // namespace transform
141 }  // namespace tir
142 }  // namespace tvm
143