1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file op_common.h 22 * \brief A set of utilities and common functionality 23 * for relay ops. 24 */ 25 #ifndef TVM_RELAY_OP_OP_COMMON_H_ 26 #define TVM_RELAY_OP_OP_COMMON_H_ 27 28 #include <tvm/relay/expr.h> 29 #include <tvm/relay/op.h> 30 #include <tvm/relay/op_attr_types.h> 31 #include <vector> 32 #include <string> 33 #include <unordered_map> 34 #include "type_relations.h" 35 #include "../pass/alter_op_layout.h" 36 37 namespace tvm { 38 namespace relay { 39 40 /*! Quick helper macro 41 * - Expose a positional make function to construct the node. 42 * - Register op to the registry. 43 * 44 * We make the decision to always only expose positional argument. 45 * We will do rewrapping in the frontend to support language 46 * sugars such as keyword arguments and default value. 47 48 * \param OpName the name of registry. 49 */ 50 #define RELAY_REGISTER_UNARY_OP(OpName) \ 51 TVM_REGISTER_API("relay.op._make." OpName) \ 52 .set_body_typed<Expr(Expr)>([](Expr data) { \ 53 static const Op& op = Op::Get(OpName); \ 54 return CallNode::make(op, {data}, Attrs(), {}); \ 55 }); \ 56 RELAY_REGISTER_OP(OpName) \ 57 .set_num_inputs(1) \ 58 .add_argument("data", "Tensor", "The input tensor.") \ 59 .add_type_rel("Identity", IdentityRel) \ 60 .set_attr<TOpPattern>("TOpPattern", kElemWise) \ 61 .set_attr<TOpIsStateful>("TOpIsStateful", false) \ 62 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \ 63 ElemwiseArbitraryLayout) \ 64 65 66 /*! Quick helper macro 67 * - Expose a positional make function to construct the node. 68 * - Register op to the registry. 69 * 70 * We make the decision to always only expose positional argument. 71 * We will do rewrapping in the frontend to support language 72 * sugars such as keyword arguments and default value. 73 * 74 * \param OpName the name of registry. 75 */ 76 #define RELAY_REGISTER_BINARY_OP(OpName) \ 77 TVM_REGISTER_API("relay.op._make." OpName) \ 78 .set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \ 79 static const Op& op = Op::Get(OpName); \ 80 return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ 81 }); \ 82 RELAY_REGISTER_OP(OpName) \ 83 .set_num_inputs(2) \ 84 .add_argument("lhs", "Tensor", "The left hand side tensor.") \ 85 .add_argument("rhs", "Tensor", "The right hand side tensor.") \ 86 .add_type_rel("Broadcast", BroadcastRel) \ 87 .set_attr<TOpPattern>("TOpPattern", kBroadcast) \ 88 .set_attr<TOpIsStateful>("TOpIsStateful", false) \ 89 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \ 90 BinaryBroadcastLayout) 91 92 // Comparisons 93 #define RELAY_REGISTER_CMP_OP(OpName) \ 94 TVM_REGISTER_API("relay.op._make." OpName) \ 95 .set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \ 96 static const Op& op = Op::Get(OpName); \ 97 return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ 98 }); \ 99 RELAY_REGISTER_OP(OpName) \ 100 .set_num_inputs(2) \ 101 .add_argument("lhs", "Tensor", "The left hand side tensor.") \ 102 .add_argument("rhs", "Tensor", "The right hand side tensor.") \ 103 .add_type_rel("BroadcastComp", BroadcastCompRel) \ 104 .set_attr<TOpPattern>("TOpPattern", kBroadcast) \ 105 .set_attr<TOpIsStateful>("TOpIsStateful", false) \ 106 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \ 107 BinaryBroadcastLayout) 108 109 110 /*! \brief A helper class for matching and rewriting operators. */ 111 template<typename R> 112 class OpMatch { 113 public: 114 using MatchFunc = 115 std::function<R(const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_args)>; 116 117 /*! \brief Match an operator with the given name. 118 * \param op_name The name of the operator to match. 119 * \param func The function to execute when it matches. 120 * \return A self-reference for builder style API. 121 */ Match(const std::string & op_name,MatchFunc func)122 inline OpMatch& Match(const std::string& op_name, MatchFunc func) { 123 auto op = Op::Get(op_name); 124 match_map_.insert({op, func}); 125 return *this; 126 } 127 128 /*! \brief Rewrite a call operation based on the operator and the registered 129 * match functions. 130 * \param call The call to rewrite. 131 * \return The result of rewriting. 132 */ operator()133 inline R operator()(const Call& call) { 134 auto it = match_map_.find(Downcast<Op>(call->op)); 135 if (it != match_map_.end()) { 136 return it->second(call->args, call->attrs, call->type_args); 137 } else { 138 if (default_ != nullptr) { 139 return default_(call->args, call->attrs, call->type_args); 140 } else { 141 LOG(FATAL) << "unexpected operation " << call->op; 142 } 143 } 144 } 145 146 private: 147 /*! \brief The match function map. */ 148 std::unordered_map<Op, MatchFunc, NodeHash, NodeEqual> match_map_; 149 /*! \brief An optional default case. */ 150 MatchFunc default_; 151 }; 152 153 } // namespace relay 154 } // namespace tvm 155 156 #endif // TVM_RELAY_OP_OP_COMMON_H_ 157