1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file vectorize_loop.cc
22 */
23 // Loop vectorizer as in Halide pipeline.
24 #include <tvm/arith/analyzer.h>
25 #include <tvm/runtime/registry.h>
26 #include <tvm/tir/analysis.h>
27 #include <tvm/tir/builtin.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/op.h>
30 #include <tvm/tir/op_attr_types.h>
31 #include <tvm/tir/stmt_functor.h>
32 #include <tvm/tir/transform.h>
33
34 #include <unordered_map>
35 #include <unordered_set>
36 #include <vector>
37
38 namespace tvm {
39 namespace tir {
40
BroadcastTo(PrimExpr e,int lanes)41 inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
42 if (e.dtype().lanes() == lanes) return e;
43 if (const BroadcastNode* op = e.as<BroadcastNode>()) {
44 if (lanes % op->lanes == 0) {
45 return Broadcast(op->value, lanes);
46 }
47 }
48 CHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to "
49 << lanes;
50 return Broadcast(e, lanes);
51 }
52
53 // Rewrite vectorized allocation access
54 // This is necessary for making each vector component containing its own workspace.
55 // Originates from Halide's loop vectorizer
56 //
57 // s[i] = s[i * lanes + var]
58 //
59 // The same principle applies when using one thread to simulate multiple context.
60 //
61 class VecAllocAccess : public StmtExprMutator {
62 public:
VecAllocAccess(const VarNode * buf,Var var,int var_lanes)63 VecAllocAccess(const VarNode* buf, Var var, int var_lanes)
64 : buf_(buf), var_(var), var_lanes_(var_lanes) {}
65 // Load
VisitExpr_(const LoadNode * op)66 PrimExpr VisitExpr_(const LoadNode* op) final {
67 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
68 op = expr.as<LoadNode>();
69 if (op->buffer_var.get() == buf_) {
70 return Load(op->dtype, op->buffer_var, op->index * var_lanes_ + var_, op->predicate);
71 } else {
72 return expr;
73 }
74 }
75 // Store
VisitStmt_(const StoreNode * op)76 Stmt VisitStmt_(const StoreNode* op) final {
77 Stmt stmt = StmtExprMutator::VisitStmt_(op);
78 op = stmt.as<StoreNode>();
79 if (op->buffer_var.get() == buf_) {
80 return Store(op->buffer_var, op->value, op->index * var_lanes_ + var_, op->predicate);
81 } else {
82 return stmt;
83 }
84 }
85
86 private:
87 // buffer var
88 const VarNode* buf_;
89 // variable to be replaced
90 Var var_;
91 // the lanes.
92 int var_lanes_;
93 };
94
95 // We use ExprFunctor directly instead of StmtExprMutator
96 // This is because the transformation can change the dtype of the Expr
97 // The existing ExprMutator transformation rules may not be well defined.
98 class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExpr&)> {
99 public:
100 using ExprFunctor::VisitExpr;
101 using StmtMutator::operator();
102
Vectorizer(Var var,int var_lanes)103 Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
104 ramp_ = Ramp(0, 1, var_lanes);
105 }
106
VisitStmt(const Stmt & stmt)107 Stmt VisitStmt(const Stmt& stmt) final {
108 CHECK(!need_scalarize_);
109 Stmt ret = StmtMutator::VisitStmt(stmt);
110 if (need_scalarize_) {
111 need_scalarize_ = false;
112 return Scalarize(stmt);
113 } else {
114 return ret;
115 }
116 }
117
VisitExpr(const PrimExpr & e)118 PrimExpr VisitExpr(const PrimExpr& e) final { return ExprFunctor::VisitExpr(e); }
119
VisitExpr_(const AddNode * op)120 PrimExpr VisitExpr_(const AddNode* op) final {
121 return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; });
122 }
123
VisitExpr_(const SubNode * op)124 PrimExpr VisitExpr_(const SubNode* op) final {
125 return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; });
126 }
127
VisitExpr_(const MulNode * op)128 PrimExpr VisitExpr_(const MulNode* op) final {
129 PrimExpr a = this->VisitExpr(op->a);
130 PrimExpr b = this->VisitExpr(op->b);
131 if (a.same_as(op->a) && b.same_as(op->b)) {
132 return GetRef<PrimExpr>(op);
133 } else {
134 int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
135 if (lanes != 1) {
136 const RampNode* b_ramp = b.as<RampNode>();
137 const RampNode* a_ramp = a.as<RampNode>();
138 if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
139 return Ramp(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
140 }
141 if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) {
142 return Ramp(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
143 }
144 }
145 return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
146 }
147 return BinaryVec<Mul>(op);
148 }
VisitExpr_(const DivNode * op)149 PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec<Div>(op); }
VisitExpr_(const ModNode * op)150 PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec<Mod>(op); }
VisitExpr_(const FloorDivNode * op)151 PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec<FloorDiv>(op); }
VisitExpr_(const FloorModNode * op)152 PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec<FloorMod>(op); }
VisitExpr_(const MinNode * op)153 PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec<Min>(op); }
VisitExpr_(const MaxNode * op)154 PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec<Max>(op); }
VisitExpr_(const EQNode * op)155 PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec<EQ>(op); }
VisitExpr_(const NENode * op)156 PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec<NE>(op); }
VisitExpr_(const LTNode * op)157 PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec<LT>(op); }
VisitExpr_(const LENode * op)158 PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec<LE>(op); }
VisitExpr_(const GTNode * op)159 PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec<GT>(op); }
VisitExpr_(const GENode * op)160 PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec<GE>(op); }
VisitExpr_(const AndNode * op)161 PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec<And>(op); }
VisitExpr_(const OrNode * op)162 PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec<Or>(op); }
163
VisitExpr_(const NotNode * op)164 PrimExpr VisitExpr_(const NotNode* op) final {
165 PrimExpr a = this->VisitExpr(op->a);
166 if (a.same_as(op->a)) {
167 return GetRef<PrimExpr>(op);
168 } else {
169 return !(a);
170 }
171 }
172
VisitExpr_(const RampNode * op)173 PrimExpr VisitExpr_(const RampNode* op) final {
174 PrimExpr base = this->VisitExpr(op->base);
175 PrimExpr stride = this->VisitExpr(op->stride);
176 if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) {
177 const RampNode* base_ramp = base.as<RampNode>();
178 if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) {
179 return Ramp(base_ramp->base, stride, op->lanes * base_ramp->lanes);
180 }
181 }
182 int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
183 base = BroadcastTo(base, lanes);
184 stride = BroadcastTo(stride, lanes);
185 Array<PrimExpr> elems;
186 for (int i = 0; i < lanes; ++i) {
187 elems.push_back(
188 Ramp(Shuffle::ExtractElement(base, i), Shuffle::ExtractElement(stride, i), op->lanes));
189 }
190 return Shuffle::Concat(elems);
191 }
192
VisitExpr_(const BroadcastNode * op)193 PrimExpr VisitExpr_(const BroadcastNode* op) final {
194 PrimExpr value = this->VisitExpr(op->value);
195 if (value.dtype().lanes() != 1) {
196 need_scalarize_ = true;
197 return GetRef<PrimExpr>(op);
198 }
199 if (value.same_as(op->value)) {
200 return GetRef<PrimExpr>(op);
201 } else {
202 return Broadcast(op->value, op->lanes);
203 }
204 }
205
VisitExpr_(const SelectNode * op)206 PrimExpr VisitExpr_(const SelectNode* op) final {
207 PrimExpr cond = this->VisitExpr(op->condition);
208 PrimExpr t = this->VisitExpr(op->true_value);
209 PrimExpr f = this->VisitExpr(op->false_value);
210 if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) {
211 return GetRef<PrimExpr>(op);
212 } else {
213 int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes());
214 return Select(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
215 }
216 }
VisitExpr_(const CastNode * op)217 PrimExpr VisitExpr_(const CastNode* op) final {
218 PrimExpr value = this->VisitExpr(op->value);
219 if (value.same_as(op->value)) {
220 return GetRef<PrimExpr>(op);
221 } else {
222 return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
223 }
224 }
225
VisitExpr_(const FloatImmNode * op)226 PrimExpr VisitExpr_(const FloatImmNode* op) final { return GetRef<PrimExpr>(op); }
227
VisitExpr_(const IntImmNode * op)228 PrimExpr VisitExpr_(const IntImmNode* op) final { return GetRef<PrimExpr>(op); }
229
VisitExpr_(const StringImmNode * op)230 PrimExpr VisitExpr_(const StringImmNode* op) final { return GetRef<PrimExpr>(op); }
231
232 // Variable
VisitExpr_(const VarNode * op)233 PrimExpr VisitExpr_(const VarNode* op) final {
234 Var var = GetRef<Var>(op);
235
236 if (var.same_as(var_)) {
237 return ramp_;
238 }
239 auto it = let_binding_.find(var);
240 if (it != let_binding_.end()) {
241 return it->second;
242 } else {
243 return std::move(var);
244 }
245 }
246 // IfThenElse expr
MutateIfThenElseExpr_(const CallNode * op)247 PrimExpr MutateIfThenElseExpr_(const CallNode* op) {
248 PrimExpr cond = this->VisitExpr(op->args[0]);
249 if (cond.dtype().is_vector()) {
250 need_scalarize_ = true;
251 return GetRef<PrimExpr>(op);
252 }
253 PrimExpr t = this->VisitExpr(op->args[1]);
254 PrimExpr f = this->VisitExpr(op->args[2]);
255 if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) {
256 return GetRef<PrimExpr>(op);
257 } else {
258 int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
259 t = BroadcastTo(t, lanes);
260 f = BroadcastTo(f, lanes);
261 return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
262 }
263 }
264 // Call
VisitExpr_(const CallNode * op)265 PrimExpr VisitExpr_(const CallNode* op) final {
266 if (op->op.same_as(builtin::if_then_else())) {
267 return MutateIfThenElseExpr_(op);
268 }
269 auto* op_ptr = op->op.as<OpNode>();
270 bool vectorizable = op_ptr && op_vectorizable_.get(GetRef<Op>(op_ptr), false);
271
272 if (!vectorizable) {
273 // Cannot vectorize this op
274 Array<PrimExpr> new_args;
275 for (auto arg : op->args) {
276 auto new_arg = this->VisitExpr(arg);
277 if (new_arg.dtype().is_vector()) {
278 need_scalarize_ = true;
279 return GetRef<PrimExpr>(op);
280 }
281 new_args.push_back(new_arg);
282 }
283 if (op->args.same_as(new_args)) {
284 return GetRef<PrimExpr>(op);
285 } else {
286 return Call(op->dtype, op->op, new_args);
287 }
288 } else {
289 int lane = 0;
290 Array<PrimExpr> new_args = MutateArray(op->args, &lane);
291 // normal code path.
292 if (op->args.same_as(new_args)) {
293 return GetRef<PrimExpr>(op);
294 } else {
295 return Call(op->dtype.with_lanes(lane), op->op, new_args);
296 }
297 }
298 }
299 // Load
VisitExpr_(const LoadNode * op)300 PrimExpr VisitExpr_(const LoadNode* op) final {
301 PrimExpr index = this->VisitExpr(op->index);
302 PrimExpr pred = this->VisitExpr(op->predicate);
303 if (index.same_as(op->index) && pred.same_as(op->predicate)) {
304 return GetRef<PrimExpr>(op);
305 } else {
306 int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes());
307 return Load(op->dtype.with_lanes(lanes), op->buffer_var, BroadcastTo(index, lanes),
308 BroadcastTo(pred, lanes));
309 }
310 }
311 // Let
VisitExpr_(const LetNode * op)312 PrimExpr VisitExpr_(const LetNode* op) final {
313 PrimExpr value = this->VisitExpr(op->value);
314 // Weaker SSA condition
315 // A single var can be binded in multiple lets
316 // but they have to bind to the same value.
317 // This is used to allow cases when we reuse a single let
318 // expression to cosntruct a nested expr.
319 // (let x = 1 in x + 1) * (let x = 1 in x + 1)
320 auto it = let_binding_.find(op->var);
321 if (it != let_binding_.end()) {
322 CHECK(deep_equal_(it->second, value))
323 << "Let cannot bind the same var to two different values";
324 }
325 if (value.dtype().lanes() != op->value.dtype().lanes()) {
326 Var new_var(op->var->name_hint, value.dtype());
327 let_binding_[op->var] = new_var;
328 return Let(new_var, value, this->VisitExpr(op->body));
329 } else {
330 let_binding_[op->var] = op->var;
331 PrimExpr body = this->VisitExpr(op->body);
332 if (value.same_as(op->value) && body.same_as(op->body)) {
333 return GetRef<PrimExpr>(op);
334 } else {
335 return Let(op->var, value, body);
336 }
337 }
338 }
339 // Store
VisitStmt_(const StoreNode * op)340 Stmt VisitStmt_(const StoreNode* op) final {
341 PrimExpr value = this->VisitExpr(op->value);
342 PrimExpr index = this->VisitExpr(op->index);
343 PrimExpr pred = this->VisitExpr(op->predicate);
344 if (value.same_as(op->value) && index.same_as(op->index)) {
345 return GetRef<Stmt>(op);
346 } else {
347 int lanes = std::max(value.dtype().lanes(), index.dtype().lanes());
348 lanes = std::max(lanes, pred.dtype().lanes());
349 return Store(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes),
350 BroadcastTo(pred, lanes));
351 }
352 }
353 // For
VisitStmt_(const ForNode * op)354 Stmt VisitStmt_(const ForNode* op) final {
355 if (op->for_type == ForType::Vectorized) {
356 LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
357 }
358 CHECK(is_zero(op->min));
359 CHECK(!op->extent.dtype().is_vector());
360 PrimExpr extent = this->VisitExpr(op->extent);
361 if (extent.dtype().is_vector()) {
362 return Scalarize(GetRef<Stmt>(op));
363 }
364 Stmt body = this->VisitStmt(op->body);
365 if (extent.same_as(op->extent) && body.same_as(op->body)) {
366 return GetRef<Stmt>(op);
367 } else {
368 return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
369 }
370 }
371 // IfThenElse
VisitStmt_(const IfThenElseNode * op)372 Stmt VisitStmt_(const IfThenElseNode* op) final {
373 CHECK(!op->condition.dtype().is_vector());
374 PrimExpr condition = this->VisitExpr(op->condition);
375 if (condition.dtype().is_vector()) {
376 return Scalarize(GetRef<Stmt>(op));
377 }
378 Stmt then_case = this->VisitStmt(op->then_case);
379 Stmt else_case;
380 if (op->else_case.defined()) {
381 else_case = this->VisitStmt(op->else_case);
382 }
383 if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
384 else_case.same_as(op->else_case)) {
385 return GetRef<Stmt>(op);
386 } else {
387 return IfThenElse(condition, then_case, else_case);
388 }
389 }
390 // LetStmt
VisitStmt_(const LetStmtNode * op)391 Stmt VisitStmt_(const LetStmtNode* op) final {
392 PrimExpr value = this->VisitExpr(op->value);
393 CHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice";
394 let_binding_[op->var] = value;
395
396 if (value.dtype().lanes() != op->value.dtype().lanes()) {
397 Var new_var(op->var->name_hint, value.dtype());
398 let_binding_[op->var] = new_var;
399 return LetStmt(new_var, value, this->VisitStmt(op->body));
400 } else {
401 let_binding_[op->var] = op->var;
402 Stmt body = this->VisitStmt(op->body);
403 if (value.same_as(op->value) && body.same_as(op->body)) {
404 return GetRef<Stmt>(op);
405 } else {
406 return LetStmt(op->var, value, body);
407 }
408 }
409 }
410 // Allocate
VisitStmt_(const AllocateNode * op)411 Stmt VisitStmt_(const AllocateNode* op) final {
412 PrimExpr condition = this->VisitExpr(op->condition);
413 if (condition.dtype().is_vector()) {
414 LOG(WARNING) << "Cannot handle vector extent in alloc ";
415 return Scalarize(GetRef<Stmt>(op));
416 }
417 Array<PrimExpr> extents;
418 for (size_t i = 0; i < op->extents.size(); i++) {
419 PrimExpr new_ext = this->VisitExpr(op->extents[i]);
420 if (new_ext.dtype().is_vector()) {
421 LOG(WARNING) << "Cannot handle vector extent in alloc ";
422 return Scalarize(GetRef<Stmt>(op));
423 }
424 extents.push_back(new_ext);
425 }
426 // place the vector lanes in least significant dimension.
427 extents.push_back(var_lanes_);
428 // rewrite access to buffer internally.
429 Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
430 body = this->VisitStmt(body);
431 return Allocate(op->buffer_var, op->dtype, extents, condition, body);
432 }
433
434 // scalarize the statment
Scalarize(Stmt stmt)435 Stmt Scalarize(Stmt stmt) {
436 Var idx(var_->name_hint + ".s", var_->dtype);
437 Map<Var, PrimExpr> values{{var_, idx}};
438 stmt = Substitute(stmt, values);
439 return For(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
440 }
441 // ProducerStore
VisitStmt_(const ProducerStoreNode * op)442 Stmt VisitStmt_(const ProducerStoreNode* op) final {
443 LOG(FATAL) << "ProducerProvide is cannot appear in a TIR PrimFunc";
444 return Stmt();
445 }
446
447 private:
448 // analyzer
449 arith::Analyzer analyzer_;
450 // deep equal
451 ExprDeepEqual deep_equal_;
452 // variable to be replaced
453 Var var_;
454 // the lanes.
455 int var_lanes_;
456 // ramp representing the var.
457 PrimExpr ramp_;
458 // flag to mark requirment of scalarization.
459 bool need_scalarize_{false};
460 // Let binding
461 std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
462 // vectorizable property
463 OpAttrMap<TVectorizable> op_vectorizable_ = Op::GetAttrMap<TVectorizable>("TVectorizable");
464
465 // mutate array, with given lane requirement
466 // when finished, p_lane updates the lane requirement.
MutateArray(Array<PrimExpr> arr,int * p_lanes)467 Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int* p_lanes) {
468 if (arr.size() == 0) return arr;
469 int& lanes = *p_lanes;
470 bool changed = false;
471 std::vector<PrimExpr> new_arr(arr.size());
472 for (size_t i = 0; i < arr.size(); i++) {
473 PrimExpr old_elem = arr[i];
474 PrimExpr new_elem = this->VisitExpr(old_elem);
475 if (!new_elem.same_as(old_elem)) changed = true;
476 new_arr[i] = new_elem;
477 lanes = std::max(lanes, new_elem.dtype().lanes());
478 }
479
480 for (size_t i = 0; i < arr.size(); ++i) {
481 if (new_arr[i].dtype().lanes() != lanes) {
482 new_arr[i] = BroadcastTo(new_arr[i], lanes);
483 changed = true;
484 }
485 }
486 if (!changed) return arr;
487 return Array<PrimExpr>(new_arr);
488 }
489 template <typename TOp, typename T>
BinaryVec(const T * op)490 PrimExpr BinaryVec(const T* op) {
491 static_assert(std::is_same<typename TOp::ContainerType, T>::value, "constraint");
492 PrimExpr a = this->VisitExpr(op->a);
493 PrimExpr b = this->VisitExpr(op->b);
494 if (a.same_as(op->a) && b.same_as(op->b)) {
495 return GetRef<PrimExpr>(op);
496 } else {
497 int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
498 return TOp(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
499 }
500 }
501 template <typename T, typename FCompute>
AddSubVec(const T * op,FCompute fcompute)502 PrimExpr AddSubVec(const T* op, FCompute fcompute) {
503 PrimExpr a = this->VisitExpr(op->a);
504 PrimExpr b = this->VisitExpr(op->b);
505 if (a.same_as(op->a) && b.same_as(op->b)) {
506 return GetRef<PrimExpr>(op);
507 } else {
508 int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
509 if (lanes != 1) {
510 const RampNode* b_ramp = b.as<RampNode>();
511 const RampNode* a_ramp = a.as<RampNode>();
512 if (a.dtype().lanes() == 1 && b_ramp) {
513 return Ramp(fcompute(a, b_ramp->base),
514 fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes);
515 }
516 if (b.dtype().lanes() == 1 && a_ramp) {
517 return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
518 }
519 }
520 return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
521 }
522 }
523 };
524
525 class LoopVectorizer : public StmtMutator {
526 public:
VisitStmt_(const ForNode * op)527 Stmt VisitStmt_(const ForNode* op) final {
528 if (op->for_type == ForType::Vectorized) {
529 CHECK(is_zero(op->min));
530 auto* extent_as_int = op->extent.as<IntImmNode>();
531 if (!extent_as_int || extent_as_int->value < 1) {
532 LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
533 }
534 return Vectorizer(op->loop_var, static_cast<int>(extent_as_int->value))(op->body);
535 } else {
536 return StmtMutator::VisitStmt_(op);
537 }
538 }
539 };
540
VectorizeLoop(Stmt stmt)541 Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); }
542
543 class VectorizeSkipper : public StmtMutator {
544 public:
VisitStmt_(const ForNode * op)545 Stmt VisitStmt_(const ForNode* op) final {
546 Stmt stmt = StmtMutator::VisitStmt_(op);
547 op = stmt.as<ForNode>();
548 if (op->for_type == ForType::Vectorized) {
549 return For(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, op->body);
550 } else {
551 return stmt;
552 }
553 }
554 };
555
SkipVectorize(Stmt stmt)556 Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); }
557
558 namespace transform {
559
560 // TODO(tvm-team): Make it as a target property.
VectorizeLoop(bool enable_vectorize)561 Pass VectorizeLoop(bool enable_vectorize) {
562 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
563 auto* n = f.CopyOnWrite();
564 if (enable_vectorize) {
565 n->body = LoopVectorizer()(std::move(n->body));
566 } else {
567 n->body = VectorizeSkipper()(std::move(n->body));
568 }
569 return f;
570 };
571 return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {});
572 }
573
574 TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop").set_body_typed(VectorizeLoop);
575
576 } // namespace transform
577
578 } // namespace tir
579 } // namespace tvm
580