1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/relay/op_strategy.h
22  * \brief The Relay operator Strategy and related data structure.
23  */
24 
25 #ifndef TVM_RELAY_OP_STRATEGY_H_
26 #define TVM_RELAY_OP_STRATEGY_H_
27 
28 #include <tvm/relay/expr.h>
29 #include <tvm/relay/op_attr_types.h>
30 #include <tvm/target/target.h>
31 #include <tvm/te/schedule.h>
32 #include <tvm/te/tensor.h>
33 
34 #include <string>
35 
36 namespace tvm {
37 namespace relay {
38 
39 /*!
40  * \brief Operator implementation that includes compute and schedule function.
41  */
42 class OpImplementationNode : public Object {
43  public:
44   /*! \brief Compute function */
45   FTVMCompute fcompute;
46   /*! \brief Schedule function */
47   FTVMSchedule fschedule;
48   /*! \brief Name of the implementation */
49   String name;
50   /*! \brief Priority level */
51   int plevel;
52 
VisitAttrs(tvm::AttrVisitor * v)53   void VisitAttrs(tvm::AttrVisitor* v) {
54     v->Visit("name", &name);
55     v->Visit("plevel", &plevel);
56   }
57 
58   static constexpr const char* _type_key = "relay.OpImplementation";
59   TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object);
60 };
61 
62 /*!
63  * \brief Operator implementation class.
64  */
65 class OpImplementation : public ObjectRef {
66  public:
67   /*!
68    * \brief Invoke the operator compute function.
69    * \param attrs The attribute of the primitive
70    * \param inputs The input tensors.
71    * \param out_type The output type information.
72    * \return The output compute description of the operator.
73    */
74   TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs, const Array<te::Tensor>& inputs,
75                                     const Type& out_type);
76   /*!
77    * \brief Build the computation schedule.
78    * \param attrs The attribute of the node.
79    * \param outs The output tensors.
80    * \param target The build target.
81    * \return The computation schedule.
82    */
83   TVM_DLL te::Schedule Schedule(const Attrs& attrs, const Array<te::Tensor>& outs,
84                                 const Target& target);
85 
86   TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode);
87 };
88 
89 /*!
90  * \brief Specialized implementations for operators under certain conditions.
91  */
92 class OpSpecializationNode : public Object {
93  public:
94   /*! \brief List of implementations. */
95   Array<OpImplementation> implementations;
96   /*! \brief Condition to enable the specialization.
97    *    Could be undefined to represent generic case. */
98   te::SpecializedCondition condition;
99 
VisitAttrs(tvm::AttrVisitor * v)100   void VisitAttrs(tvm::AttrVisitor* v) {
101     v->Visit("condition", &condition);
102     v->Visit("implementations", &implementations);
103   }
104 
105   static constexpr const char* _type_key = "relay.OpSpecialization";
106   TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode);
107 };
108 
109 /*!
110  * \brief Operator specialization class.
111  */
112 class OpSpecialization : public ObjectRef {
113  public:
114   /*!
115    * \brief Add an implementation.
116    * \param fcompute Compute function
117    * \param fschedule Schedule function
118    * \param name Name of the implementation
119    * \param plevel Priority level of the implementation
120    */
121   TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name,
122                                  int plevel);
123 
124   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode);
125 };
126 
127 /*!
128  * \brief Operator strategy to choose implementation.
129  */
130 class OpStrategyNode : public Object {
131  public:
132   /*! \brief List of operator specializations. */
133   Array<OpSpecialization> specializations;
134 
VisitAttrs(tvm::AttrVisitor * v)135   void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("specializations", &specializations); }
136 
137   static constexpr const char* _type_key = "relay.OpStrategy";
138   TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode);
139 };
140 
141 /*!
142  * \brief Operator strategy class.
143  */
144 class OpStrategy : public ObjectRef {
145  public:
146   /*!
147    * \brief Add an implementation.
148    * \param fcompute Compute function
149    * \param fschedule Schedule function
150    * \param name Name of the implementation
151    * \param plevel Priority level of the implementation
152    */
153   TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name,
154                                  int plevel);
155 
156   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode);
157 };
158 
159 }  // namespace relay
160 }  // namespace tvm
161 #endif  // TVM_RELAY_OP_STRATEGY_H_
162