1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file legalize.cc
22 * \brief Converts an expr to another expr. This pass can be used to transform an op based on its
23 * shape, dtype or layout to another op or a sequence of ops.
24 */
25
26 #include <tvm/operation.h>
27 #include <tvm/relay/expr_functor.h>
28 #include <tvm/relay/op_attr_types.h>
29 #include <tvm/relay/transform.h>
30
31 namespace tvm {
32 namespace relay {
33
34 namespace legalize {
35
36 // Call registered FTVMLegalize of an op
37 // Returns the legalized expression
38 class Legalizer : public ExprMutator {
39 public:
Legalizer(const std::string & legalize_map_attr_name)40 explicit Legalizer(const std::string& legalize_map_attr_name)
41 : legalize_map_attr_name_{legalize_map_attr_name} {}
42
VisitExpr_(const CallNode * call_node)43 Expr VisitExpr_(const CallNode* call_node) {
44 // Get the new_call node without any changes to current call node.
45 Expr new_e = ExprMutator::VisitExpr_(call_node);
46 Call new_call = Downcast<Call>(new_e);
47
48 // Check if the string is registered in the OpRegistry.
49 if (!Op::HasAttr(legalize_map_attr_name_)) {
50 return new_e;
51 }
52
53 // Collect the registered legalize function.
54 auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
55 auto call_op = call_node->op;
56 if (call_op.as<OpNode>()) {
57 Op op = Downcast<Op>(call_node->op);
58
59 if (fop_legalize.count(op)) {
60 // Collect the new_args.
61 tvm::Array<Expr> call_args = new_call->args;
62
63 // Collect input and output dtypes to pass on to Legalize API.
64 tvm::Array<tvm::relay::Type> types;
65 for (auto arg : call_node->args) {
66 types.push_back(arg->checked_type());
67 }
68 types.push_back(call_node->checked_type());
69
70 // Transform the op by calling the registered legalize function.
71 Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
72
73 // Reassign new_e if the transformation succeeded.
74 if (legalized_value.defined()) {
75 // Check that the returned Expr from legalize is CallNode.
76 const CallNode* legalized_call_node = legalized_value.as<CallNode>();
77 CHECK(legalized_call_node)
78 << "Can only replace the original operator with another call node";
79
80 new_e = legalized_value;
81 }
82 }
83 }
84
85 return new_e;
86 }
87
88 private:
89 std::string legalize_map_attr_name_;
90 };
91
Legalize(const Expr & expr,const std::string & legalize_map_attr_name)92 Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
93 return Legalizer(legalize_map_attr_name).Mutate(expr);
94 }
95
96 } // namespace legalize
97
98 namespace transform {
99
Legalize(const std::string & legalize_map_attr_name)100 Pass Legalize(const std::string& legalize_map_attr_name) {
101 runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
102 [=](Function f, Module m, PassContext pc) {
103 return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
104 };
105 return CreateFunctionPass(pass_func, 1, "Legalize", {ir::StringImm::make("InferType")});
106 }
107
108 TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize);
109
110 } // namespace transform
111
112 } // namespace relay
113 } // namespace tvm
114