1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file legalize.cc
22  * \brief Converts an expr to another expr. This pass can be used to transform an op based on its
23  * shape, dtype or layout to another op or a sequence of ops.
24  */
25 
26 #include <tvm/operation.h>
27 #include <tvm/relay/expr_functor.h>
28 #include <tvm/relay/op_attr_types.h>
29 #include <tvm/relay/transform.h>
30 
31 namespace tvm {
32 namespace relay {
33 
34 namespace legalize {
35 
36 // Call registered FTVMLegalize of an op
37 // Returns the legalized expression
38 class Legalizer : public ExprMutator {
39  public:
Legalizer(const std::string & legalize_map_attr_name)40   explicit Legalizer(const std::string& legalize_map_attr_name)
41       : legalize_map_attr_name_{legalize_map_attr_name} {}
42 
VisitExpr_(const CallNode * call_node)43   Expr VisitExpr_(const CallNode* call_node) {
44     // Get the new_call node without any changes to current call node.
45     Expr new_e = ExprMutator::VisitExpr_(call_node);
46     Call new_call = Downcast<Call>(new_e);
47 
48     // Check if the string is registered in the OpRegistry.
49     if (!Op::HasAttr(legalize_map_attr_name_)) {
50       return new_e;
51     }
52 
53     // Collect the registered legalize function.
54     auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
55     auto call_op = call_node->op;
56     if (call_op.as<OpNode>()) {
57       Op op = Downcast<Op>(call_node->op);
58 
59       if (fop_legalize.count(op)) {
60         // Collect the new_args.
61         tvm::Array<Expr> call_args = new_call->args;
62 
63         // Collect input and output dtypes to pass on to Legalize API.
64         tvm::Array<tvm::relay::Type> types;
65         for (auto arg : call_node->args) {
66           types.push_back(arg->checked_type());
67         }
68         types.push_back(call_node->checked_type());
69 
70         // Transform the op by calling the registered legalize function.
71         Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);
72 
73         // Reassign new_e if the transformation succeeded.
74         if (legalized_value.defined()) {
75           // Check that the returned Expr from legalize is CallNode.
76           const CallNode* legalized_call_node = legalized_value.as<CallNode>();
77           CHECK(legalized_call_node)
78               << "Can only replace the original operator with another call node";
79 
80           new_e = legalized_value;
81         }
82       }
83     }
84 
85     return new_e;
86   }
87 
88  private:
89   std::string legalize_map_attr_name_;
90 };
91 
Legalize(const Expr & expr,const std::string & legalize_map_attr_name)92 Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
93   return Legalizer(legalize_map_attr_name).Mutate(expr);
94 }
95 
96 }  // namespace legalize
97 
98 namespace transform {
99 
Legalize(const std::string & legalize_map_attr_name)100 Pass Legalize(const std::string& legalize_map_attr_name) {
101   runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
102       [=](Function f, Module m, PassContext pc) {
103         return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
104       };
105   return CreateFunctionPass(pass_func, 1, "Legalize", {ir::StringImm::make("InferType")});
106 }
107 
108 TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize);
109 
110 }  // namespace transform
111 
112 }  // namespace relay
113 }  // namespace tvm
114