1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file detail/extern.h
22 * \brief Helpers for using external functions
23 */
24 #ifndef TOPI_DETAIL_EXTERN_H_
25 #define TOPI_DETAIL_EXTERN_H_
26
27 #include <tvm/operation.h>
28 #include <vector>
29 #include <string>
30
31
32 namespace topi {
33 namespace detail {
34 using namespace tvm;
35
36 /*!
37 * \brief Construct a buffer to pass to an external function
38 *
39 * \param shape The shape of the buffer
40 * \param dtype The type of the buffer elements
41 * \param name The name of the buffer
42 *
43 * \return The Buffer object
44 */
DeclExternBuffer(Array<Expr> shape,Type dtype,std::string name)45 inline Buffer DeclExternBuffer(Array<Expr> shape,
46 Type dtype,
47 std::string name) {
48 auto data = var(name, Handle());
49 auto elem_offset = Expr();
50 return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
51 -1, 0, kDefault);
52 }
53
54 /*!
55 * \brief A function which constructs an Expr representing the invocation of an external
56 * function. The function expects two arguments: an array of Buffers holding the input
57 * tensor values, and a pre-allocated array of Buffers to be filled with the outputs.
58 */
59 using FExtern = std::function<Expr(Array<Buffer>, Array<Buffer>)>;
60
61 /*!
62 * \brief Create tensors representing the result of invoking an external function.
63 * This function will create the necessary buffers to hold input and output tensor values.
64 *
65 * \param out_shapes An array where each element is the shape of the corresponding output tensor.
66 * \param out_types An array where each element is the dtype of the corresponding output tensor.
67 * \param inputs An array of input Tensors
68 * \param fextern A function that constructs an Expr representing the invocation of
69 * the external function given the input and output buffers.
70 * \param name The name of the operation
71 * \param tag The tag to mark the operation
72 * \param attrs The additional auxiliary attributes of the operation.
73 *
74 * \return An array of Tensors representing the outputs of the function invocation. There will
75 * be one output Tensor for each element of out_shapes, with dtype equal to the corresponding
76 * element of out_types.
77 */
make_extern(const Array<Array<Expr>> & out_shapes,const std::vector<Type> & out_types,const Array<Tensor> & inputs,FExtern fextern,std::string name,std::string tag,::tvm::Map<std::string,NodeRef> attrs)78 inline Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
79 const std::vector<Type>& out_types,
80 const Array<Tensor>& inputs,
81 FExtern fextern,
82 std::string name,
83 std::string tag,
84 ::tvm::Map<std::string, NodeRef> attrs) {
85 CHECK_EQ(out_shapes.size(), out_types.size())
86 << "make_extern: out_shapes and out_types must have equal size";
87
88 Array<Buffer> input_placeholders;
89 for (auto t : inputs) {
90 input_placeholders.push_back(DeclExternBuffer(t->shape, t->dtype, t->op->name));
91 }
92 Array<Buffer> output_placeholders;
93 for (size_t i = 0; i < out_shapes.size(); ++i) {
94 output_placeholders.push_back(DeclExternBuffer(out_shapes[i], out_types[i], name));
95 }
96
97 auto body = fextern(input_placeholders, output_placeholders);
98 auto body_stmt = tvm::ir::Evaluate::make(body);
99
100 auto op = ExternOpNode::make(
101 name, tag, attrs, inputs,
102 input_placeholders, output_placeholders, body_stmt);
103
104 Array<Tensor> outputs;
105 for (size_t i = 0; i < output_placeholders.size(); ++i) {
106 outputs.push_back(op.output(i));
107 }
108 return outputs;
109 }
110
111 /*!
112 * \brief This function is used to create a DLTensor structure on the stack to
113 * be able to pass a symbolic buffer as arguments to TVM PackedFunc
114 *
115 * \param buf The buffer to pack
116 *
117 * \return An expression representing the pack operation
118 */
pack_buffer(Buffer buf)119 inline Expr pack_buffer(Buffer buf) {
120 CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element";
121 auto shape = tvm::ir::Call::make(Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
122 buf->shape, tvm::ir::Call::CallType::Intrinsic);
123 Expr strides;
124 if (buf->strides.size() > 0) {
125 strides = tvm::ir::Call::make(Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
126 buf->shape, tvm::ir::Call::CallType::Intrinsic);
127 } else {
128 strides = 0;
129 }
130 Array<Expr> pack_args{
131 buf->data,
132 shape,
133 strides,
134 make_const(Int(32), static_cast<int64_t>(buf->shape.size())),
135 make_const(buf->dtype, 0),
136 buf->elem_offset
137 };
138 return tvm::ir::Call::make(Handle(), tvm::ir::intrinsic::tvm_stack_make_array,
139 pack_args, tvm::ir::Call::CallType::Intrinsic);
140 }
141
142 /*!
143 * \brief Construct an Expr representing the invocation of a PackedFunc
144 *
145 * \param args An array containing the registered name of the PackedFunc followed
146 * by the arguments to pass to the PackedFunc when called. The first element of the
147 * array must be a constant string expression.
148 *
149 * \return An expression representing the invocation
150 */
call_packed(Array<Expr> args)151 inline Expr call_packed(Array<Expr> args) {
152 return tvm::ir::Call::make(Int(32), tvm::ir::intrinsic::tvm_call_packed,
153 args, tvm::ir::Call::CallType::Intrinsic);
154 }
155
156 } // namespace detail
157 } // namespace topi
158 #endif // TOPI_DETAIL_EXTERN_H_
159