1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *  Lower TVM related buildin intrinsics such as packed call.
22  * \file lower_tvm_buildin.cc
23  */
24 #include <tvm/ir.h>
25 #include <tvm/ir_mutator.h>
26 #include <tvm/ir_pass.h>
27 #include <unordered_set>
28 #include "ir_util.h"
29 #include "../arithmetic/compute_expr.h"
30 
31 namespace tvm {
32 namespace ir {
33 
ConstInt32(size_t index)34 inline Expr ConstInt32(size_t index) {
35   CHECK_LE(index, std::numeric_limits<int>::max());
36   return make_const(Int(32), static_cast<int>(index));
37 }
38 
StackAlloca(std::string type,size_t num)39 inline Expr StackAlloca(std::string type, size_t num) {
40   Array<Expr> args = {StringImm::make(type), ConstInt32(num)};
41   return Call::make(Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic);
42 }
43 
44 // Calculate the statistics of packed function.
45 // These information are needed during codegen.
46 class BuiltinLower : public IRMutator {
47  public:
Build(Stmt stmt)48   Stmt Build(Stmt stmt) {
49     stack_shape_ = Var("stack_shape", Handle());
50     stack_array_ = Var("stack_array", Handle());
51     stack_value_ = Var("stack_value", Handle());
52     stack_tcode_ = Var("stack_tcode", Handle());
53     stmt = this->Mutate(stmt);
54     if (max_shape_stack_ != 0) {
55       stmt = LetStmt::make(
56           stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
57     }
58     if (max_array_stack_ != 0) {
59       stmt = LetStmt::make(
60           stack_array_, StackAlloca("array", max_array_stack_), stmt);
61     }
62     if (max_arg_stack_ != 0) {
63       stmt = LetStmt::make(
64           stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
65       stmt = LetStmt::make(
66           stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
67     }
68     return stmt;
69   }
70 
Mutate(Stmt stmt)71   Stmt Mutate(Stmt stmt) final {
72     stmt = IRMutator::Mutate(stmt);
73     CHECK_EQ(run_shape_stack_, 0);
74     CHECK_EQ(run_array_stack_, 0);
75     while (prep_seq_.size() != 0) {
76       stmt = Block::make(prep_seq_.back(), stmt);
77       prep_seq_.pop_back();
78     }
79     return stmt;
80   }
81 
Mutate_(const Allocate * op,const Stmt & s)82   Stmt Mutate_(const Allocate* op, const Stmt& s) {
83     // Lower allocate to device allocate when needed.
84     Stmt stmt = IRMutator::Mutate_(op, s);
85     op = stmt.as<Allocate>();
86     if (op->new_expr.defined()) return stmt;
87     // Get constant allocation bound.
88     int64_t dev_type;
89     int64_t nbytes = GetVectorBytes(op->type);
90     if (device_type_.defined()) {
91       if (arith::GetConst(device_type_, &dev_type)) {
92         if (dev_type == kDLCPU) {
93           int32_t constant_size = op->constant_allocation_size();
94           if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
95             return stmt;
96           }
97         }
98       }
99     }
100     Expr total_bytes = make_const(op->extents[0].type(), nbytes);
101     for (size_t i = 0; i < op->extents.size(); ++i) {
102       total_bytes = total_bytes * op->extents[i];
103     }
104     CHECK(device_type_.defined()) << "Unknown device type in current IR";
105     CHECK(device_id_.defined()) << "Unknown device id in current IR";
106     Stmt throw_last_error = Evaluate::make(Call::make(Int(32),
107                                            intrinsic::tvm_throw_last_error, {},
108                                            Call::Intrinsic));
109 
110     Stmt body = Block::make(
111         IfThenElse::make(Call::make(Bool(1),
112                                     intrinsic::tvm_handle_is_null,
113                                     {op->buffer_var}, Call::PureIntrinsic),
114                          throw_last_error),
115         op->body);
116 
117     Stmt alloca = LetStmt::make(
118         op->buffer_var,
119         Call::make(op->buffer_var.type(),
120                    "TVMBackendAllocWorkspace",
121                    {cast(Int(32), device_type_),
122                     cast(Int(32), device_id_),
123                     cast(UInt(64), total_bytes),
124                     IntImm::make(Int(32), op->type.code()),
125                     IntImm::make(Int(32), op->type.bits())},
126                    Call::Extern),
127         body);
128 
129     Expr free_op = Call::make(Int(32),
130                               "TVMBackendFreeWorkspace",
131                               {cast(Int(32), device_type_),
132                                     cast(Int(32), device_id_),
133                                     op->buffer_var},
134                               Call::Extern);
135     Stmt free_stmt = IfThenElse::make(free_op != make_zero(Int(32)), throw_last_error);
136     body = Block::make(alloca, free_stmt);
137     body = AttrStmt::make(
138         op->buffer_var, attr::storage_alignment,
139         make_const(Int(32), runtime::kTempAllocaAlignment),
140         body);
141     return body;
142   }
143 
Mutate_(const AttrStmt * op,const Stmt & s)144   Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
145     if (op->attr_key == attr::device_context_id) {
146       CHECK(!device_id_.defined());
147       device_id_ = op->value;
148       return Mutate(op->body);
149     } else if (op->attr_key == attr::device_context_type) {
150       CHECK(!device_type_.defined());
151       device_type_ = op->value;
152       return Mutate(op->body);
153     } else {
154       return IRMutator::Mutate_(op, s);
155     }
156   }
Mutate_(const Call * op,const Expr & e)157   Expr Mutate_(const Call* op, const Expr &e) final {
158     if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
159       return MakeCallPacked(op, e);
160     } else if (op->is_intrinsic(intrinsic::tvm_call_trace_packed)) {
161       return MakeCallTracePacked(op, e);
162     } else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
163       return MakeShape(op, e);
164     } else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
165       return MakeArray(op, e);
166     } else if (op->is_intrinsic(intrinsic::tvm_context_id)) {
167       return make_zero(op->type);
168     } else {
169       return IRMutator::Mutate_(op, e);
170     }
171   }
172   // call shape
MakeShape(const Call * op,const Expr & e)173   Expr MakeShape(const Call* op, const Expr& e) {
174     size_t stack_begin = run_shape_stack_;
175     run_shape_stack_ += op->args.size();
176     Expr expr = IRMutator::Mutate_(op, e);
177     op = expr.as<Call>();
178     for (size_t i = 0; i < op->args.size(); ++i) {
179       prep_seq_.emplace_back(
180           Store::make(stack_shape_, cast(Int(64), op->args[i]),
181                       ConstInt32(stack_begin +i), const_true(1)));
182     }
183     return AddressOffset(stack_shape_, Int(64), stack_begin);
184   }
185   // make array
MakeArray(const Call * op,const Expr & e)186   Expr MakeArray(const Call* op, const Expr& e) {
187     size_t idx = run_array_stack_;
188     run_array_stack_ += 1;
189     Expr expr = IRMutator::Mutate_(op, e);
190     op = expr.as<Call>();
191     prep_seq_.emplace_back(
192         TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
193     prep_seq_.emplace_back(
194         TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
195     Expr strides = op->args[2];
196     if (!strides.defined() || is_zero(strides)) {
197       strides = make_zero(Handle());
198     }
199     prep_seq_.emplace_back(
200         TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides));
201     prep_seq_.emplace_back(
202         TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3]));
203     Type dtype = op->args[4].type();
204     prep_seq_.emplace_back(
205         TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode,
206                      make_const(UInt(8), static_cast<int>(dtype.code()))));
207     prep_seq_.emplace_back(
208         TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits,
209                      make_const(UInt(8), dtype.bits())));
210     prep_seq_.emplace_back(
211         TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
212                      make_const(UInt(16), dtype.lanes())));
213     // set byte offset
214     int data_bytes = GetVectorBytes(dtype);
215     Expr byte_offset = op->args[5];
216     if (!is_zero(byte_offset)) {
217       byte_offset = byte_offset * make_const(byte_offset.type(), data_bytes);
218     }
219     prep_seq_.emplace_back(
220         TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
221                      cast(UInt(64), byte_offset)));
222     CHECK(device_type_.defined()) << "Unknown device type in current IR";
223     CHECK(device_id_.defined()) << "Unknown device id in current IR";
224     prep_seq_.emplace_back(
225         TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
226                      cast(Int(32), device_id_)));
227     prep_seq_.emplace_back(
228         TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
229                      cast(Int(32), device_type_)));
230     return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr);
231   }
232   // call packed.
MakeCallPacked(const Call * op,const Expr & e)233   Expr MakeCallPacked(const Call* op, const Expr& e) {
234     size_t restore_shape_stack = run_shape_stack_;
235     size_t restore_array_stack = run_array_stack_;
236     size_t arg_stack_begin = run_arg_stack_;
237     run_arg_stack_ += op->args.size();
238     // Specially handle the buffer packed intrinsic
239     Expr expr = IRMutator::Mutate_(op, e);
240     op = expr.as<Call>();
241     for (size_t i = 1; i < op->args.size(); ++i) {
242       Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
243       Expr arg = op->args[i];
244       Type t = arg.type();
245       Type api_type = APIType(t);
246       if (t != api_type) {
247         arg = Cast::make(api_type, arg);
248       }
249       prep_seq_.emplace_back(TVMStructSet(
250           stack_value_, static_cast<int>(arg_stack_begin + i - 1),
251           intrinsic::kTVMValueContent, arg));
252       int arg_tcode = api_type.code();
253       if (api_type.is_handle() && arg.as<StringImm>()) {
254         arg_tcode = kStr;
255       }
256       if (IsArrayHandle(arg)) arg_tcode = kArrayHandle;
257       prep_seq_.emplace_back(
258           Store::make(stack_tcode_,
259                       ConstInt32(arg_tcode),
260                       stack_index, const_true(1)));
261     }
262     // UPDATE stack value
263     max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
264     max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
265     max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
266     run_shape_stack_ = restore_shape_stack;
267     run_array_stack_ = restore_array_stack;
268     run_arg_stack_ = arg_stack_begin;
269     Array<Expr> packed_args = {
270       op->args[0],
271       stack_value_,
272       stack_tcode_,
273       ConstInt32(arg_stack_begin),
274       ConstInt32(arg_stack_begin + op->args.size() - 1)
275     };
276     return Call::make(
277         Int(32), intrinsic::tvm_call_packed_lowered,
278         packed_args, Call::Intrinsic);
279   }
280 
MakeCallTracePacked(const Call * op,const Expr & e)281   Expr MakeCallTracePacked(const Call *op, const Expr &e) {
282     size_t restore_shape_stack = run_shape_stack_;
283     size_t restore_array_stack = run_array_stack_;
284     size_t arg_stack_begin = run_arg_stack_;
285     run_arg_stack_ += op->args.size();
286     size_t args_size = op->args.size();
287     CHECK_GT(args_size, 0);
288     Expr expr = IRMutator::Mutate_(op, e);
289     op = expr.as<Call>();
290     for (size_t i = 1; i < op->args.size(); ++i) {
291       Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
292       Expr arg = op->args[i];
293       Type t = arg.type();
294       Type api_type = APIType(t);
295       if (t != api_type) {
296         arg = Cast::make(api_type, arg);
297       }
298       prep_seq_.emplace_back(TVMStructSet(
299           stack_value_, static_cast<int>(arg_stack_begin + i - 1),
300           intrinsic::kTVMValueContent, arg));
301       int arg_tcode = api_type.code();
302       CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers";
303       prep_seq_.emplace_back(
304           Store::make(stack_tcode_,
305                       ConstInt32(arg_tcode),
306                       stack_index, const_true(1)));
307     }
308     // UPDATE stack value
309     max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
310     max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
311     max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
312     run_shape_stack_ = restore_shape_stack;
313     run_array_stack_ = restore_array_stack;
314     // Update the top of the stack, so we can use more than one
315     // packed function's arguments with the one stack.
316     run_arg_stack_ = arg_stack_begin + args_size - 1;
317     Array<Expr> packed_args = {
318       op->args[0],
319       stack_value_,
320       stack_tcode_,
321       ConstInt32(arg_stack_begin),
322       ConstInt32(arg_stack_begin + op->args.size() - 1),
323       // Pass traced value.
324       op->args[args_size - 1]
325     };
326     return Call::make(
327         op->type, intrinsic::tvm_call_trace_packed_lowered,
328         packed_args, Call::Intrinsic);
329   }
330 
331  private:
IsArrayHandle(const Expr & arg)332   bool IsArrayHandle(const Expr& arg) {
333     // specially set array handle.
334     if (const Call* buf = arg.as<Call>()) {
335       if (buf->is_intrinsic(intrinsic::tvm_struct_get) &&
336           buf->args[2].as<IntImm>()->value == intrinsic::kArrAddr) {
337         return true;
338       }
339     }
340     return false;
341   }
342 
343   // The prepration sequence to be emitted.
344   std::vector<Stmt> prep_seq_;
345   Expr device_type_;
346   Expr device_id_;
347   // Var handle for each stack.
348   Var stack_shape_;
349   Var stack_array_;
350   Var stack_tcode_;
351   Var stack_value_;
352   // The running statistics
353   uint64_t run_shape_stack_{0};
354   uint64_t run_array_stack_{0};
355   uint64_t run_arg_stack_{0};
356   // statistics of stacks
357   uint64_t max_shape_stack_{0};
358   uint64_t max_array_stack_{0};
359   uint64_t max_arg_stack_{0};
360 };
361 
LowerTVMBuiltin(LoweredFunc f)362 LoweredFunc LowerTVMBuiltin(LoweredFunc f) {
363   auto n = make_node<LoweredFuncNode>(*f.operator->());
364   n->body = BuiltinLower().Build(n->body);
365   return LoweredFunc(n);
366 }
367 
368 }  // namespace ir
369 }  // namespace tvm
370