1// RUN: mlir-opt --tosa-make-broadcastable %s | FileCheck %s 2 3// ----- 4// CHECK-LABEL: broadcast0 5func @test_broadcast0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { 6 // CHECK-NOT: reshape 7 %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> 8 return %0 : tensor<1xf32> 9} 10 11// ----- 12// CHECK-LABEL: broadcast1 13func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> { 14 // CHECK: reshape 15 %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32> 16 return %0 : tensor<2x1xf32> 17} 18 19// ----- 20// CHECK-LABEL: broadcast2 21func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> { 22 // CHECK: reshape 23 %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32> 24 return %0 : tensor<2x1xf32> 25} 26 27// ----- 28// CHECK-LABEL: broadcast3 29func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> { 30 // CHECK: reshape 31 %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32> 32 return %0 : tensor<2x1x1x1xf32> 33} 34 35// ----- 36// CHECK-LABEL: broadcast4 37func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> { 38 // CHECK: reshape 39 %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32> 40 return %0 : tensor<1x1x1x2xf32> 41} 42 43// ----- 44// CHECK-LABEL: broadcast5 45func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> { 46 // CHECK: reshape 47 %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32> 48 return %0 : tensor<1x1x2x1xf32> 49} 50 51// ----- 52// CHECK-LABEL: broadcast6 53func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> { 54 // CHECK: reshape 55 %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32> 56 return %0 : tensor<17x16x15x14xf32> 57} 58 59// ----- 60// CHECK-LABEL: broadcast7 61func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> { 62 // CHECK: reshape 63 %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32> 64 return %0 : tensor<17x16x1x14xf32> 65} 66 67// ----- 68// CHECK-LABEL: broadcast8 69func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> { 70 // CHECK: reshape 71 %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32> 72 return %0 : tensor<17x16x15x14xf32> 73} 74 75// ----- 76// CHECK-LABEL: broadcast9 77func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> { 78 // CHECK: reshape 79 %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32> 80 return %0 : tensor<17x16x15x14xf32> 81} 82 83// ----- 84// CHECK-LABEL: broadcast10 85func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> { 86 // CHECK: reshape 87 %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32> 88 return %0 : tensor<17x16x15x14xf32> 89} 90 91// ----- 92// CHECK-LABEL: broadcast13 93func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { 94 // CHECK: reshape 95 %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> 96 return %0 : tensor<17x16x15x14xf32> 97} 98 99// ----- 100// CHECK-LABEL: broadcast14 101func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> { 102 // CHECK: reshape 103 %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> 104 return %0 : tensor<17x16x1x14xf32> 105} 106 107// ----- 108// CHECK-LABEL: broadcast15 109func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { 110 // CHECK: reshape 111 %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> 112 return %0 : tensor<17x16x15x14xf32> 113} 114 115// ----- 116// CHECK-LABEL: broadcast16 117func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { 118 // CHECK: reshape 119 %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> 120 return %0 : tensor<17x16x15x14xf32> 121} 122 123// ----- 124// CHECK-LABEL: broadcast17 125func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { 126 // CHECK: reshape 127 %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> 128 return %0 : tensor<17x16x15x14xf32> 129} 130 131// ----- 132// CHECK-LABEL: broadcast18 133func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> { 134 // CHECK: add 135 %0 = "tosa.add"(%arg0, %arg1) : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32> 136 return %0 : tensor<14x15xf32> 137} 138 139// ----- 140// CHECK-LABEL: broadcast_mul 141func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { 142 // CHECK: reshape 143 %0 = "tosa.mul"(%arg0, %arg1) {shift = 1 : i32 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> 144 return %0 : tensor<17x16x15x14xi32> 145} 146 147// ----- 148// CHECK-LABEL: broadcast_arithmetic_right_shift 149func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { 150 // CHECK: reshape 151 %0 = "tosa.arithmetic_right_shift"(%arg0, %arg1) { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> 152 return %0 : tensor<17x16x15x14xi32> 153} 154 155// ----- 156// CHECK-LABEL: broadcast_scalar 157func @test_broadcast_scalar(%arg0: tensor<i32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { 158 // CHECK-NEXT: reshape 159 %0 = "tosa.add"(%arg0, %arg1) : (tensor<i32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> 160 return %0 : tensor<17x16x15x14xi32> 161} 162