1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * Lower TVM related buildin intrinsics such as packed call.
22 * \file lower_tvm_buildin.cc
23 */
24 #include <tvm/ir.h>
25 #include <tvm/ir_mutator.h>
26 #include <tvm/ir_pass.h>
27 #include <unordered_set>
28 #include "ir_util.h"
29 #include "../arithmetic/compute_expr.h"
30
31 namespace tvm {
32 namespace ir {
33
ConstInt32(size_t index)34 inline Expr ConstInt32(size_t index) {
35 CHECK_LE(index, std::numeric_limits<int>::max());
36 return make_const(Int(32), static_cast<int>(index));
37 }
38
StackAlloca(std::string type,size_t num)39 inline Expr StackAlloca(std::string type, size_t num) {
40 Array<Expr> args = {StringImm::make(type), ConstInt32(num)};
41 return Call::make(Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic);
42 }
43
44 // Calculate the statistics of packed function.
45 // These information are needed during codegen.
46 class BuiltinLower : public IRMutator {
47 public:
Build(Stmt stmt)48 Stmt Build(Stmt stmt) {
49 stack_shape_ = Var("stack_shape", Handle());
50 stack_array_ = Var("stack_array", Handle());
51 stack_value_ = Var("stack_value", Handle());
52 stack_tcode_ = Var("stack_tcode", Handle());
53 stmt = this->Mutate(stmt);
54 if (max_shape_stack_ != 0) {
55 stmt = LetStmt::make(
56 stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
57 }
58 if (max_array_stack_ != 0) {
59 stmt = LetStmt::make(
60 stack_array_, StackAlloca("array", max_array_stack_), stmt);
61 }
62 if (max_arg_stack_ != 0) {
63 stmt = LetStmt::make(
64 stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
65 stmt = LetStmt::make(
66 stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
67 }
68 return stmt;
69 }
70
Mutate(Stmt stmt)71 Stmt Mutate(Stmt stmt) final {
72 stmt = IRMutator::Mutate(stmt);
73 CHECK_EQ(run_shape_stack_, 0);
74 CHECK_EQ(run_array_stack_, 0);
75 while (prep_seq_.size() != 0) {
76 stmt = Block::make(prep_seq_.back(), stmt);
77 prep_seq_.pop_back();
78 }
79 return stmt;
80 }
81
Mutate_(const Allocate * op,const Stmt & s)82 Stmt Mutate_(const Allocate* op, const Stmt& s) {
83 // Lower allocate to device allocate when needed.
84 Stmt stmt = IRMutator::Mutate_(op, s);
85 op = stmt.as<Allocate>();
86 if (op->new_expr.defined()) return stmt;
87 // Get constant allocation bound.
88 int64_t dev_type;
89 int64_t nbytes = GetVectorBytes(op->type);
90 if (device_type_.defined()) {
91 if (arith::GetConst(device_type_, &dev_type)) {
92 if (dev_type == kDLCPU) {
93 int32_t constant_size = op->constant_allocation_size();
94 if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
95 return stmt;
96 }
97 }
98 }
99 }
100 Expr total_bytes = make_const(op->extents[0].type(), nbytes);
101 for (size_t i = 0; i < op->extents.size(); ++i) {
102 total_bytes = total_bytes * op->extents[i];
103 }
104 CHECK(device_type_.defined()) << "Unknown device type in current IR";
105 CHECK(device_id_.defined()) << "Unknown device id in current IR";
106 Stmt throw_last_error = Evaluate::make(Call::make(Int(32),
107 intrinsic::tvm_throw_last_error, {},
108 Call::Intrinsic));
109
110 Stmt body = Block::make(
111 IfThenElse::make(Call::make(Bool(1),
112 intrinsic::tvm_handle_is_null,
113 {op->buffer_var}, Call::PureIntrinsic),
114 throw_last_error),
115 op->body);
116
117 Stmt alloca = LetStmt::make(
118 op->buffer_var,
119 Call::make(op->buffer_var.type(),
120 "TVMBackendAllocWorkspace",
121 {cast(Int(32), device_type_),
122 cast(Int(32), device_id_),
123 cast(UInt(64), total_bytes),
124 IntImm::make(Int(32), op->type.code()),
125 IntImm::make(Int(32), op->type.bits())},
126 Call::Extern),
127 body);
128
129 Expr free_op = Call::make(Int(32),
130 "TVMBackendFreeWorkspace",
131 {cast(Int(32), device_type_),
132 cast(Int(32), device_id_),
133 op->buffer_var},
134 Call::Extern);
135 Stmt free_stmt = IfThenElse::make(free_op != make_zero(Int(32)), throw_last_error);
136 body = Block::make(alloca, free_stmt);
137 body = AttrStmt::make(
138 op->buffer_var, attr::storage_alignment,
139 make_const(Int(32), runtime::kTempAllocaAlignment),
140 body);
141 return body;
142 }
143
Mutate_(const AttrStmt * op,const Stmt & s)144 Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
145 if (op->attr_key == attr::device_context_id) {
146 CHECK(!device_id_.defined());
147 device_id_ = op->value;
148 return Mutate(op->body);
149 } else if (op->attr_key == attr::device_context_type) {
150 CHECK(!device_type_.defined());
151 device_type_ = op->value;
152 return Mutate(op->body);
153 } else {
154 return IRMutator::Mutate_(op, s);
155 }
156 }
Mutate_(const Call * op,const Expr & e)157 Expr Mutate_(const Call* op, const Expr &e) final {
158 if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
159 return MakeCallPacked(op, e);
160 } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) {
161 return MakeCallTracePacked(op, e);
162 } else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
163 return MakeShape(op, e);
164 } else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
165 return MakeArray(op, e);
166 } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
167 return make_zero(op->type);
168 } else {
169 return IRMutator::Mutate_(op, e);
170 }
171 }
172 // call shape
MakeShape(const Call * op,const Expr & e)173 Expr MakeShape(const Call* op, const Expr& e) {
174 size_t stack_begin = run_shape_stack_;
175 run_shape_stack_ += op->args.size();
176 Expr expr = IRMutator::Mutate_(op, e);
177 op = expr.as<Call>();
178 for (size_t i = 0; i < op->args.size(); ++i) {
179 prep_seq_.emplace_back(
180 Store::make(stack_shape_, cast(Int(64), op->args[i]),
181 ConstInt32(stack_begin +i), const_true(1)));
182 }
183 return AddressOffset(stack_shape_, Int(64), stack_begin);
184 }
185 // make array
MakeArray(const Call * op,const Expr & e)186 Expr MakeArray(const Call* op, const Expr& e) {
187 size_t idx = run_array_stack_;
188 run_array_stack_ += 1;
189 Expr expr = IRMutator::Mutate_(op, e);
190 op = expr.as<Call>();
191 prep_seq_.emplace_back(
192 TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
193 prep_seq_.emplace_back(
194 TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
195 Expr strides = op->args[2];
196 if (!strides.defined() || is_zero(strides)) {
197 strides = make_zero(Handle());
198 }
199 prep_seq_.emplace_back(
200 TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides));
201 prep_seq_.emplace_back(
202 TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3]));
203 Type dtype = op->args[4].type();
204 prep_seq_.emplace_back(
205 TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode,
206 make_const(UInt(8), static_cast<int>(dtype.code()))));
207 prep_seq_.emplace_back(
208 TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits,
209 make_const(UInt(8), dtype.bits())));
210 prep_seq_.emplace_back(
211 TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
212 make_const(UInt(16), dtype.lanes())));
213 // set byte offset
214 int data_bytes = GetVectorBytes(dtype);
215 Expr byte_offset = op->args[5];
216 if (!is_zero(byte_offset)) {
217 byte_offset = byte_offset * make_const(byte_offset.type(), data_bytes);
218 }
219 prep_seq_.emplace_back(
220 TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
221 cast(UInt(64), byte_offset)));
222 CHECK(device_type_.defined()) << "Unknown device type in current IR";
223 CHECK(device_id_.defined()) << "Unknown device id in current IR";
224 prep_seq_.emplace_back(
225 TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
226 cast(Int(32), device_id_)));
227 prep_seq_.emplace_back(
228 TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
229 cast(Int(32), device_type_)));
230 return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr);
231 }
232 // call packed.
MakeCallPacked(const Call * op,const Expr & e)233 Expr MakeCallPacked(const Call* op, const Expr& e) {
234 size_t restore_shape_stack = run_shape_stack_;
235 size_t restore_array_stack = run_array_stack_;
236 size_t arg_stack_begin = run_arg_stack_;
237 run_arg_stack_ += op->args.size();
238 // Specially handle the buffer packed intrinsic
239 Expr expr = IRMutator::Mutate_(op, e);
240 op = expr.as<Call>();
241 for (size_t i = 1; i < op->args.size(); ++i) {
242 Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
243 Expr arg = op->args[i];
244 Type t = arg.type();
245 Type api_type = APIType(t);
246 if (t != api_type) {
247 arg = Cast::make(api_type, arg);
248 }
249 prep_seq_.emplace_back(TVMStructSet(
250 stack_value_, static_cast<int>(arg_stack_begin + i - 1),
251 intrinsic::kTVMValueContent, arg));
252 int arg_tcode = api_type.code();
253 if (api_type.is_handle() && arg.as<StringImm>()) {
254 arg_tcode = kStr;
255 }
256 if (IsArrayHandle(arg)) arg_tcode = kArrayHandle;
257 prep_seq_.emplace_back(
258 Store::make(stack_tcode_,
259 ConstInt32(arg_tcode),
260 stack_index, const_true(1)));
261 }
262 // UPDATE stack value
263 max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
264 max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
265 max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
266 run_shape_stack_ = restore_shape_stack;
267 run_array_stack_ = restore_array_stack;
268 run_arg_stack_ = arg_stack_begin;
269 Array<Expr> packed_args = {
270 op->args[0],
271 stack_value_,
272 stack_tcode_,
273 ConstInt32(arg_stack_begin),
274 ConstInt32(arg_stack_begin + op->args.size() - 1)
275 };
276 return Call::make(
277 Int(32), intrinsic::tvm_call_packed_lowered,
278 packed_args, Call::Intrinsic);
279 }
280
MakeCallTracePacked(const Call * op,const Expr & e)281 Expr MakeCallTracePacked(const Call *op, const Expr &e) {
282 size_t restore_shape_stack = run_shape_stack_;
283 size_t restore_array_stack = run_array_stack_;
284 size_t arg_stack_begin = run_arg_stack_;
285 run_arg_stack_ += op->args.size();
286 size_t args_size = op->args.size();
287 CHECK_GT(args_size, 0);
288 Expr expr = IRMutator::Mutate_(op, e);
289 op = expr.as<Call>();
290 for (size_t i = 1; i < op->args.size(); ++i) {
291 Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
292 Expr arg = op->args[i];
293 Type t = arg.type();
294 Type api_type = APIType(t);
295 if (t != api_type) {
296 arg = Cast::make(api_type, arg);
297 }
298 prep_seq_.emplace_back(TVMStructSet(
299 stack_value_, static_cast<int>(arg_stack_begin + i - 1),
300 intrinsic::kTVMValueContent, arg));
301 int arg_tcode = api_type.code();
302 CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
303 prep_seq_.emplace_back(
304 Store::make(stack_tcode_,
305 ConstInt32(arg_tcode),
306 stack_index, const_true(1)));
307 }
308 // UPDATE stack value
309 max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
310 max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
311 max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
312 run_shape_stack_ = restore_shape_stack;
313 run_array_stack_ = restore_array_stack;
314 // Update the top of the stack, so we can use more than one
315 // packed function's arguments with the one stack.
316 run_arg_stack_ = arg_stack_begin + args_size - 1;
317 Array<Expr> packed_args = {
318 op->args[0],
319 stack_value_,
320 stack_tcode_,
321 ConstInt32(arg_stack_begin),
322 ConstInt32(arg_stack_begin + op->args.size() - 1),
323 // Pass traced value.
324 op->args[args_size - 1]
325 };
326 return Call::make(
327 op->type, intrinsic::tvm_call_trace_packed_lowered,
328 packed_args, Call::Intrinsic);
329 }
330
331 private:
IsArrayHandle(const Expr & arg)332 bool IsArrayHandle(const Expr& arg) {
333 // specially set array handle.
334 if (const Call* buf = arg.as<Call>()) {
335 if (buf->is_intrinsic(intrinsic::tvm_struct_get) &&
336 buf->args[2].as<IntImm>()->value == intrinsic::kArrAddr) {
337 return true;
338 }
339 }
340 return false;
341 }
342
343 // The prepration sequence to be emitted.
344 std::vector<Stmt> prep_seq_;
345 Expr device_type_;
346 Expr device_id_;
347 // Var handle for each stack.
348 Var stack_shape_;
349 Var stack_array_;
350 Var stack_tcode_;
351 Var stack_value_;
352 // The running statistics
353 uint64_t run_shape_stack_{0};
354 uint64_t run_array_stack_{0};
355 uint64_t run_arg_stack_{0};
356 // statistics of stacks
357 uint64_t max_shape_stack_{0};
358 uint64_t max_array_stack_{0};
359 uint64_t max_arg_stack_{0};
360 };
361
LowerTVMBuiltin(LoweredFunc f)362 LoweredFunc LowerTVMBuiltin(LoweredFunc f) {
363 auto n = make_node<LoweredFuncNode>(*f.operator->());
364 n->body = BuiltinLower().Build(n->body);
365 return LoweredFunc(n);
366 }
367
368 } // namespace ir
369 } // namespace tvm
370