1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file tvm/relay/op_strategy.h 22 * \brief The Relay operator Strategy and related data structure. 23 */ 24 25 #ifndef TVM_RELAY_OP_STRATEGY_H_ 26 #define TVM_RELAY_OP_STRATEGY_H_ 27 28 #include <tvm/relay/expr.h> 29 #include <tvm/relay/op_attr_types.h> 30 #include <tvm/target/target.h> 31 #include <tvm/te/schedule.h> 32 #include <tvm/te/tensor.h> 33 34 #include <string> 35 36 namespace tvm { 37 namespace relay { 38 39 /*! 40 * \brief Operator implementation that includes compute and schedule function. 41 */ 42 class OpImplementationNode : public Object { 43 public: 44 /*! \brief Compute function */ 45 FTVMCompute fcompute; 46 /*! \brief Schedule function */ 47 FTVMSchedule fschedule; 48 /*! \brief Name of the implementation */ 49 String name; 50 /*! \brief Priority level */ 51 int plevel; 52 VisitAttrs(tvm::AttrVisitor * v)53 void VisitAttrs(tvm::AttrVisitor* v) { 54 v->Visit("name", &name); 55 v->Visit("plevel", &plevel); 56 } 57 58 static constexpr const char* _type_key = "relay.OpImplementation"; 59 TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object); 60 }; 61 62 /*! 63 * \brief Operator implementation class. 64 */ 65 class OpImplementation : public ObjectRef { 66 public: 67 /*! 68 * \brief Invoke the operator compute function. 69 * \param attrs The attribute of the primitive 70 * \param inputs The input tensors. 71 * \param out_type The output type information. 72 * \return The output compute description of the operator. 73 */ 74 TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs, const Array<te::Tensor>& inputs, 75 const Type& out_type); 76 /*! 77 * \brief Build the computation schedule. 78 * \param attrs The attribute of the node. 79 * \param outs The output tensors. 80 * \param target The build target. 81 * \return The computation schedule. 82 */ 83 TVM_DLL te::Schedule Schedule(const Attrs& attrs, const Array<te::Tensor>& outs, 84 const Target& target); 85 86 TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode); 87 }; 88 89 /*! 90 * \brief Specialized implementations for operators under certain conditions. 91 */ 92 class OpSpecializationNode : public Object { 93 public: 94 /*! \brief List of implementations. */ 95 Array<OpImplementation> implementations; 96 /*! \brief Condition to enable the specialization. 97 * Could be undefined to represent generic case. */ 98 te::SpecializedCondition condition; 99 VisitAttrs(tvm::AttrVisitor * v)100 void VisitAttrs(tvm::AttrVisitor* v) { 101 v->Visit("condition", &condition); 102 v->Visit("implementations", &implementations); 103 } 104 105 static constexpr const char* _type_key = "relay.OpSpecialization"; 106 TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode); 107 }; 108 109 /*! 110 * \brief Operator specialization class. 111 */ 112 class OpSpecialization : public ObjectRef { 113 public: 114 /*! 115 * \brief Add an implementation. 116 * \param fcompute Compute function 117 * \param fschedule Schedule function 118 * \param name Name of the implementation 119 * \param plevel Priority level of the implementation 120 */ 121 TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name, 122 int plevel); 123 124 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode); 125 }; 126 127 /*! 128 * \brief Operator strategy to choose implementation. 129 */ 130 class OpStrategyNode : public Object { 131 public: 132 /*! \brief List of operator specializations. */ 133 Array<OpSpecialization> specializations; 134 VisitAttrs(tvm::AttrVisitor * v)135 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("specializations", &specializations); } 136 137 static constexpr const char* _type_key = "relay.OpStrategy"; 138 TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode); 139 }; 140 141 /*! 142 * \brief Operator strategy class. 143 */ 144 class OpStrategy : public ObjectRef { 145 public: 146 /*! 147 * \brief Add an implementation. 148 * \param fcompute Compute function 149 * \param fschedule Schedule function 150 * \param name Name of the implementation 151 * \param plevel Priority level of the implementation 152 */ 153 TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name, 154 int plevel); 155 156 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode); 157 }; 158 159 } // namespace relay 160 } // namespace tvm 161 #endif // TVM_RELAY_OP_STRATEGY_H_ 162