1// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s 2// RUN: mlir-opt %s -test-vector-unrolling-patterns | FileCheck %s 3 4// CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)> 5// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)> 6 7// CHECK-LABEL: func @add4x2 8// CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>> 9// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>> 10// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>> 11// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>> 12// CHECK-NEXT: %[[A1:.*]] = addf %[[TG1]], %[[TG2]] : vector<2x2xf32> 13// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>> 14// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>> 15// CHECK-NEXT: %[[A2:.*]] = addf %[[TG3]], %[[TG4]] : vector<2x2xf32> 16// CHECK-NEXT: %[[R1:.*]] = vector.tuple %[[A1]], %[[A2]] : vector<2x2xf32>, vector<2x2xf32> 17// CHECK-NEXT: %[[R2:.*]] = vector.insert_slices %[[R1]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32> 18// CHECK-NEXT: return %[[R2:.*]] : vector<4x2xf32> 19 20func @add4x2(%0: vector<4x2xf32>) -> vector<4x2xf32> { 21 %1 = addf %0, %0: vector<4x2xf32> 22 return %1: vector<4x2xf32> 23} 24 25// CHECK-LABEL: func @add4x4 26// CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 27// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 28 29// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 30// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 31// CHECK-NEXT: %[[A1:.*]] = addf %[[TG1]], %[[TG2]] : vector<2x2xf32> 32 33// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 34// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 35// CHECK-NEXT: %[[A2:.*]] = addf %[[TG3]], %[[TG4]] : vector<2x2xf32> 36 37// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES1]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 38// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 39// CHECK-NEXT: %[[A3:.*]] = addf %[[TG5]], %[[TG6]] : vector<2x2xf32> 40 41// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES1]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 42// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES2]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 43// CHECK-NEXT: %[[A4:.*]] = addf %[[TG7]], %[[TG8]] : vector<2x2xf32> 44 45// CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 46 47// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 48// CHECK-NEXT: %[[A5:.*]] = addf %[[TG9]], %[[A1]] : vector<2x2xf32> 49 50// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 51// CHECK-NEXT: %[[A6:.*]] = addf %[[TG11]], %[[A2]] : vector<2x2xf32> 52 53// CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 54// CHECK-NEXT: %[[A7:.*]] = addf %[[TG13]], %[[A3]] : vector<2x2xf32> 55 56// CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 57// CHECK-NEXT: %[[A8:.*]] = addf %[[TG15]], %[[A4]] : vector<2x2xf32> 58 59// CHECK-NEXT: %[[R3:.*]] = vector.tuple %[[A5]], %[[A6]], %[[A7]], %[[A8]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> 60// CHECK-NEXT: %[[R4:.*]] = vector.insert_slices %[[R3]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> 61// CHECK-NEXT: return %[[R4]] : vector<4x4xf32> 62 63func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> { 64 %2 = addf %0, %1: vector<4x4xf32> 65 %3 = addf %1, %2: vector<4x4xf32> 66 return %3: vector<4x4xf32> 67} 68 69#contraction_accesses0 = [ 70 affine_map<(i, j, k) -> (i, k)>, 71 affine_map<(i, j, k) -> (k, j)>, 72 affine_map<(i, j, k) -> (i, j)> 73] 74#contraction_trait0 = { 75 indexing_maps = #contraction_accesses0, 76 iterator_types = ["parallel", "parallel", "reduction"] 77} 78 79// CHECK-LABEL: func @contraction4x4_ijk 80 81// CHECK: %[[LMASK:.*]] = vector.constant_mask [4, 6] : vector<4x6xi1> 82// CHECK-NEXT: %[[RMASK:.*]] = vector.constant_mask [6, 4] : vector<6x4xi1> 83 84// Reducing output vector [0, 0] 85 86// CHECK-NEXT: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x6xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 87// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<6x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 88// CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 89// CHECK-NEXT: %[[ES4:.*]] = vector.extract_slices %[[LMASK]], [2, 2], [1, 1] : vector<4x6xi1> into tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 90// CHECK-NEXT: %[[ES5:.*]] = vector.extract_slices %[[RMASK]], [2, 2], [1, 1] : vector<6x4xi1> into tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 91 92// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 93// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 94// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 95// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 96// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 97// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 98 99// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 100// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES2]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 101// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES4]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 102// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES5]], 2 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 103// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 104 105// CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES1]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 106// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES2]], 4 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 107// CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES4]], 2 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 108// CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES5]], 4 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 109// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 110 111// Reducing output vector [0, 2] 112 113// CHECK-NEXT: %[[TG14:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 114// CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 115// CHECK-NEXT: %[[TG16:.*]] = vector.tuple_get %[[ES5]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 116// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 117 118// CHECK-NEXT: %[[TG17:.*]] = vector.tuple_get %[[ES2]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 119// CHECK-NEXT: %[[TG18:.*]] = vector.tuple_get %[[ES5]], 3 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 120// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 121 122// CHECK-NEXT: %[[TG19:.*]] = vector.tuple_get %[[ES2]], 5 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 123// CHECK-NEXT: %[[TG20:.*]] = vector.tuple_get %[[ES5]], 5 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 124// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 125 126// Reducing output vector [2, 0] 127 128// CHECK-NEXT: %[[TG21:.*]] = vector.tuple_get %[[ES1]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 129// CHECK-NEXT: %[[TG22:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 130// CHECK-NEXT: %[[TG23:.*]] = vector.tuple_get %[[ES4]], 3 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 131// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 132 133// CHECK-NEXT: %[[TG24:.*]] = vector.tuple_get %[[ES1]], 4 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 134// CHECK-NEXT: %[[TG25:.*]] = vector.tuple_get %[[ES4]], 4 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 135// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 136 137// CHECK-NEXT: %[[TG26:.*]] = vector.tuple_get %[[ES1]], 5 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 138// CHECK-NEXT: %[[TG27:.*]] = vector.tuple_get %[[ES4]], 5 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>> 139// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 140 141// Reducing output vector [2, 2] 142 143// CHECK-NEXT: %[[TG28:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 144// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 145// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 146// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 147 148// CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R3S00]], %[[R3S02]], %[[R3S20]], %[[R3S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> 149// CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> 150// CHECK-NEXT: return %[[RES1]] : vector<4x4xf32> 151 152func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>, 153 %arg2 : vector<4x4xf32>, %arg3 : index) 154 -> (vector<4x4xf32>) { 155 %lhsm = vector.constant_mask [4, 6] : vector<4x6xi1> 156 %rhsm = vector.constant_mask [6, 4] : vector<6x4xi1> 157 %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm 158 : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> 159 160 return %0 : vector<4x4xf32> 161} 162 163#contraction_accesses1 = [ 164 affine_map<(i, k, j) -> (i, k)>, 165 affine_map<(i, k, j) -> (k, j)>, 166 affine_map<(i, k, j) -> (i, j)> 167] 168#contraction_trait1 = { 169 indexing_maps = #contraction_accesses1, 170 iterator_types = ["parallel", "reduction", "parallel"] 171} 172 173// CHECK-LABEL: func @contraction4x4_ikj 174 175 176// CHECK: %[[LMASK:.*]] = vector.constant_mask [4, 2] : vector<4x2xi1> 177// CHECK-NEXT: %[[RMASK:.*]] = vector.constant_mask [2, 4] : vector<2x4xi1> 178 179// Reducing output vector [0, 0] 180 181// CHECK-NEXT: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>> 182// CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<2x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>> 183// CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 184// CHECK-NEXT: %[[ES4:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xi1> into tuple<vector<2x2xi1>, vector<2x2xi1>> 185// CHECK-NEXT: %[[ES5:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<2x4xi1> into tuple<vector<2x2xi1>, vector<2x2xi1>> 186 187// CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>> 188// CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>> 189// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 190// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>> 191// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>> 192// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 193 194// Reducing output vector [0, 2] 195 196// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>> 197// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 198// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES5]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>> 199// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 200 201// Reducing output vector [2, 0] 202 203// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>> 204// CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 205// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES4]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>> 206// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 207 208// Reducing output vector [2, 2] 209 210// CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> 211// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 212 213// CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R1S00]], %[[R1S02]], %[[R1S20]], %[[R1S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> 214// CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> 215// CHECK-NEXT: return %[[RES1]] : vector<4x4xf32> 216 217func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>, 218 %arg2 : vector<4x4xf32>, %arg3 : index) 219 -> (vector<4x4xf32>) { 220 %lhsm = vector.constant_mask [4, 2] : vector<4x2xi1> 221 %rhsm = vector.constant_mask [2, 4] : vector<2x4xi1> 222 %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm 223 : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32> 224 225 return %0 : vector<4x4xf32> 226} 227 228// CHECK-LABEL: func @contraction4x4_ikj_xfer_read 229 230// CHECK: %[[C0:.*]] = constant 0 : index 231// CHECK: %[[C2:.*]] = constant 2 : index 232 233// Check LHS vector.transfer read is split for each user. 234 235// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x2xf32>, vector<2x2xf32> 236// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x2xf32>, vector<2x2xf32> 237 238// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<2x4xf32>, vector<2x2xf32> 239// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<2x4xf32>, vector<2x2xf32> 240 241// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> 242// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> 243// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> 244// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> 245 246// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 247// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 248// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 249// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> 250 251// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> 252// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> 253// CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> 254// CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> 255// CHECK-NEXT: return 256 257func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>, 258 %arg1 : memref<2x4xf32>, 259 %arg2 : memref<4x4xf32>) { 260 %c0 = constant 0 : index 261 %cf0 = constant 0.0 : f32 262 263 %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 264 { permutation_map = affine_map<(d0, d1) -> (d0, d1)> } 265 : memref<4x2xf32>, vector<4x2xf32> 266 267 %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 268 { permutation_map = affine_map<(d0, d1) -> (d0, d1)> } 269 : memref<2x4xf32>, vector<2x4xf32> 270 271 %2 = vector.transfer_read %arg2[%c0, %c0], %cf0 272 { permutation_map = affine_map<(d0, d1) -> (d0, d1)> } 273 : memref<4x4xf32>, vector<4x4xf32> 274 275 %3 = vector.contract #contraction_trait1 %0, %1, %2 276 : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32> 277 278 vector.transfer_write %3, %arg2[%c0, %c0] 279 {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} 280 : vector<4x4xf32>, memref<4x4xf32> 281 return 282} 283 284// TODO: Update test with VTR split transform. 285// CHECK-LABEL: func @vector_transfers 286// CHECK-COUNT-8: vector.transfer_read 287// CHECK-COUNT-4: addf 288// CHECK-COUNT-4: vector.transfer_write 289 290func @vector_transfers(%arg0: index, %arg1: index) { 291 %cst = constant 0.000000e+00 : f32 292 %0 = alloc(%arg0, %arg1) : memref<?x?xf32> 293 %1 = alloc(%arg0, %arg1) : memref<?x?xf32> 294 %2 = alloc(%arg0, %arg1) : memref<?x?xf32> 295 %cst_0 = constant 1.000000e+00 : f32 296 %cst_1 = constant 2.000000e+00 : f32 297 affine.for %arg2 = 0 to %arg0 step 4 { 298 affine.for %arg3 = 0 to %arg1 step 4 { 299 %4 = vector.transfer_read %0[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32> 300 %5 = vector.transfer_read %1[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32> 301 %6 = addf %4, %5 : vector<4x4xf32> 302 vector.transfer_write %6, %2[%arg2, %arg3] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : vector<4x4xf32>, memref<?x?xf32> 303 } 304 } 305 return 306} 307 308// CHECK-LABEL: func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) 309// CHECK: return %arg1 310 311func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { 312 %0 = vector.tuple %arg0, %arg1 : vector<4xf32>, vector<8xf32> 313 %1 = vector.tuple_get %0, 1 : tuple<vector<4xf32>, vector<8xf32>> 314 return %1 : vector<8xf32> 315} 316 317// CHECK-LABEL: func @tuple_get_producer_consumer 318// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>, 319// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>, 320// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>, 321// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>, 322// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>, 323// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>, 324// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>, 325// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32> 326// CHECK: return %[[A7]] : vector<2x4xf32> 327 328func @tuple_get_producer_consumer( 329 %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>, 330 %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>, 331 %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>, 332 %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> { 333 %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7 334 : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, 335 vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32> 336 // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0] 337 %1 = vector.insert_slices %0, [2, 4], [1, 1] 338 : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, 339 vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> 340 into vector<4x16xf32> 341 // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12] 342 %2 = vector.extract_slices %1, [4, 8], [1, 1] 343 : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>> 344 // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4] 345 %3 = vector.shape_cast %2 : tuple<vector<4x8xf32>, vector<4x8xf32>> to 346 tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>> 347 // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4] 348 %4 = vector.tuple_get %3, 1 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>> 349 // %arg7 == %4 at tupleIndex = -1, offsets = [0, 0, 2, 4] 350 %5 = vector.shape_cast %4 : vector<1x1x4x8xf32> to vector<4x8xf32> 351 // %arg7 == %5 at tupleIndex = -1, offsets = [2, 4] 352 %6 = vector.extract_slices %5, [2, 4], [1, 1] 353 : vector<4x8xf32> into 354 tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> 355 // %arg7 == %6 at tupleIndex = 3, offsets = [0, 0] 356 %7 = vector.tuple_get %6, 3 357 : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> 358 // %arg7 == %7 359 return %7 : vector<2x4xf32> 360} 361 362// CHECK-LABEL: func @tuple_get_producer_consumer_swizzle 363// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>, 364// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>, 365// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>, 366// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>, 367// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>, 368// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>, 369// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>, 370// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32> 371// CHECK: return %[[A7]] : vector<2x4xf32> 372 373func @tuple_get_producer_consumer_swizzle( 374 %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>, 375 %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>, 376 %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>, 377 %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> { 378 %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7 379 : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, 380 vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32> 381 // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0] 382 %1 = vector.insert_slices %0, [2, 4], [1, 1] 383 : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, 384 vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> 385 into vector<4x16xf32> 386 // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12] 387 %2 = vector.extract_slices %1, [4, 8], [1, 1] 388 : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>> 389 // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4] 390 %3= vector.shape_cast %2 : tuple<vector<4x8xf32>, vector<4x8xf32>> to 391 tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>> 392 // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4] 393 394 // Extract tuple elements. 395 %4 = vector.tuple_get %3, 0 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>> 396 %5 = vector.tuple_get %3, 1 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>> 397 // %arg7 == %5 at tupleIndex = -1, offsets = [0, 0, 2, 4] 398 399 // Swizzle tuple elements. 400 %6 = vector.tuple %5, %4 : vector<1x1x4x8xf32>, vector<1x1x4x8xf32> 401 // %arg7 == %6 at tupleIndex = 0, offsets = [0, 0, 2, 4] 402 %7 = vector.shape_cast %6 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>> to 403 tuple<vector<4x8xf32>, vector<4x8xf32>> 404 // %arg7 = %7 at tupleIndex = 0, offsets = [2, 4] 405 %8 = vector.tuple_get %7, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>> 406 // %arg7 == %8 at tupleIndex = -1, offsets = [2, 4] 407 %9 = vector.extract_slices %8, [2, 4], [1, 1] 408 : vector<4x8xf32> into 409 tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> 410 // %arg7 == %9 at tupleIndex = 3, offsets = [0, 0] 411 %10 = vector.tuple_get %9, 3 412 : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> 413 // %arg7 == %10 414 return %10 : vector<2x4xf32> 415} 416 417// CHECK-LABEL: func @cancelling_shape_cast_ops 418// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32> 419// CHECK: return %[[A0]] : vector<2x4xf32> 420func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> { 421 %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32> 422 %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32> 423 return %1 : vector<2x4xf32> 424} 425 426// CHECK-LABEL: func @vector_transfers_vector_element_type 427// CHECK: %[[C0:.*]] = constant 0 : index 428// CHECK: %[[C1:.*]] = constant 1 : index 429// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {masked = [false, false]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32> 430// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} {masked = [false, false]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32> 431// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {masked = [false, false]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> 432// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] {masked = [false, false]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> 433 434func @vector_transfers_vector_element_type() { 435 %c0 = constant 0 : index 436 %cf0 = constant 0.000000e+00 : f32 437 %vf0 = splat %cf0 : vector<2x4xf32> 438 439 %0 = alloc() : memref<6x2x1xvector<2x4xf32>> 440 441 %1 = vector.transfer_read %0[%c0, %c0, %c0], %vf0 442 {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} 443 : memref<6x2x1xvector<2x4xf32>>, vector<2x1x2x4xf32> 444 445 %2 = vector.extract_slices %1, [1, 1, 2, 4], [1, 1, 1, 1] 446 : vector<2x1x2x4xf32> into tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>> 447 %3 = vector.tuple_get %2, 0 : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>> 448 %4 = vector.tuple_get %2, 1 : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>> 449 %5 = vector.tuple %3, %4 : vector<1x1x2x4xf32>, vector<1x1x2x4xf32> 450 %6 = vector.insert_slices %5, [1, 1, 2, 4], [1, 1, 1, 1] 451 : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>> into vector<2x1x2x4xf32> 452 453 vector.transfer_write %6, %0[%c0, %c0, %c0] 454 {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} 455 : vector<2x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> 456 457 return 458} 459 460// Test that ShapeCastOp on tuple of vectors, decomposes to multiple 461// ShapeCastOps on vectors. 462// CHECK-LABEL: func @shape_cast_decomposition 463// CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<5x4x2xf32> to vector<20x2xf32> 464// CHECK-NEXT: %[[V1:.*]] = vector.shape_cast %{{.*}} : vector<3x4x2xf32> to vector<12x2xf32> 465// CHECK-NEXT: return %[[V0]], %[[V1]] : vector<20x2xf32>, vector<12x2xf32> 466 467func @shape_cast_decomposition(%arg0 : vector<5x4x2xf32>, 468 %arg1 : vector<3x4x2xf32>) 469 -> (vector<20x2xf32>, vector<12x2xf32>) { 470 %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32> 471 %1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to 472 tuple<vector<20x2xf32>, vector<12x2xf32>> 473 %2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>> 474 %3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>> 475 return %2, %3 : vector<20x2xf32>, vector<12x2xf32> 476} 477 478// Test that cancelling ShapeCastOps are canonicalized away. 479// EX: 480// 481// The following MLIR with cancelling ShapeCastOps: 482// 483// %0 = source : vector<5x4x2xf32> 484// %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32> 485// %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32> 486// %3 = user %2 : vector<5x4x2xf32> 487// 488// Should canonicalize to the following: 489// 490// 491// %0 = source : vector<5x4x2xf32> 492// %1 = user %0 : vector<5x4x2xf32> 493// 494 495// ShapeCastOps on vectors. 496// CHECK-LABEL: func @shape_cast_fold 497// CHECK: return %{{.*}}, %{{.*}} : vector<5x4x2xf32>, vector<3x4x2xf32> 498 499func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>) 500 -> (vector<5x4x2xf32>, vector<3x4x2xf32>) { 501 %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32> 502 503 %1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to 504 tuple<vector<20x2xf32>, vector<12x2xf32>> 505 506 %2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>> 507 %3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>> 508 509 %4 = vector.tuple %2, %3 : vector<20x2xf32>, vector<12x2xf32> 510 %5 = vector.shape_cast %4 : tuple<vector<20x2xf32>, vector<12x2xf32>> to 511 tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> 512 513 %6 = vector.tuple_get %5, 0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> 514 %7 = vector.tuple_get %5, 1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> 515 516 return %6, %7 : vector<5x4x2xf32>, vector<3x4x2xf32> 517} 518