1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file legalize.cc
22  * \brief Converts an expr to another expr. This pass can be used to transform an op based on its
23  * shape, dtype or layout to another op or a sequence of ops.
24  */
25 
26 #include <tvm/relay/expr_functor.h>
27 #include <tvm/relay/op_attr_types.h>
28 #include <tvm/relay/transform.h>
29 #include <tvm/te/operation.h>
30 
31 namespace tvm {
32 namespace relay {
33 
34 namespace legalize {
35 
36 // Call registered FTVMLegalize of an op
37 // Returns the legalized expression
38 class Legalizer : public ExprRewriter {
39  public:
Legalizer(const std::string & legalize_map_attr_name)40   explicit Legalizer(const std::string& legalize_map_attr_name)
41       : legalize_map_attr_name_{legalize_map_attr_name} {}
42 
Rewrite_(const CallNode * call_node,const Expr & post)43   Expr Rewrite_(const CallNode* call_node, const Expr& post) override {
44     // Get the new_call node without any changes to current call node.
45     Call new_call = Downcast<Call>(post);
46 
47     // Check if the string is registered.
48     if (!Op::HasAttrMap(legalize_map_attr_name_)) {
49       return post;
50     }
51 
52     // Collect the registered legalize function.
53     auto fop_legalize = Op::GetAttrMap<FTVMLegalize>(legalize_map_attr_name_);
54     auto call_op = call_node->op;
55     if (call_op.as<OpNode>()) {
56       Op op = Downcast<Op>(call_node->op);
57 
58       if (fop_legalize.count(op)) {
59         // Collect the new_args.
60         tvm::Array<Expr> call_args = new_call->args;
61 
62         // Collect input and output dtypes to pass on to Legalize API.
63         tvm::Array<tvm::relay::Type> types;
64         for (auto arg : call_node->args) {
65           types.push_back(arg->checked_type());
66         }
67         types.push_back(call_node->checked_type());
68 
69         // Transform the op by calling the registered legalize function.
70         Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
71 
72         // Return the new expr if the transformation succeeded.
73         if (legalized_value.defined()) {
74           // Check that the returned Expr from legalize is CallNode.
75           const CallNode* legalized_call_node = legalized_value.as<CallNode>();
76           CHECK(legalized_call_node)
77               << "Can only replace the original operator with another call node";
78           return legalized_value;
79         }
80       }
81     }
82 
83     return post;
84   }
85 
86  private:
87   std::string legalize_map_attr_name_;
88 };
89 
Legalize(const Expr & expr,const std::string & legalize_map_attr_name)90 Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
91   auto rewriter = Legalizer(legalize_map_attr_name);
92   return PostOrderRewrite(expr, &rewriter);
93 }
94 
95 }  // namespace legalize
96 
97 namespace transform {
98 
Legalize(const String & legalize_map_attr_name)99 Pass Legalize(const String& legalize_map_attr_name) {
100   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
101       [=](Function f, IRModule m, PassContext pc) {
102         return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
103       };
104   return CreateFunctionPass(pass_func, 1, "Legalize", {"InferType"});
105 }
106 
107 TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize);
108 
109 }  // namespace transform
110 
111 }  // namespace relay
112 }  // namespace tvm
113