1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tir/ir/transform.cc
22 * \brief TIR specific transformation passes.
23 */
24 #include <tvm/node/repr_printer.h>
25 #include <tvm/runtime/registry.h>
26 #include <tvm/tir/transform.h>
27
28 namespace tvm {
29 namespace tir {
30 namespace transform {
31
32 /*!
33 * \brief Function level pass that applies transformations to all
34 * TIR functions within the module.
35 */
36 class PrimFuncPassNode : public PassNode {
37 public:
38 /* \brief The pass meta data.*/
39 PassInfo pass_info;
40
41 /*! \brief The pass function called on each. */
42 runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func;
43
VisitAttrs(tvm::AttrVisitor * v)44 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
45
46 /*!
47 * \brief Run a function pass on given pass context.
48 *
49 * \param mod The module that an optimization pass is applied on.
50 * \param pass_ctx The context that an optimization pass executes on.
51 *
52 * \return Return the updated module.
53 */
54 IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
55
56 /*!
57 * \brief Get the pass information/meta data.
58 */
Info() const59 PassInfo Info() const override { return pass_info; }
60
61 static constexpr const char* _type_key = "tir.PrimFuncPass";
62 TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncPassNode, PassNode);
63 };
64
65 class PrimFuncPass : public Pass {
66 public:
67 /*!
68 * \brief The constructor
69 * \param pass_func The packed function which implements a pass.
70 * \param pass_info The pass info.
71 */
72 TVM_DLL PrimFuncPass(
73 runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
74 PassInfo pass_info);
75
76 TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode);
77 };
78
PrimFuncPass(runtime::TypedPackedFunc<PrimFunc (PrimFunc,IRModule,PassContext)> pass_func,PassInfo pass_info)79 PrimFuncPass::PrimFuncPass(
80 runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
81 PassInfo pass_info) {
82 auto n = make_object<PrimFuncPassNode>();
83 n->pass_func = std::move(pass_func);
84 n->pass_info = std::move(pass_info);
85 data_ = std::move(n);
86 }
87
88 // Perform Module -> Module optimizations at the PrimFunc level.
operator ()(IRModule mod,const PassContext & pass_ctx) const89 IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
90 const PassInfo& pass_info = Info();
91 CHECK(mod.defined());
92 pass_ctx.Trace(mod, pass_info, true);
93 std::vector<ObjectRef> deleted_list;
94 IRModuleNode* mod_ptr = mod.CopyOnWrite();
95 auto* func_dict = mod_ptr->functions.CopyOnWrite();
96 // directly loop over the underlying dict
97 for (auto& kv : *func_dict) {
98 // only picks up tir::PrimFunc
99 if (kv.second->IsInstance<PrimFuncNode>()) {
100 // move out the function so that it is the only copy.
101 PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
102 func = pass_func(std::move(func), mod, pass_ctx);
103 kv.second = std::move(func);
104
105 if (!kv.second.defined()) {
106 deleted_list.push_back(kv.first);
107 }
108 }
109 }
110
111 // automatic removal of None
112 for (const auto& gv : deleted_list) {
113 func_dict->erase(gv);
114 }
115 pass_ctx.Trace(mod, pass_info, false);
116 return mod;
117 }
118
CreatePrimFuncPass(const runtime::TypedPackedFunc<PrimFunc (PrimFunc,IRModule,PassContext)> & pass_func,int opt_level,String name,tvm::Array<String> required)119 Pass CreatePrimFuncPass(
120 const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
121 int opt_level, String name, tvm::Array<String> required) {
122 PassInfo pass_info = PassInfo(opt_level, name, required);
123 return PrimFuncPass(pass_func, pass_info);
124 }
125
126 TVM_REGISTER_NODE_TYPE(PrimFuncPassNode);
127
128 TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass")
129 .set_body_typed(
130 [](runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
__anonee27b92f0102(runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func, PassInfo pass_info) 131 PassInfo pass_info) { return PrimFuncPass(pass_func, pass_info); });
132
133 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
__anonee27b92f0202(const ObjectRef& ref, ReprPrinter* p) 134 .set_dispatch<PrimFuncPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
135 auto* node = static_cast<const PrimFuncPassNode*>(ref.get());
136 const PassInfo info = node->Info();
137 p->stream << "PrimFuncPass(" << info->name << ", opt_level=" << info->opt_level << ")";
138 });
139
140 } // namespace transform
141 } // namespace tir
142 } // namespace tvm
143