1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file vectorize_loop.cc
22  */
23 // Loop vectorizer as in Halide pipeline.
24 #include <tvm/arith/analyzer.h>
25 #include <tvm/runtime/registry.h>
26 #include <tvm/tir/analysis.h>
27 #include <tvm/tir/builtin.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/op.h>
30 #include <tvm/tir/op_attr_types.h>
31 #include <tvm/tir/stmt_functor.h>
32 #include <tvm/tir/transform.h>
33 
34 #include <unordered_map>
35 #include <unordered_set>
36 #include <vector>
37 
38 namespace tvm {
39 namespace tir {
40 
BroadcastTo(PrimExpr e,int lanes)41 inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
42   if (e.dtype().lanes() == lanes) return e;
43   if (const BroadcastNode* op = e.as<BroadcastNode>()) {
44     if (lanes % op->lanes == 0) {
45       return Broadcast(op->value, lanes);
46     }
47   }
48   CHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to "
49                                  << lanes;
50   return Broadcast(e, lanes);
51 }
52 
53 // Rewrite vectorized allocation access
54 // This is necessary for making each vector component containing its own workspace.
55 // Originates from Halide's loop vectorizer
56 //
57 // s[i] = s[i * lanes + var]
58 //
59 // The same principle applies when using one thread to simulate multiple context.
60 //
61 class VecAllocAccess : public StmtExprMutator {
62  public:
VecAllocAccess(const VarNode * buf,Var var,int var_lanes)63   VecAllocAccess(const VarNode* buf, Var var, int var_lanes)
64       : buf_(buf), var_(var), var_lanes_(var_lanes) {}
65   // Load
VisitExpr_(const LoadNode * op)66   PrimExpr VisitExpr_(const LoadNode* op) final {
67     PrimExpr expr = StmtExprMutator::VisitExpr_(op);
68     op = expr.as<LoadNode>();
69     if (op->buffer_var.get() == buf_) {
70       return Load(op->dtype, op->buffer_var, op->index * var_lanes_ + var_, op->predicate);
71     } else {
72       return expr;
73     }
74   }
75   // Store
VisitStmt_(const StoreNode * op)76   Stmt VisitStmt_(const StoreNode* op) final {
77     Stmt stmt = StmtExprMutator::VisitStmt_(op);
78     op = stmt.as<StoreNode>();
79     if (op->buffer_var.get() == buf_) {
80       return Store(op->buffer_var, op->value, op->index * var_lanes_ + var_, op->predicate);
81     } else {
82       return stmt;
83     }
84   }
85 
86  private:
87   // buffer var
88   const VarNode* buf_;
89   // variable to be replaced
90   Var var_;
91   // the lanes.
92   int var_lanes_;
93 };
94 
95 // We use ExprFunctor directly instead of StmtExprMutator
96 // This is because the transformation can change the dtype of the Expr
97 // The existing ExprMutator transformation rules may not be well defined.
98 class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExpr&)> {
99  public:
100   using ExprFunctor::VisitExpr;
101   using StmtMutator::operator();
102 
Vectorizer(Var var,int var_lanes)103   Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
104     ramp_ = Ramp(0, 1, var_lanes);
105   }
106 
VisitStmt(const Stmt & stmt)107   Stmt VisitStmt(const Stmt& stmt) final {
108     CHECK(!need_scalarize_);
109     Stmt ret = StmtMutator::VisitStmt(stmt);
110     if (need_scalarize_) {
111       need_scalarize_ = false;
112       return Scalarize(stmt);
113     } else {
114       return ret;
115     }
116   }
117 
VisitExpr(const PrimExpr & e)118   PrimExpr VisitExpr(const PrimExpr& e) final { return ExprFunctor::VisitExpr(e); }
119 
VisitExpr_(const AddNode * op)120   PrimExpr VisitExpr_(const AddNode* op) final {
121     return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; });
122   }
123 
VisitExpr_(const SubNode * op)124   PrimExpr VisitExpr_(const SubNode* op) final {
125     return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; });
126   }
127 
VisitExpr_(const MulNode * op)128   PrimExpr VisitExpr_(const MulNode* op) final {
129     PrimExpr a = this->VisitExpr(op->a);
130     PrimExpr b = this->VisitExpr(op->b);
131     if (a.same_as(op->a) && b.same_as(op->b)) {
132       return GetRef<PrimExpr>(op);
133     } else {
134       int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
135       if (lanes != 1) {
136         const RampNode* b_ramp = b.as<RampNode>();
137         const RampNode* a_ramp = a.as<RampNode>();
138         if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
139           return Ramp(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
140         }
141         if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) {
142           return Ramp(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
143         }
144       }
145       return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
146     }
147     return BinaryVec<Mul>(op);
148   }
VisitExpr_(const DivNode * op)149   PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec<Div>(op); }
VisitExpr_(const ModNode * op)150   PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec<Mod>(op); }
VisitExpr_(const FloorDivNode * op)151   PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec<FloorDiv>(op); }
VisitExpr_(const FloorModNode * op)152   PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec<FloorMod>(op); }
VisitExpr_(const MinNode * op)153   PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec<Min>(op); }
VisitExpr_(const MaxNode * op)154   PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec<Max>(op); }
VisitExpr_(const EQNode * op)155   PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec<EQ>(op); }
VisitExpr_(const NENode * op)156   PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec<NE>(op); }
VisitExpr_(const LTNode * op)157   PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec<LT>(op); }
VisitExpr_(const LENode * op)158   PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec<LE>(op); }
VisitExpr_(const GTNode * op)159   PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec<GT>(op); }
VisitExpr_(const GENode * op)160   PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec<GE>(op); }
VisitExpr_(const AndNode * op)161   PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec<And>(op); }
VisitExpr_(const OrNode * op)162   PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec<Or>(op); }
163 
VisitExpr_(const NotNode * op)164   PrimExpr VisitExpr_(const NotNode* op) final {
165     PrimExpr a = this->VisitExpr(op->a);
166     if (a.same_as(op->a)) {
167       return GetRef<PrimExpr>(op);
168     } else {
169       return !(a);
170     }
171   }
172 
VisitExpr_(const RampNode * op)173   PrimExpr VisitExpr_(const RampNode* op) final {
174     PrimExpr base = this->VisitExpr(op->base);
175     PrimExpr stride = this->VisitExpr(op->stride);
176     if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) {
177       const RampNode* base_ramp = base.as<RampNode>();
178       if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) {
179         return Ramp(base_ramp->base, stride, op->lanes * base_ramp->lanes);
180       }
181     }
182     int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
183     base = BroadcastTo(base, lanes);
184     stride = BroadcastTo(stride, lanes);
185     Array<PrimExpr> elems;
186     for (int i = 0; i < lanes; ++i) {
187       elems.push_back(
188           Ramp(Shuffle::ExtractElement(base, i), Shuffle::ExtractElement(stride, i), op->lanes));
189     }
190     return Shuffle::Concat(elems);
191   }
192 
VisitExpr_(const BroadcastNode * op)193   PrimExpr VisitExpr_(const BroadcastNode* op) final {
194     PrimExpr value = this->VisitExpr(op->value);
195     if (value.dtype().lanes() != 1) {
196       need_scalarize_ = true;
197       return GetRef<PrimExpr>(op);
198     }
199     if (value.same_as(op->value)) {
200       return GetRef<PrimExpr>(op);
201     } else {
202       return Broadcast(op->value, op->lanes);
203     }
204   }
205 
VisitExpr_(const SelectNode * op)206   PrimExpr VisitExpr_(const SelectNode* op) final {
207     PrimExpr cond = this->VisitExpr(op->condition);
208     PrimExpr t = this->VisitExpr(op->true_value);
209     PrimExpr f = this->VisitExpr(op->false_value);
210     if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) {
211       return GetRef<PrimExpr>(op);
212     } else {
213       int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes());
214       return Select(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
215     }
216   }
VisitExpr_(const CastNode * op)217   PrimExpr VisitExpr_(const CastNode* op) final {
218     PrimExpr value = this->VisitExpr(op->value);
219     if (value.same_as(op->value)) {
220       return GetRef<PrimExpr>(op);
221     } else {
222       return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
223     }
224   }
225 
VisitExpr_(const FloatImmNode * op)226   PrimExpr VisitExpr_(const FloatImmNode* op) final { return GetRef<PrimExpr>(op); }
227 
VisitExpr_(const IntImmNode * op)228   PrimExpr VisitExpr_(const IntImmNode* op) final { return GetRef<PrimExpr>(op); }
229 
VisitExpr_(const StringImmNode * op)230   PrimExpr VisitExpr_(const StringImmNode* op) final { return GetRef<PrimExpr>(op); }
231 
232   // Variable
VisitExpr_(const VarNode * op)233   PrimExpr VisitExpr_(const VarNode* op) final {
234     Var var = GetRef<Var>(op);
235 
236     if (var.same_as(var_)) {
237       return ramp_;
238     }
239     auto it = let_binding_.find(var);
240     if (it != let_binding_.end()) {
241       return it->second;
242     } else {
243       return std::move(var);
244     }
245   }
246   // IfThenElse expr
MutateIfThenElseExpr_(const CallNode * op)247   PrimExpr MutateIfThenElseExpr_(const CallNode* op) {
248     PrimExpr cond = this->VisitExpr(op->args[0]);
249     if (cond.dtype().is_vector()) {
250       need_scalarize_ = true;
251       return GetRef<PrimExpr>(op);
252     }
253     PrimExpr t = this->VisitExpr(op->args[1]);
254     PrimExpr f = this->VisitExpr(op->args[2]);
255     if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) {
256       return GetRef<PrimExpr>(op);
257     } else {
258       int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
259       t = BroadcastTo(t, lanes);
260       f = BroadcastTo(f, lanes);
261       return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
262     }
263   }
264   // Call
VisitExpr_(const CallNode * op)265   PrimExpr VisitExpr_(const CallNode* op) final {
266     if (op->op.same_as(builtin::if_then_else())) {
267       return MutateIfThenElseExpr_(op);
268     }
269     auto* op_ptr = op->op.as<OpNode>();
270     bool vectorizable = op_ptr && op_vectorizable_.get(GetRef<Op>(op_ptr), false);
271 
272     if (!vectorizable) {
273       // Cannot vectorize this op
274       Array<PrimExpr> new_args;
275       for (auto arg : op->args) {
276         auto new_arg = this->VisitExpr(arg);
277         if (new_arg.dtype().is_vector()) {
278           need_scalarize_ = true;
279           return GetRef<PrimExpr>(op);
280         }
281         new_args.push_back(new_arg);
282       }
283       if (op->args.same_as(new_args)) {
284         return GetRef<PrimExpr>(op);
285       } else {
286         return Call(op->dtype, op->op, new_args);
287       }
288     } else {
289       int lane = 0;
290       Array<PrimExpr> new_args = MutateArray(op->args, &lane);
291       // normal code path.
292       if (op->args.same_as(new_args)) {
293         return GetRef<PrimExpr>(op);
294       } else {
295         return Call(op->dtype.with_lanes(lane), op->op, new_args);
296       }
297     }
298   }
299   // Load
VisitExpr_(const LoadNode * op)300   PrimExpr VisitExpr_(const LoadNode* op) final {
301     PrimExpr index = this->VisitExpr(op->index);
302     PrimExpr pred = this->VisitExpr(op->predicate);
303     if (index.same_as(op->index) && pred.same_as(op->predicate)) {
304       return GetRef<PrimExpr>(op);
305     } else {
306       int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes());
307       return Load(op->dtype.with_lanes(lanes), op->buffer_var, BroadcastTo(index, lanes),
308                   BroadcastTo(pred, lanes));
309     }
310   }
311   // Let
VisitExpr_(const LetNode * op)312   PrimExpr VisitExpr_(const LetNode* op) final {
313     PrimExpr value = this->VisitExpr(op->value);
314     // Weaker SSA condition
315     // A single var can be binded in multiple lets
316     // but they have to bind to the same value.
317     // This is used to allow cases when we reuse a single let
318     // expression to cosntruct a nested expr.
319     // (let x = 1 in x + 1) * (let x = 1 in x + 1)
320     auto it = let_binding_.find(op->var);
321     if (it != let_binding_.end()) {
322       CHECK(deep_equal_(it->second, value))
323           << "Let cannot bind the same var to two different values";
324     }
325     if (value.dtype().lanes() != op->value.dtype().lanes()) {
326       Var new_var(op->var->name_hint, value.dtype());
327       let_binding_[op->var] = new_var;
328       return Let(new_var, value, this->VisitExpr(op->body));
329     } else {
330       let_binding_[op->var] = op->var;
331       PrimExpr body = this->VisitExpr(op->body);
332       if (value.same_as(op->value) && body.same_as(op->body)) {
333         return GetRef<PrimExpr>(op);
334       } else {
335         return Let(op->var, value, body);
336       }
337     }
338   }
339   // Store
VisitStmt_(const StoreNode * op)340   Stmt VisitStmt_(const StoreNode* op) final {
341     PrimExpr value = this->VisitExpr(op->value);
342     PrimExpr index = this->VisitExpr(op->index);
343     PrimExpr pred = this->VisitExpr(op->predicate);
344     if (value.same_as(op->value) && index.same_as(op->index)) {
345       return GetRef<Stmt>(op);
346     } else {
347       int lanes = std::max(value.dtype().lanes(), index.dtype().lanes());
348       lanes = std::max(lanes, pred.dtype().lanes());
349       return Store(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes),
350                    BroadcastTo(pred, lanes));
351     }
352   }
353   // For
VisitStmt_(const ForNode * op)354   Stmt VisitStmt_(const ForNode* op) final {
355     if (op->for_type == ForType::Vectorized) {
356       LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
357     }
358     CHECK(is_zero(op->min));
359     CHECK(!op->extent.dtype().is_vector());
360     PrimExpr extent = this->VisitExpr(op->extent);
361     if (extent.dtype().is_vector()) {
362       return Scalarize(GetRef<Stmt>(op));
363     }
364     Stmt body = this->VisitStmt(op->body);
365     if (extent.same_as(op->extent) && body.same_as(op->body)) {
366       return GetRef<Stmt>(op);
367     } else {
368       return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body);
369     }
370   }
371   // IfThenElse
VisitStmt_(const IfThenElseNode * op)372   Stmt VisitStmt_(const IfThenElseNode* op) final {
373     CHECK(!op->condition.dtype().is_vector());
374     PrimExpr condition = this->VisitExpr(op->condition);
375     if (condition.dtype().is_vector()) {
376       return Scalarize(GetRef<Stmt>(op));
377     }
378     Stmt then_case = this->VisitStmt(op->then_case);
379     Stmt else_case;
380     if (op->else_case.defined()) {
381       else_case = this->VisitStmt(op->else_case);
382     }
383     if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
384         else_case.same_as(op->else_case)) {
385       return GetRef<Stmt>(op);
386     } else {
387       return IfThenElse(condition, then_case, else_case);
388     }
389   }
390   // LetStmt
VisitStmt_(const LetStmtNode * op)391   Stmt VisitStmt_(const LetStmtNode* op) final {
392     PrimExpr value = this->VisitExpr(op->value);
393     CHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice";
394     let_binding_[op->var] = value;
395 
396     if (value.dtype().lanes() != op->value.dtype().lanes()) {
397       Var new_var(op->var->name_hint, value.dtype());
398       let_binding_[op->var] = new_var;
399       return LetStmt(new_var, value, this->VisitStmt(op->body));
400     } else {
401       let_binding_[op->var] = op->var;
402       Stmt body = this->VisitStmt(op->body);
403       if (value.same_as(op->value) && body.same_as(op->body)) {
404         return GetRef<Stmt>(op);
405       } else {
406         return LetStmt(op->var, value, body);
407       }
408     }
409   }
410   // Allocate
VisitStmt_(const AllocateNode * op)411   Stmt VisitStmt_(const AllocateNode* op) final {
412     PrimExpr condition = this->VisitExpr(op->condition);
413     if (condition.dtype().is_vector()) {
414       LOG(WARNING) << "Cannot handle vector extent in alloc ";
415       return Scalarize(GetRef<Stmt>(op));
416     }
417     Array<PrimExpr> extents;
418     for (size_t i = 0; i < op->extents.size(); i++) {
419       PrimExpr new_ext = this->VisitExpr(op->extents[i]);
420       if (new_ext.dtype().is_vector()) {
421         LOG(WARNING) << "Cannot handle vector extent in alloc ";
422         return Scalarize(GetRef<Stmt>(op));
423       }
424       extents.push_back(new_ext);
425     }
426     // place the vector lanes in least significant dimension.
427     extents.push_back(var_lanes_);
428     // rewrite access to buffer internally.
429     Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
430     body = this->VisitStmt(body);
431     return Allocate(op->buffer_var, op->dtype, extents, condition, body);
432   }
433 
434   // scalarize the statment
Scalarize(Stmt stmt)435   Stmt Scalarize(Stmt stmt) {
436     Var idx(var_->name_hint + ".s", var_->dtype);
437     Map<Var, PrimExpr> values{{var_, idx}};
438     stmt = Substitute(stmt, values);
439     return For(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
440   }
441   // ProducerStore
VisitStmt_(const ProducerStoreNode * op)442   Stmt VisitStmt_(const ProducerStoreNode* op) final {
443     LOG(FATAL) << "ProducerProvide is cannot appear in a TIR PrimFunc";
444     return Stmt();
445   }
446 
447  private:
448   // analyzer
449   arith::Analyzer analyzer_;
450   // deep equal
451   ExprDeepEqual deep_equal_;
452   // variable to be replaced
453   Var var_;
454   // the lanes.
455   int var_lanes_;
456   // ramp representing the var.
457   PrimExpr ramp_;
458   // flag to mark requirment of scalarization.
459   bool need_scalarize_{false};
460   // Let binding
461   std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
462   // vectorizable property
463   OpAttrMap<TVectorizable> op_vectorizable_ = Op::GetAttrMap<TVectorizable>("TVectorizable");
464 
465   // mutate array, with given lane requirement
466   // when finished, p_lane updates the lane requirement.
MutateArray(Array<PrimExpr> arr,int * p_lanes)467   Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int* p_lanes) {
468     if (arr.size() == 0) return arr;
469     int& lanes = *p_lanes;
470     bool changed = false;
471     std::vector<PrimExpr> new_arr(arr.size());
472     for (size_t i = 0; i < arr.size(); i++) {
473       PrimExpr old_elem = arr[i];
474       PrimExpr new_elem = this->VisitExpr(old_elem);
475       if (!new_elem.same_as(old_elem)) changed = true;
476       new_arr[i] = new_elem;
477       lanes = std::max(lanes, new_elem.dtype().lanes());
478     }
479 
480     for (size_t i = 0; i < arr.size(); ++i) {
481       if (new_arr[i].dtype().lanes() != lanes) {
482         new_arr[i] = BroadcastTo(new_arr[i], lanes);
483         changed = true;
484       }
485     }
486     if (!changed) return arr;
487     return Array<PrimExpr>(new_arr);
488   }
489   template <typename TOp, typename T>
BinaryVec(const T * op)490   PrimExpr BinaryVec(const T* op) {
491     static_assert(std::is_same<typename TOp::ContainerType, T>::value, "constraint");
492     PrimExpr a = this->VisitExpr(op->a);
493     PrimExpr b = this->VisitExpr(op->b);
494     if (a.same_as(op->a) && b.same_as(op->b)) {
495       return GetRef<PrimExpr>(op);
496     } else {
497       int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
498       return TOp(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
499     }
500   }
501   template <typename T, typename FCompute>
AddSubVec(const T * op,FCompute fcompute)502   PrimExpr AddSubVec(const T* op, FCompute fcompute) {
503     PrimExpr a = this->VisitExpr(op->a);
504     PrimExpr b = this->VisitExpr(op->b);
505     if (a.same_as(op->a) && b.same_as(op->b)) {
506       return GetRef<PrimExpr>(op);
507     } else {
508       int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
509       if (lanes != 1) {
510         const RampNode* b_ramp = b.as<RampNode>();
511         const RampNode* a_ramp = a.as<RampNode>();
512         if (a.dtype().lanes() == 1 && b_ramp) {
513           return Ramp(fcompute(a, b_ramp->base),
514                       fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes);
515         }
516         if (b.dtype().lanes() == 1 && a_ramp) {
517           return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
518         }
519       }
520       return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
521     }
522   }
523 };
524 
525 class LoopVectorizer : public StmtMutator {
526  public:
VisitStmt_(const ForNode * op)527   Stmt VisitStmt_(const ForNode* op) final {
528     if (op->for_type == ForType::Vectorized) {
529       CHECK(is_zero(op->min));
530       auto* extent_as_int = op->extent.as<IntImmNode>();
531       if (!extent_as_int || extent_as_int->value < 1) {
532         LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
533       }
534       return Vectorizer(op->loop_var, static_cast<int>(extent_as_int->value))(op->body);
535     } else {
536       return StmtMutator::VisitStmt_(op);
537     }
538   }
539 };
540 
VectorizeLoop(Stmt stmt)541 Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); }
542 
543 class VectorizeSkipper : public StmtMutator {
544  public:
VisitStmt_(const ForNode * op)545   Stmt VisitStmt_(const ForNode* op) final {
546     Stmt stmt = StmtMutator::VisitStmt_(op);
547     op = stmt.as<ForNode>();
548     if (op->for_type == ForType::Vectorized) {
549       return For(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, op->body);
550     } else {
551       return stmt;
552     }
553   }
554 };
555 
SkipVectorize(Stmt stmt)556 Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); }
557 
558 namespace transform {
559 
560 // TODO(tvm-team): Make it as a target property.
VectorizeLoop(bool enable_vectorize)561 Pass VectorizeLoop(bool enable_vectorize) {
562   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
563     auto* n = f.CopyOnWrite();
564     if (enable_vectorize) {
565       n->body = LoopVectorizer()(std::move(n->body));
566     } else {
567       n->body = VectorizeSkipper()(std::move(n->body));
568     }
569     return f;
570   };
571   return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {});
572 }
573 
574 TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop").set_body_typed(VectorizeLoop);
575 
576 }  // namespace transform
577 
578 }  // namespace tir
579 }  // namespace tvm
580