1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import tvm
18
19def test_vectorize_loop():
20    dtype = 'int64'
21    n = tvm.var('n')
22    ib = tvm.ir_builder.create()
23    A = ib.pointer("float32", name="A")
24    with ib.for_range(0, n) as i:
25        with ib.for_range(0, 4, for_type="vectorize") as j:
26            A[j] = tvm.const(1, A.dtype)
27    stmt = ib.get()
28
29    assert isinstance(stmt.body, tvm.stmt.For)
30    stmt = tvm.ir_pass.VectorizeLoop(stmt)
31    assert isinstance(stmt, tvm.stmt.For)
32    assert not isinstance(stmt.body, tvm.stmt.For)
33    assert isinstance(stmt.body.index, tvm.expr.Ramp)
34    assert isinstance(stmt.body.value, tvm.expr.Broadcast)
35
36def test_vectorize_vector():
37    dtype = 'int64'
38    n = tvm.var('n')
39    ib = tvm.ir_builder.create()
40    A = ib.pointer("float32x4", name="A")
41    with ib.for_range(0, n) as i:
42        with ib.for_range(0, 4, for_type="vectorize") as j:
43            A[j] = tvm.const(1, A.dtype)
44    stmt = ib.get()
45    assert isinstance(stmt.body, tvm.stmt.For)
46    stmt = tvm.ir_pass.VectorizeLoop(stmt)
47    assert isinstance(stmt, tvm.stmt.For)
48    assert not isinstance(stmt.body, tvm.stmt.For)
49    assert isinstance(stmt.body.index, tvm.expr.Ramp)
50    assert isinstance(stmt.body.value, tvm.expr.Broadcast)
51
52
53def test_vectorize_with_if():
54    n = tvm.var('n')
55    x = tvm.var('x')
56    ib = tvm.ir_builder.create()
57    A = ib.pointer("float32", name="A")
58    with ib.for_range(0, 4, for_type="vectorize") as i:
59        with ib.if_scope(x < n):
60            A[i] = A[i] + 1
61        with ib.else_scope():
62            with ib.if_scope(i < n):
63                A[i] = 2.0
64    stmt = ib.get()
65    stmt = tvm.ir_pass.VectorizeLoop(stmt)
66    assert isinstance(stmt, tvm.stmt.IfThenElse)
67    assert isinstance(stmt.then_case.index, tvm.expr.Ramp)
68    assert isinstance(stmt.then_case.value, tvm.expr.Add)
69    assert stmt.then_case.value.dtype == "float32x4"
70    assert isinstance(stmt.else_case, tvm.stmt.For)
71
72def test_vectorize_with_le_cond():
73    n = tvm.var('n')
74    ib = tvm.ir_builder.create()
75    A = ib.pointer("float32", name="A")
76    with ib.for_range(0, 4, for_type="vectorize") as i:
77        with ib.if_scope(i <= n):
78            A[i] = A[i] + 1
79    stmt = ib.get()
80    stmt = tvm.ir_pass.VectorizeLoop(stmt)
81    assert isinstance(stmt, tvm.stmt.For)
82
83def test_vectorize_with_ge_cond():
84    n = tvm.var('n')
85    ib = tvm.ir_builder.create()
86    A = ib.pointer("float32", name="A")
87    with ib.for_range(0, 4, for_type="vectorize") as i:
88        with ib.if_scope(i >= n):
89            A[i] = A[i] + 1
90    stmt = ib.get()
91    stmt = tvm.ir_pass.VectorizeLoop(stmt)
92    assert isinstance(stmt, tvm.stmt.For)
93
94def test_vectorize_if_then_else():
95    n = tvm.var('n')
96    x = tvm.var('x')
97    ib = tvm.ir_builder.create()
98    A = ib.pointer("float32", name="A")
99    with ib.for_range(0, 4, for_type="vectorize") as i:
100        A[i] = tvm.call_intrin("float32", "tvm_if_then_else",
101                               i > 0,
102                               A[i] + 1, A[i])
103    stmt = ib.get()
104    stmt = tvm.ir_pass.VectorizeLoop(stmt)
105    assert isinstance(stmt, tvm.stmt.For)
106
107
108    ib = tvm.ir_builder.create()
109    A = ib.pointer("float32", name="A")
110    with ib.for_range(0, n) as k:
111        with ib.for_range(0, 4, for_type="vectorize") as i:
112            A[k * 4 + i] = tvm.call_intrin("float32", "tvm_if_then_else",
113                                           k > 0,
114                                           A[k * 4 + i], 0)
115    stmt = ib.get()
116    assert isinstance(stmt.body, tvm.stmt.For)
117    stmt = tvm.ir_pass.VectorizeLoop(stmt)
118    assert not isinstance(stmt.body, tvm.stmt.For)
119    assert isinstance(stmt.body.value.args[2], tvm.expr.Broadcast)
120
121
122if __name__ == "__main__":
123    test_vectorize_vector()
124    test_vectorize_with_if()
125    test_vectorize_loop()
126    test_vectorize_if_then_else()
127    test_vectorize_with_le_cond()
128    test_vectorize_with_ge_cond()
129