1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tvm/target/generic_func.h
22 * \brief Generic function that can be specialzied on a per target basis.
23 */
24 #ifndef TVM_TARGET_GENERIC_FUNC_H_
25 #define TVM_TARGET_GENERIC_FUNC_H_
26
27 #include <tvm/runtime/packed_func.h>
28 #include <tvm/support/with.h>
29 #include <tvm/target/target.h>
30
31 #include <string>
32 #include <unordered_map>
33 #include <utility>
34 #include <vector>
35
36 namespace tvm {
37
38 class GenericFuncNode;
39
40 /*!
41 * \brief Generic function that can be specialized on a per-target basis.
42 */
43 class GenericFunc : public ObjectRef {
44 public:
GenericFunc()45 GenericFunc() {}
GenericFunc(ObjectPtr<Object> n)46 explicit GenericFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
47
48 /*!
49 * \brief Set the default function implementaiton.
50 * \param value The default function
51 * \param allow_override If true, this call may override a previously registered function. If
52 * false, an error will be logged if the call would override a previously registered function.
53 * \return reference to self.
54 */
55 TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false);
56 /*!
57 * \brief Register a specialized function
58 * \param tags The tags for this specialization
59 * \param value The specialized function
60 * \param allow_override If true, this call may override previously registered tags. If false,
61 * an error will be logged if the call would override previously registered tags.
62 * \return reference to self.
63 */
64 TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
65 const runtime::PackedFunc value, bool allow_override = false);
66 /*!
67 * \brief Call generic function by directly passing in unpacked format.
68 * \param args Arguments to be passed.
69 * \tparam Args arguments to be passed.
70 *
71 * \code
72 * // Example code on how to call generic function
73 * void CallGeneric(GenericFunc f) {
74 * // call like normal functions by pass in arguments
75 * // return value is automatically converted back
76 * int rvalue = f(1, 2.0);
77 * }
78 * \endcode
79 */
80 template <typename... Args>
81 inline runtime::TVMRetValue operator()(Args&&... args) const;
82 /*!
83 * \brief Invoke the relevant function for the current target context, set by set_target_context.
84 * Arguments are passed in packed format.
85 * \param args The arguments to pass to the function.
86 * \param ret The return value
87 */
88 TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const;
89
90 /*!
91 * \brief Find or register the GenericFunc instance corresponding to the give name
92 * \param name The name of the registered GenericFunc
93 * \return The GenericFunc instance
94 */
95 TVM_DLL static GenericFunc Get(const std::string& name);
96
97 /*!
98 * \brief Add a GenericFunc instance to the registry
99 * \param func The GenericFunc instance
100 * \param name The name of the registered GenericFunc
101 */
102 TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);
103
104 /*!
105 * \brief access the internal node container
106 * \return the pointer to the internal node container
107 */
108 inline GenericFuncNode* operator->();
109
110 // declare container type
111 using ContainerType = GenericFuncNode;
112
113 // Internal class.
114 struct Manager;
115
116 private:
117 friend struct Manager;
118 };
119
120 template <typename... Args>
operator()121 inline runtime::TVMRetValue GenericFunc::operator()(Args&&... args) const {
122 const int kNumArgs = sizeof...(Args);
123 const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
124 TVMValue values[kArraySize];
125 int type_codes[kArraySize];
126 runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes),
127 std::forward<Args>(args)...);
128 runtime::TVMRetValue rv;
129 CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv);
130 return rv;
131 }
132
133 /*!
134 * \brief Represents a generic function that can be specialized on a per-target basis.
135 */
136 class GenericFuncNode : public Object {
137 public:
138 /*! \brief name of the function */
139 std::string name_;
140 /* \brief the generic builder */
141 runtime::PackedFunc generic_func_;
142 /* \brief map from keys to registered functions */
143 std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
144
VisitAttrs(AttrVisitor * v)145 void VisitAttrs(AttrVisitor* v) {}
146
147 static constexpr const char* _type_key = "GenericFunc";
148 TVM_DECLARE_FINAL_OBJECT_INFO(GenericFuncNode, Object);
149 };
150
151 inline GenericFuncNode* GenericFunc::operator->() {
152 return static_cast<GenericFuncNode*>(get_mutable());
153 }
154
155 #define TVM_GENERIC_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_##TVM
156
157 /*!
158 * \def TVM_REGISTER_GENERIC_FUNC
159 * \brief Register a new generic function, or set a device-specific variant
160 * of the corresponding function.
161 *
162 * \param name The name of the function
163 */
164 #define TVM_REGISTER_GENERIC_FUNC(name) \
165 TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::GenericFunc::Get(#name)
166
167 } // namespace tvm
168 #endif // TVM_TARGET_GENERIC_FUNC_H_
169