1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/target/generic_func.h
22  * \brief Generic function that can be specialzied on a per target basis.
23  */
24 #ifndef TVM_TARGET_GENERIC_FUNC_H_
25 #define TVM_TARGET_GENERIC_FUNC_H_
26 
27 #include <tvm/runtime/packed_func.h>
28 #include <tvm/support/with.h>
29 #include <tvm/target/target.h>
30 
31 #include <string>
32 #include <unordered_map>
33 #include <utility>
34 #include <vector>
35 
36 namespace tvm {
37 
38 class GenericFuncNode;
39 
40 /*!
41  * \brief Generic function that can be specialized on a per-target basis.
42  */
43 class GenericFunc : public ObjectRef {
44  public:
GenericFunc()45   GenericFunc() {}
GenericFunc(ObjectPtr<Object> n)46   explicit GenericFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
47 
48   /*!
49    * \brief Set the default function implementaiton.
50    * \param value The default function
51    * \param allow_override If true, this call may override a previously registered function. If
52    * false, an error will be logged if the call would override a previously registered function.
53    * \return reference to self.
54    */
55   TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false);
56   /*!
57    * \brief Register a specialized function
58    * \param tags The tags for this specialization
59    * \param value The specialized function
60    * \param allow_override If true, this call may override previously registered tags. If false,
61    * an error will be logged if the call would override previously registered tags.
62    * \return reference to self.
63    */
64   TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
65                                      const runtime::PackedFunc value, bool allow_override = false);
66   /*!
67    * \brief Call generic function by directly passing in unpacked format.
68    * \param args Arguments to be passed.
69    * \tparam Args arguments to be passed.
70    *
71    * \code
72    *   // Example code on how to call generic function
73    *   void CallGeneric(GenericFunc f) {
74    *     // call like normal functions by pass in arguments
75    *     // return value is automatically converted back
76    *     int rvalue = f(1, 2.0);
77    *   }
78    * \endcode
79    */
80   template <typename... Args>
81   inline runtime::TVMRetValue operator()(Args&&... args) const;
82   /*!
83    * \brief Invoke the relevant function for the current target context, set by set_target_context.
84    * Arguments are passed in packed format.
85    * \param args The arguments to pass to the function.
86    * \param ret The return value
87    */
88   TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const;
89 
90   /*!
91    * \brief Find or register the GenericFunc instance corresponding to the give name
92    * \param name The name of the registered GenericFunc
93    * \return The GenericFunc instance
94    */
95   TVM_DLL static GenericFunc Get(const std::string& name);
96 
97   /*!
98    * \brief Add a GenericFunc instance to the registry
99    * \param func The GenericFunc instance
100    * \param name The name of the registered GenericFunc
101    */
102   TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);
103 
104   /*!
105    * \brief access the internal node container
106    * \return the pointer to the internal node container
107    */
108   inline GenericFuncNode* operator->();
109 
110   // declare container type
111   using ContainerType = GenericFuncNode;
112 
113   // Internal class.
114   struct Manager;
115 
116  private:
117   friend struct Manager;
118 };
119 
120 template <typename... Args>
operator()121 inline runtime::TVMRetValue GenericFunc::operator()(Args&&... args) const {
122   const int kNumArgs = sizeof...(Args);
123   const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
124   TVMValue values[kArraySize];
125   int type_codes[kArraySize];
126   runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes),
127                             std::forward<Args>(args)...);
128   runtime::TVMRetValue rv;
129   CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv);
130   return rv;
131 }
132 
133 /*!
134  * \brief Represents a generic function that can be specialized on a per-target basis.
135  */
136 class GenericFuncNode : public Object {
137  public:
138   /*! \brief name of the function */
139   std::string name_;
140   /* \brief the generic builder */
141   runtime::PackedFunc generic_func_;
142   /* \brief map from keys to registered functions */
143   std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
144 
VisitAttrs(AttrVisitor * v)145   void VisitAttrs(AttrVisitor* v) {}
146 
147   static constexpr const char* _type_key = "GenericFunc";
148   TVM_DECLARE_FINAL_OBJECT_INFO(GenericFuncNode, Object);
149 };
150 
151 inline GenericFuncNode* GenericFunc::operator->() {
152   return static_cast<GenericFuncNode*>(get_mutable());
153 }
154 
155 #define TVM_GENERIC_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_##TVM
156 
157 /*!
158  * \def TVM_REGISTER_GENERIC_FUNC
159  * \brief Register a new generic function, or set a device-specific variant
160  * of the corresponding function.
161  *
162  * \param name The name of the function
163  */
164 #define TVM_REGISTER_GENERIC_FUNC(name) \
165   TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::GenericFunc::Get(#name)
166 
167 }  // namespace tvm
168 #endif  // TVM_TARGET_GENERIC_FUNC_H_
169