1// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s 2 3func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { 4 %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) 5 outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> 6 return %0: tensor<16x32xf32> 7} 8 9// CHECK-LABEL: @generalize_matmul_tensor_f32 10// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) 11// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32 12// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 13// CHECK-NEXT: linalg.yield %[[ADD]] : f32 14// CHECK-NEXT: -> tensor<16x32xf32> 15 16// ----- 17 18func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { 19 %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>) 20 outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> 21 return %0: tensor<16x32xi32> 22} 23 24// CHECK-LABEL: @generalize_matmul_tensor_i32 25// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i32) 26// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_ARG]], %[[B_ARG]] : i32 27// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 28// CHECK-NEXT: linalg.yield %[[ADD]] : i32 29// CHECK-NEXT: -> tensor<16x32xi32> 30 31// ----- 32 33func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_f32(%input : tensor<1x4x16x1xf32>, %filter: tensor<2x2x1xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { 34 %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 35 ins(%input, %filter : tensor<1x4x16x1xf32>, tensor<2x2x1xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> 36 return %0: tensor<1x2x4x1xf32> 37} 38 39// CHECK-LABEL: @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_f32 40// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[FILTER_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) 41// CHECK-NEXT: %[[MUL:.+]] = mulf %[[IN_ARG]], %[[FILTER_ARG]] : f32 42// CHECK-NEXT: %[[ADD:.+]] = addf %[[OUT_ARG]], %[[MUL]] : f32 43// CHECK-NEXT: linalg.yield %[[ADD]] : f32 44// CHECK-NEXT: -> tensor<1x2x4x1xf32> 45 46// ----- 47 48func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32(%input : tensor<1x4x16x1xi32>, %filter: tensor<2x2x1xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { 49 %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 50 ins(%input, %filter : tensor<1x4x16x1xi32>, tensor<2x2x1xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> 51 return %0: tensor<1x2x4x1xi32> 52} 53 54// CHECK-LABEL: @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32 55// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[FILTER_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) 56// CHECK-NEXT: %[[MUL:.+]] = muli %[[IN_ARG]], %[[FILTER_ARG]] : i32 57// CHECK-NEXT: %[[ADD:.+]] = addi %[[OUT_ARG]], %[[MUL]] : i32 58// CHECK-NEXT: linalg.yield %[[ADD]] : i32 59// CHECK-NEXT: -> tensor<1x2x4x1xi32> 60 61// ----- 62 63func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { 64 %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 65 ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> 66 return %0: tensor<1x2x4x1xf32> 67} 68 69// CHECK-LABEL: @generalize_pooling_nhwc_max_f32 70// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) 71// CHECK-NEXT: %[[COND:.+]] = cmpf ogt, %[[OUT_ARG]], %[[IN_ARG]] : f32 72// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32 73// CHECK-NEXT: linalg.yield %[[MAX]] : f32 74// CHECK-NEXT: -> tensor<1x2x4x1xf32> 75 76// ----- 77 78func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { 79 %0 = linalg.pooling_nhwc_max {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 80 ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> 81 return %0: tensor<1x2x4x1xi32> 82} 83 84// CHECK-LABEL: @generalize_pooling_nhwc_max_i32 85// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) 86// CHECK-NEXT: %[[COND:.+]] = cmpi sgt, %[[OUT_ARG]], %[[IN_ARG]] : i32 87// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32 88// CHECK-NEXT: linalg.yield %[[MAX]] : i32 89// CHECK-NEXT: -> tensor<1x2x4x1xi32> 90 91// ----- 92 93func @generalize_pooling_nhwc_min_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { 94 %0 = linalg.pooling_nhwc_min {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 95 ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> 96 return %0: tensor<1x2x4x1xf32> 97} 98 99// CHECK-LABEL: @generalize_pooling_nhwc_min_f32 100// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) 101// CHECK-NEXT: %[[COND:.+]] = cmpf olt, %[[OUT_ARG]], %[[IN_ARG]] : f32 102// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : f32 103// CHECK-NEXT: linalg.yield %[[MAX]] : f32 104// CHECK-NEXT: -> tensor<1x2x4x1xf32> 105 106// ----- 107 108func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { 109 %0 = linalg.pooling_nhwc_min {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 110 ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> 111 return %0: tensor<1x2x4x1xi32> 112} 113 114// CHECK-LABEL: @generalize_pooling_nhwc_min_i32 115// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) 116// CHECK-NEXT: %[[COND:.+]] = cmpi slt, %[[OUT_ARG]], %[[IN_ARG]] : i32 117// CHECK-NEXT: %[[MAX:.+]] = select %[[COND]], %[[OUT_ARG]], %[[IN_ARG]] : i32 118// CHECK-NEXT: linalg.yield %[[MAX]] : i32 119// CHECK-NEXT: -> tensor<1x2x4x1xi32> 120 121// ----- 122 123func @generalize_pooling_nhwc_sum_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> { 124 %0 = linalg.pooling_nhwc_sum {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 125 ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> 126 return %0: tensor<1x2x4x1xf32> 127} 128 129// CHECK-LABEL: @generalize_pooling_nhwc_sum_f32 130// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) 131// CHECK-NEXT: %[[ADD:.+]] = addf %[[OUT_ARG]], %[[IN_ARG]] : f32 132// CHECK-NEXT: linalg.yield %[[ADD]] : f32 133// CHECK-NEXT: -> tensor<1x2x4x1xf32> 134 135// ----- 136 137func @generalize_pooling_nhwc_sum_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> { 138 %0 = linalg.pooling_nhwc_sum {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>} 139 ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> 140 return %0: tensor<1x2x4x1xi32> 141} 142 143// CHECK-LABEL: @generalize_pooling_nhwc_sum_i32 144// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32) 145// CHECK-NEXT: %[[ADD:.+]] = addi %[[OUT_ARG]], %[[IN_ARG]] : i32 146// CHECK-NEXT: linalg.yield %[[ADD]] : i32 147// CHECK-NEXT: -> tensor<1x2x4x1xi32> 148 149// ----- 150 151func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> { 152 %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32> 153 return %0: tensor<16x32xf32> 154} 155 156// CHECK-LABEL: @generalize_fill_rng_2d_f32 157// CHECK-DAG: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %[[O:.+]]: f32 158// CHECK-DAG: %[[IDX0:.+]] = linalg.index 0 : index 159// CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index 160// CHECK-DAG: %[[IDX0_CAST:.+]] = index_cast %[[IDX0]] : index to i32 161// CHECK-DAG: %[[IDX1_CAST:.+]] = index_cast %[[IDX1]] : index to i32 162// CHECK-DAG: %[[VAL0:.+]] = addi %[[IDX0_CAST]], %[[SEED]] : i32 163// CHECK-DAG: %[[CST0:.+]] = constant 1103515245 : i32 164// CHECK-DAG: %[[CST1:.+]] = constant 12345 : i32 165// CHECK-DAG: %[[VAL1:.+]] = muli %[[VAL0]], %[[CST0]] : i32 166// CHECK-DAG: %[[VAL2:.+]] = addi %[[VAL1]], %[[CST1]] : i32 167// Skip random number computation for the second index. 168// CHECK-DAG: %[[DIFF:.+]] = subf %[[MAX]], %[[MIN]] : f64 169// CHECK-DAG: %[[CST2:.+]] = constant 2.3283063999999999E-10 : f64 170// CHECK-DAG: %[[FACT:.+]] = mulf %[[DIFF]], %[[CST2]] : f64 171// CHECK-DAG: %[[VAL4:.+]] = mulf %{{.+}}, %[[FACT]] : f64 172// CHECK-DAG: %[[VAL5:.+]] = addf %[[VAL4]], %[[MIN]] : f64 173// CHECK-DAG: %[[VAL6:.+]] = fptrunc %[[VAL5]] : f64 to f32 174// CHECK-NEXT: linalg.yield %[[VAL6]] : f32 175// CHECK-NEXT: -> tensor<16x32xf32> 176 177// ----- 178 179func @generalize_fill_rng_2d_i32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xi32>) -> tensor<16x32xi32> { 180 %0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xi32>) -> tensor<16x32xi32> 181 return %0: tensor<16x32xi32> 182} 183 184// CHECK-LABEL: @generalize_fill_rng_2d_i32 185// CHECK: ^{{.*}}(%[[MIN:.+]]: f64, %[[MAX:.+]]: f64, %[[SEED:.+]]: i32, %[[O:.+]]: i32 186// Verifies floating point to integer cast. 187// CHECK: %[[VAL6:.+]] = fptosi %{{.+}} : f64 to i32 188// CHECK-NEXT: linalg.yield %[[VAL6]] : i32 189// CHECK-NEXT: -> tensor<16x32xi32> 190 191// ----- 192 193func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x32xf32>) -> tensor<16x32xf32> { 194 %0 = linalg.soft_plus_2d ins(%input: tensor<16x32xf32>) outs(%output: tensor<16x32xf32>) -> tensor<16x32xf32> 195 return %0: tensor<16x32xf32> 196} 197 198// CHECK-LABEL: @generalize_soft_plus_2d_f32 199// CHECK: %[[C1:.+]] = constant 1.000000e+00 : f64 200// CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32 201// CHECK-NEXT: %[[C1_CAST:.+]] = fptrunc %[[C1]] : f64 to f32 202// CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32 203// CHECK-NEXT: %[[SUM:.+]] = addf %[[C1_CAST]], %[[EXP]] : f32 204// CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32 205// CHECK-NEXT: linalg.yield %[[LOG]] : f32 206// CHECK-NEXT: -> tensor<16x32xf32> 207 208// ----- 209// Verifies floating point to integer cast. 210func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> { 211 %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) 212 outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16> 213 return %0: tensor<16x32xi16> 214} 215 216// CHECK-LABEL: @generalize_matmul_tensor_f32_f32_i16 217// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: i16) 218// CHECK-NEXT: %[[A_CAST:.+]] = fptosi %[[A_ARG]] : f32 to i16 219// CHECK-NEXT: %[[B_CAST:.+]] = fptosi %[[B_ARG]] : f32 to i16 220// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16 221// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16 222// CHECK-NEXT: linalg.yield %[[ADD]] : i16 223// CHECK-NEXT: -> tensor<16x32xi16> 224 225// ----- 226// Verifies sign extension cast. 227func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { 228 %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>) 229 outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> 230 return %0: tensor<16x32xi32> 231} 232 233// CHECK-LABEL: @generalize_matmul_tensor_i8_i8_i32 234// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: i32) 235// CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32 236// CHECK-NEXT: %[[B_CAST:.+]] = sexti %[[B_ARG]] : i8 to i32 237// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32 238// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 239// CHECK-NEXT: linalg.yield %[[ADD]] : i32 240// CHECK-NEXT: -> tensor<16x32xi32> 241 242// ----- 243// Verifies that different argument types is legal. 244func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi16>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { 245 %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>) 246 outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> 247 return %0: tensor<16x32xi32> 248} 249 250// CHECK-LABEL: @generalize_matmul_tensor_i8_i16_i32 251// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i16, %[[C_ARG:.+]]: i32) 252// CHECK-NEXT: %[[A_CAST:.+]] = sexti %[[A_ARG]] : i8 to i32 253// CHECK-NEXT: %[[B_CAST:.+]] = sexti %[[B_ARG]] : i16 to i32 254// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i32 255// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 256// CHECK-NEXT: linalg.yield %[[ADD]] : i32 257// CHECK-NEXT: -> tensor<16x32xi32> 258 259// ----- 260// Somewhat non-sensical but checks integer truncation cast. 261func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> { 262 %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>) 263 outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16> 264 return %0: tensor<16x32xi16> 265} 266 267// CHECK-LABEL: @generalize_matmul_tensor_i32_i32_i16 268// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i16) 269// CHECK-NEXT: %[[A_CAST:.+]] = trunci %[[A_ARG]] : i32 to i16 270// CHECK-NEXT: %[[B_CAST:.+]] = trunci %[[B_ARG]] : i32 to i16 271// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_CAST]], %[[B_CAST]] : i16 272// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16 273// CHECK-NEXT: linalg.yield %[[ADD]] : i16 274// CHECK-NEXT: -> tensor<16x32xi16> 275 276// ----- 277// Verifies integer to floating point cast. 278func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { 279 %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>) 280 outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> 281 return %0: tensor<16x32xf32> 282} 283 284// CHECK-LABEL: @generalize_matmul_tensor_i8_i8_f32 285// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i8, %[[B_ARG:.+]]: i8, %[[C_ARG:.+]]: f32) 286// CHECK-NEXT: %[[A_CAST:.+]] = sitofp %[[A_ARG]] : i8 to f32 287// CHECK-NEXT: %[[B_CAST:.+]] = sitofp %[[B_ARG]] : i8 to f32 288// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 289// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 290// CHECK-NEXT: linalg.yield %[[ADD]] : f32 291// CHECK-NEXT: -> tensor<16x32xf32> 292 293// ----- 294// Verifies floating point extension cast. 295func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { 296 %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>) 297 outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> 298 return %0: tensor<16x32xf32> 299} 300 301// CHECK-LABEL: @generalize_matmul_tensor_f16_f16_f32 302// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f16, %[[C_ARG:.+]]: f32) 303// CHECK-NEXT: %[[A_CAST:.+]] = fpext %[[A_ARG]] : f16 to f32 304// CHECK-NEXT: %[[B_CAST:.+]] = fpext %[[B_ARG]] : f16 to f32 305// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 306// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 307// CHECK-NEXT: linalg.yield %[[ADD]] : f32 308// CHECK-NEXT: -> tensor<16x32xf32> 309 310// ----- 311// Verifies floating point truncation. 312func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { 313 %0 = linalg.matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>) 314 outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> 315 return %0: tensor<16x32xf32> 316} 317 318// CHECK-LABEL: @generalize_matmul_tensor_f64_f64_f32 319// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f64, %[[B_ARG:.+]]: f64, %[[C_ARG:.+]]: f32) 320// CHECK-NEXT: %[[A_CAST:.+]] = fptrunc %[[A_ARG]] : f64 to f32 321// CHECK-NEXT: %[[B_CAST:.+]] = fptrunc %[[B_ARG]] : f64 to f32 322// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_CAST]], %[[B_CAST]] : f32 323// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 324// CHECK-NEXT: linalg.yield %[[ADD]] : f32 325// CHECK-NEXT: -> tensor<16x32xf32> 326