1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import tvm 18 19def test_vectorize_loop(): 20 dtype = 'int64' 21 n = tvm.var('n') 22 ib = tvm.ir_builder.create() 23 A = ib.pointer("float32", name="A") 24 with ib.for_range(0, n) as i: 25 with ib.for_range(0, 4, for_type="vectorize") as j: 26 A[j] = tvm.const(1, A.dtype) 27 stmt = ib.get() 28 29 assert isinstance(stmt.body, tvm.stmt.For) 30 stmt = tvm.ir_pass.VectorizeLoop(stmt) 31 assert isinstance(stmt, tvm.stmt.For) 32 assert not isinstance(stmt.body, tvm.stmt.For) 33 assert isinstance(stmt.body.index, tvm.expr.Ramp) 34 assert isinstance(stmt.body.value, tvm.expr.Broadcast) 35 36def test_vectorize_vector(): 37 dtype = 'int64' 38 n = tvm.var('n') 39 ib = tvm.ir_builder.create() 40 A = ib.pointer("float32x4", name="A") 41 with ib.for_range(0, n) as i: 42 with ib.for_range(0, 4, for_type="vectorize") as j: 43 A[j] = tvm.const(1, A.dtype) 44 stmt = ib.get() 45 assert isinstance(stmt.body, tvm.stmt.For) 46 stmt = tvm.ir_pass.VectorizeLoop(stmt) 47 assert isinstance(stmt, tvm.stmt.For) 48 assert not isinstance(stmt.body, tvm.stmt.For) 49 assert isinstance(stmt.body.index, tvm.expr.Ramp) 50 assert isinstance(stmt.body.value, tvm.expr.Broadcast) 51 52 53def test_vectorize_with_if(): 54 n = tvm.var('n') 55 x = tvm.var('x') 56 ib = tvm.ir_builder.create() 57 A = ib.pointer("float32", name="A") 58 with ib.for_range(0, 4, for_type="vectorize") as i: 59 with ib.if_scope(x < n): 60 A[i] = A[i] + 1 61 with ib.else_scope(): 62 with ib.if_scope(i < n): 63 A[i] = 2.0 64 stmt = ib.get() 65 stmt = tvm.ir_pass.VectorizeLoop(stmt) 66 assert isinstance(stmt, tvm.stmt.IfThenElse) 67 assert isinstance(stmt.then_case.index, tvm.expr.Ramp) 68 assert isinstance(stmt.then_case.value, tvm.expr.Add) 69 assert stmt.then_case.value.dtype == "float32x4" 70 assert isinstance(stmt.else_case, tvm.stmt.For) 71 72def test_vectorize_with_le_cond(): 73 n = tvm.var('n') 74 ib = tvm.ir_builder.create() 75 A = ib.pointer("float32", name="A") 76 with ib.for_range(0, 4, for_type="vectorize") as i: 77 with ib.if_scope(i <= n): 78 A[i] = A[i] + 1 79 stmt = ib.get() 80 stmt = tvm.ir_pass.VectorizeLoop(stmt) 81 assert isinstance(stmt, tvm.stmt.For) 82 83def test_vectorize_with_ge_cond(): 84 n = tvm.var('n') 85 ib = tvm.ir_builder.create() 86 A = ib.pointer("float32", name="A") 87 with ib.for_range(0, 4, for_type="vectorize") as i: 88 with ib.if_scope(i >= n): 89 A[i] = A[i] + 1 90 stmt = ib.get() 91 stmt = tvm.ir_pass.VectorizeLoop(stmt) 92 assert isinstance(stmt, tvm.stmt.For) 93 94def test_vectorize_if_then_else(): 95 n = tvm.var('n') 96 x = tvm.var('x') 97 ib = tvm.ir_builder.create() 98 A = ib.pointer("float32", name="A") 99 with ib.for_range(0, 4, for_type="vectorize") as i: 100 A[i] = tvm.call_intrin("float32", "tvm_if_then_else", 101 i > 0, 102 A[i] + 1, A[i]) 103 stmt = ib.get() 104 stmt = tvm.ir_pass.VectorizeLoop(stmt) 105 assert isinstance(stmt, tvm.stmt.For) 106 107 108 ib = tvm.ir_builder.create() 109 A = ib.pointer("float32", name="A") 110 with ib.for_range(0, n) as k: 111 with ib.for_range(0, 4, for_type="vectorize") as i: 112 A[k * 4 + i] = tvm.call_intrin("float32", "tvm_if_then_else", 113 k > 0, 114 A[k * 4 + i], 0) 115 stmt = ib.get() 116 assert isinstance(stmt.body, tvm.stmt.For) 117 stmt = tvm.ir_pass.VectorizeLoop(stmt) 118 assert not isinstance(stmt.body, tvm.stmt.For) 119 assert isinstance(stmt.body.value.args[2], tvm.expr.Broadcast) 120 121 122if __name__ == "__main__": 123 test_vectorize_vector() 124 test_vectorize_with_if() 125 test_vectorize_loop() 126 test_vectorize_if_then_else() 127 test_vectorize_with_le_cond() 128 test_vectorize_with_ge_cond() 129