1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file op_common.h
22  * \brief A set of utilities and common functionality
23  * for relay ops.
24  */
25 #ifndef TVM_RELAY_OP_OP_COMMON_H_
26 #define TVM_RELAY_OP_OP_COMMON_H_
27 
28 #include <tvm/relay/expr.h>
29 #include <tvm/relay/op.h>
30 #include <tvm/relay/op_attr_types.h>
31 #include <vector>
32 #include <string>
33 #include <unordered_map>
34 #include "type_relations.h"
35 #include "../pass/alter_op_layout.h"
36 
37 namespace tvm {
38 namespace relay {
39 
40 /*! Quick helper macro
41  * - Expose a positional make function to construct the node.
42  * - Register op to the registry.
43  *
44  * We make the decision to always only expose positional argument.
45  * We will do rewrapping in the frontend to support language
46  * sugars such as keyword arguments and default value.
47 
48  * \param OpName the name of registry.
49  */
50 #define RELAY_REGISTER_UNARY_OP(OpName)                     \
51   TVM_REGISTER_API("relay.op._make." OpName)                \
52     .set_body_typed<Expr(Expr)>([](Expr data) {             \
53         static const Op& op = Op::Get(OpName);              \
54         return CallNode::make(op, {data}, Attrs(), {});     \
55       });                                                   \
56   RELAY_REGISTER_OP(OpName)                                 \
57     .set_num_inputs(1)                                      \
58     .add_argument("data", "Tensor", "The input tensor.")    \
59     .add_type_rel("Identity", IdentityRel)                  \
60     .set_attr<TOpPattern>("TOpPattern", kElemWise)          \
61     .set_attr<TOpIsStateful>("TOpIsStateful", false)        \
62     .set_attr<FInferCorrectLayout>("FInferCorrectLayout",   \
63                                    ElemwiseArbitraryLayout) \
64 
65 
66 /*! Quick helper macro
67  * - Expose a positional make function to construct the node.
68  * - Register op to the registry.
69  *
70  * We make the decision to always only expose positional argument.
71  * We will do rewrapping in the frontend to support language
72  * sugars such as keyword arguments and default value.
73  *
74  * \param OpName the name of registry.
75  */
76 #define RELAY_REGISTER_BINARY_OP(OpName)                          \
77   TVM_REGISTER_API("relay.op._make." OpName)                      \
78     .set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {    \
79         static const Op& op = Op::Get(OpName);                    \
80         return CallNode::make(op, {lhs, rhs}, Attrs(), {});       \
81       });                                                         \
82   RELAY_REGISTER_OP(OpName)                                       \
83     .set_num_inputs(2)                                            \
84     .add_argument("lhs", "Tensor", "The left hand side tensor.")  \
85     .add_argument("rhs", "Tensor", "The right hand side tensor.") \
86     .add_type_rel("Broadcast", BroadcastRel)                      \
87     .set_attr<TOpPattern>("TOpPattern", kBroadcast)               \
88     .set_attr<TOpIsStateful>("TOpIsStateful", false)              \
89     .set_attr<FInferCorrectLayout>("FInferCorrectLayout",         \
90                                    BinaryBroadcastLayout)
91 
92 // Comparisons
93 #define RELAY_REGISTER_CMP_OP(OpName)                             \
94   TVM_REGISTER_API("relay.op._make." OpName)                      \
95   .set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {      \
96     static const Op& op = Op::Get(OpName);                        \
97     return CallNode::make(op, {lhs, rhs}, Attrs(), {});           \
98   });                                                             \
99   RELAY_REGISTER_OP(OpName)                                       \
100     .set_num_inputs(2)                                            \
101     .add_argument("lhs", "Tensor", "The left hand side tensor.")  \
102     .add_argument("rhs", "Tensor", "The right hand side tensor.") \
103     .add_type_rel("BroadcastComp", BroadcastCompRel)              \
104     .set_attr<TOpPattern>("TOpPattern", kBroadcast)               \
105     .set_attr<TOpIsStateful>("TOpIsStateful", false)              \
106     .set_attr<FInferCorrectLayout>("FInferCorrectLayout",         \
107                                    BinaryBroadcastLayout)
108 
109 
110 /*! \brief A helper class for matching and rewriting operators. */
111 template<typename R>
112 class OpMatch {
113  public:
114   using MatchFunc =
115       std::function<R(const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_args)>;
116 
117   /*! \brief Match an operator with the given name.
118    *  \param op_name The name of the operator to match.
119    *  \param func The function to execute when it matches.
120    *  \return A self-reference for builder style API.
121    */
Match(const std::string & op_name,MatchFunc func)122   inline OpMatch& Match(const std::string& op_name, MatchFunc func) {
123     auto op = Op::Get(op_name);
124     match_map_.insert({op, func});
125     return *this;
126   }
127 
128   /*! \brief Rewrite a call operation based on the operator and the registered
129    *  match functions.
130    * \param call The call to rewrite.
131    * \return The result of rewriting.
132    */
operator()133   inline R operator()(const Call& call) {
134     auto it = match_map_.find(Downcast<Op>(call->op));
135     if (it != match_map_.end()) {
136       return it->second(call->args, call->attrs, call->type_args);
137     } else {
138       if (default_ != nullptr) {
139         return default_(call->args, call->attrs, call->type_args);
140       } else {
141         LOG(FATAL) << "unexpected operation " << call->op;
142       }
143     }
144   }
145 
146  private:
147   /*! \brief The match function map. */
148   std::unordered_map<Op, MatchFunc, NodeHash, NodeEqual> match_map_;
149   /*! \brief An optional default case. */
150   MatchFunc default_;
151 };
152 
153 }  // namespace relay
154 }  // namespace tvm
155 
156 #endif  // TVM_RELAY_OP_OP_COMMON_H_
157