1 // Copyright 2018 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // output_msa.h: optimized MSA specializations of the templates in output.h. 16 17 #ifndef GEMMLOWP_INTERNAL_OUTPUT_MSA_H_ 18 #define GEMMLOWP_INTERNAL_OUTPUT_MSA_H_ 19 20 #include "output.h" 21 22 #include <msa.h> 23 24 namespace gemmlowp { 25 26 template <> 27 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 28 RegBufferInt32<4>> { 29 typedef RegBufferInt32<4> InputType; 30 typedef RegBufferUint8<4> OutputType; 31 32 typedef OutputStageSaturatingCastToUint8 OutputStage; 33 34 OutputStageEvalBufferImpl(const OutputStage&) {} 35 36 OutputType Eval(InputType input) const { 37 OutputType output; 38 // Signed saturate each 32-bit element to 9 bits 39 // (this takes full care of non-negative elements). 40 v4i32 tmp = __builtin_msa_sat_s_w(input.reg[0], 8); 41 // Zero out negative elements. 42 tmp = __builtin_msa_maxi_s_w(tmp, 0); 43 // Pack every 32-bit element into 16 bits. 44 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( 45 reinterpret_cast<v8i16>(tmp), reinterpret_cast<v8i16>(tmp))); 46 // Pack every element into 8 bits. 47 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( 48 reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp))); 49 // Return 4 uint8_t elements as uint32_t. 50 output.reg[0] = __builtin_msa_copy_s_w(tmp, 0); 51 return output; 52 } 53 }; 54 55 template <> 56 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 57 RegBufferInt32<8>> { 58 typedef RegBufferInt32<8> InputType; 59 typedef RegBufferUint8<8> OutputType; 60 61 typedef OutputStageSaturatingCastToUint8 OutputStage; 62 63 OutputStageEvalBufferImpl(const OutputStage&) {} 64 65 OutputType Eval(InputType input) const { 66 OutputType output; 67 // Signed saturate each 32-bit element to 9 bits 68 // (this takes full care of non-negative elements). 69 v4i32 tmp_lo = __builtin_msa_sat_s_w(input.reg[0], 8); 70 v4i32 tmp_hi = __builtin_msa_sat_s_w(input.reg[1], 8); 71 // Pack every 32-bit element into 16 bits, 72 // combining all 8 elements into one vector. 73 tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( 74 reinterpret_cast<v8i16>(tmp_hi), reinterpret_cast<v8i16>(tmp_lo))); 75 // Zero out negative elements. 76 tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h( 77 reinterpret_cast<v8i16>(tmp_lo), 0)); 78 // Pack every element into 8 bits. 79 tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( 80 reinterpret_cast<v16i8>(tmp_lo), reinterpret_cast<v16i8>(tmp_lo))); 81 // Return 8 uint8_t elements as 2 uint32_t's. 82 output.reg[0] = __builtin_msa_copy_s_w(tmp_lo, 0); 83 output.reg[1] = __builtin_msa_copy_s_w(tmp_lo, 1); 84 return output; 85 } 86 }; 87 88 #define GEMMLOWP_MIPS_SAT_U8_16(out, in0, in1, in2, in3) \ 89 { \ 90 v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 8); \ 91 v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 8); \ 92 v4i32 tmp2 = __builtin_msa_sat_s_w(in2, 8); \ 93 v4i32 tmp3 = __builtin_msa_sat_s_w(in3, 8); \ 94 tmp0 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( \ 95 reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0))); \ 96 tmp2 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( \ 97 reinterpret_cast<v8i16>(tmp3), reinterpret_cast<v8i16>(tmp2))); \ 98 tmp0 = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h( \ 99 reinterpret_cast<v8i16>(tmp0), 0)); \ 100 tmp2 = reinterpret_cast<v4i32>(__builtin_msa_maxi_s_h( \ 101 reinterpret_cast<v8i16>(tmp2), 0)); \ 102 tmp0 = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( \ 103 reinterpret_cast<v16i8>(tmp2), reinterpret_cast<v16i8>(tmp0))); \ 104 out = reinterpret_cast<v16i8>(tmp0); \ 105 } 106 107 template <> 108 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 109 RegBufferInt32<16>> { 110 typedef RegBufferInt32<16> InputType; 111 typedef RegBufferUint8<16> OutputType; 112 113 typedef OutputStageSaturatingCastToUint8 OutputStage; 114 115 OutputStageEvalBufferImpl(const OutputStage&) {} 116 117 OutputType Eval(InputType input) const { 118 OutputType output; 119 GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1], 120 input.reg[2], input.reg[3]); 121 return output; 122 } 123 }; 124 125 template <> 126 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 127 RegBufferInt32<32>> { 128 typedef RegBufferInt32<32> InputType; 129 typedef RegBufferUint8<32> OutputType; 130 131 typedef OutputStageSaturatingCastToUint8 OutputStage; 132 133 OutputStageEvalBufferImpl(const OutputStage&) {} 134 135 OutputType Eval(InputType input) const { 136 OutputType output; 137 GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1], 138 input.reg[2], input.reg[3]); 139 GEMMLOWP_MIPS_SAT_U8_16(output.reg[1], input.reg[4], input.reg[5], 140 input.reg[6], input.reg[7]); 141 return output; 142 } 143 }; 144 145 #undef GEMMLOWP_MIPS_SAT_U8_16 146 147 template <> 148 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 149 RegBufferInt32<4>> { 150 typedef RegBufferInt32<4> InputType; 151 typedef RegBufferInt16<4> OutputType; 152 153 typedef OutputStageSaturatingCastToInt16 OutputStage; 154 155 OutputStageEvalBufferImpl(const OutputStage&) {} 156 157 OutputType Eval(InputType input) const { 158 OutputType output; 159 // Signed saturate each 32-bit element to 16 bits. 160 v8i16 tmp = 161 reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(input.reg[0], 15)); 162 output.reg[0] = __builtin_msa_copy_s_h(tmp, 0); 163 output.reg[1] = __builtin_msa_copy_s_h(tmp, 2); 164 output.reg[2] = __builtin_msa_copy_s_h(tmp, 4); 165 output.reg[3] = __builtin_msa_copy_s_h(tmp, 6); 166 return output; 167 } 168 }; 169 170 #define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1) \ 171 { \ 172 v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15); \ 173 v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15); \ 174 out = __builtin_msa_pckev_h(reinterpret_cast<v8i16>(tmp1), \ 175 reinterpret_cast<v8i16>(tmp0)); \ 176 } 177 178 template <> 179 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 180 RegBufferInt32<8>> { 181 typedef RegBufferInt32<8> InputType; 182 typedef RegBufferInt16<8> OutputType; 183 184 typedef OutputStageSaturatingCastToInt16 OutputStage; 185 186 OutputStageEvalBufferImpl(const OutputStage&) {} 187 188 OutputType Eval(InputType input) const { 189 OutputType output; 190 GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]); 191 return output; 192 } 193 }; 194 195 template <> 196 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 197 RegBufferInt32<16>> { 198 typedef RegBufferInt32<16> InputType; 199 typedef RegBufferInt16<16> OutputType; 200 201 typedef OutputStageSaturatingCastToInt16 OutputStage; 202 203 OutputStageEvalBufferImpl(const OutputStage&) {} 204 205 OutputType Eval(InputType input) const { 206 OutputType output; 207 GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]); 208 GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]); 209 return output; 210 } 211 }; 212 213 template <> 214 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 215 RegBufferInt32<32>> { 216 typedef RegBufferInt32<32> InputType; 217 typedef RegBufferInt16<32> OutputType; 218 219 typedef OutputStageSaturatingCastToInt16 OutputStage; 220 221 OutputStageEvalBufferImpl(const OutputStage&) {} 222 223 OutputType Eval(InputType input) const { 224 OutputType output; 225 GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]); 226 GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]); 227 GEMMLOWP_MIPS_SAT_I16_8(output.reg[2], input.reg[4], input.reg[5]); 228 GEMMLOWP_MIPS_SAT_I16_8(output.reg[3], input.reg[6], input.reg[7]); 229 return output; 230 } 231 }; 232 233 #undef GEMMLOWP_MIPS_SAT_I16_8 234 235 template <> 236 struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8, 237 RegBufferInt32<4>> { 238 typedef RegBufferInt32<4> InputType; 239 typedef RegBufferUint8<4> OutputType; 240 241 typedef OutputStageTruncatingCastToUint8 OutputStage; 242 243 OutputStageEvalBufferImpl(const OutputStage&) {} 244 245 OutputType Eval(InputType input) const { 246 OutputType output; 247 // Pack every 32-bit element into 16 bits. 248 v4i32 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( 249 reinterpret_cast<v8i16>(input.reg[0]), 250 reinterpret_cast<v8i16>(input.reg[0]))); 251 // Pack every element into 8 bits. 252 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( 253 reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp))); 254 // Return 4 uint8_t elements as uint32_t. 255 output.reg[0] = __builtin_msa_copy_s_w(tmp, 0); 256 return output; 257 } 258 }; 259 260 template <> 261 struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8, 262 RegBufferInt32<8>> { 263 typedef RegBufferInt32<8> InputType; 264 typedef RegBufferUint8<8> OutputType; 265 266 typedef OutputStageTruncatingCastToUint8 OutputStage; 267 268 OutputStageEvalBufferImpl(const OutputStage&) {} 269 270 OutputType Eval(InputType input) const { 271 OutputType output; 272 // Pack every 32-bit element into 16 bits. 273 v4i32 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h( 274 reinterpret_cast<v8i16>(input.reg[1]), 275 reinterpret_cast<v8i16>(input.reg[0]))); 276 // Pack every element into 8 bits. 277 tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b( 278 reinterpret_cast<v16i8>(tmp), reinterpret_cast<v16i8>(tmp))); 279 // Return 8 uint8_t elements as 2 uint32_t's. 280 output.reg[0] = __builtin_msa_copy_s_w(tmp, 0); 281 output.reg[1] = __builtin_msa_copy_s_w(tmp, 1); 282 return output; 283 } 284 }; 285 286 template <> 287 struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8, 288 RegBufferInt32<16>> { 289 typedef RegBufferInt32<16> InputType; 290 typedef RegBufferUint8<16> OutputType; 291 292 typedef OutputStageTruncatingCastToUint8 OutputStage; 293 294 OutputStageEvalBufferImpl(const OutputStage&) {} 295 296 OutputType Eval(InputType input) const { 297 OutputType output; 298 // Pack every 32-bit element into 16 bits. 299 v8i16 tmp0 = __builtin_msa_pckev_h( 300 reinterpret_cast<v8i16>(input.reg[1]), 301 reinterpret_cast<v8i16>(input.reg[0])); 302 v8i16 tmp1 = __builtin_msa_pckev_h( 303 reinterpret_cast<v8i16>(input.reg[3]), 304 reinterpret_cast<v8i16>(input.reg[2])); 305 // Pack every element into 8 bits. 306 output.reg[0] = __builtin_msa_pckev_b( 307 reinterpret_cast<v16i8>(tmp1), reinterpret_cast<v16i8>(tmp0)); 308 return output; 309 } 310 }; 311 312 template <> 313 struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8, 314 RegBufferInt32<32>> { 315 typedef RegBufferInt32<32> InputType; 316 typedef RegBufferUint8<32> OutputType; 317 318 typedef OutputStageTruncatingCastToUint8 OutputStage; 319 320 OutputStageEvalBufferImpl(const OutputStage&) {} 321 322 OutputType Eval(InputType input) const { 323 OutputType output; 324 // Pack every 32-bit element into 16 bits. 325 v8i16 tmp0 = __builtin_msa_pckev_h( 326 reinterpret_cast<v8i16>(input.reg[1]), 327 reinterpret_cast<v8i16>(input.reg[0])); 328 v8i16 tmp1 = __builtin_msa_pckev_h( 329 reinterpret_cast<v8i16>(input.reg[3]), 330 reinterpret_cast<v8i16>(input.reg[2])); 331 v8i16 tmp2 = __builtin_msa_pckev_h( 332 reinterpret_cast<v8i16>(input.reg[5]), 333 reinterpret_cast<v8i16>(input.reg[4])); 334 v8i16 tmp3 = __builtin_msa_pckev_h( 335 reinterpret_cast<v8i16>(input.reg[7]), 336 reinterpret_cast<v8i16>(input.reg[6])); 337 // Pack every element into 8 bits. 338 output.reg[0] = __builtin_msa_pckev_b( 339 reinterpret_cast<v16i8>(tmp1), reinterpret_cast<v16i8>(tmp0)); 340 output.reg[1] = __builtin_msa_pckev_b( 341 reinterpret_cast<v16i8>(tmp3), reinterpret_cast<v16i8>(tmp2)); 342 return output; 343 } 344 }; 345 346 template <typename DstType> 347 struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> { 348 static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row, 349 int col) { 350 if (DstType::kOrder == MapOrder::ColMajor) { 351 StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 352 } else { 353 *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); 354 *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); 355 *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); 356 *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); 357 } 358 } 359 }; 360 361 template <typename DstType> 362 struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> { 363 static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row, 364 int col) { 365 if (DstType::kOrder == MapOrder::ColMajor) { 366 StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 367 StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]); 368 } else { 369 *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); 370 *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); 371 *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); 372 *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); 373 *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]); 374 *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]); 375 *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]); 376 *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]); 377 } 378 } 379 }; 380 381 template <typename DstType> 382 struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> { 383 static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row, 384 int col) { 385 *dst->data(row + 0, col) = src.buf.reg[0]; 386 *dst->data(row + 1, col) = src.buf.reg[1]; 387 *dst->data(row + 2, col) = src.buf.reg[2]; 388 *dst->data(row + 3, col) = src.buf.reg[3]; 389 } 390 }; 391 392 template <typename DstType> 393 struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> { 394 static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row, 395 int col) { 396 if (DstType::kOrder == MapOrder::ColMajor) { 397 StoreInt16x8(dst->data(row, col), src.buf.reg[0]); 398 } else { 399 *dst->data(row + 0, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 0); 400 *dst->data(row + 1, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 1); 401 *dst->data(row + 2, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 2); 402 *dst->data(row + 3, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 3); 403 *dst->data(row + 4, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 4); 404 *dst->data(row + 5, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 5); 405 *dst->data(row + 6, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 6); 406 *dst->data(row + 7, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 7); 407 } 408 } 409 }; 410 411 inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) { 412 RegBlockInt32<4, 4> result; 413 v4i32 tmp0, tmp1; 414 tmp0 = __builtin_msa_ilvr_w(src.buf.reg[1], src.buf.reg[0]); 415 tmp1 = __builtin_msa_ilvr_w(src.buf.reg[3], src.buf.reg[2]); 416 result.buf.reg[0] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d( 417 reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0))); 418 result.buf.reg[1] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d( 419 reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0))); 420 tmp0 = __builtin_msa_ilvl_w(src.buf.reg[1], src.buf.reg[0]); 421 tmp1 = __builtin_msa_ilvl_w(src.buf.reg[3], src.buf.reg[2]); 422 result.buf.reg[2] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d( 423 reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0))); 424 result.buf.reg[3] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d( 425 reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0))); 426 return result; 427 } 428 429 template <typename DstType> 430 struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> { 431 static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row, 432 int col) { 433 if (DstType::kOrder == MapOrder::ColMajor) { 434 for (int i = 0; i < 4; i++) { 435 StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]); 436 } 437 } else { 438 const auto transpose = Transpose(src); 439 for (int i = 0; i < 4; i++) { 440 StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]); 441 } 442 } 443 } 444 }; 445 446 template <typename DstType> 447 struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> { 448 static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row, 449 int col) { 450 std::int16_t buf[16]; 451 StoreInt16x8(buf + 0, src.buf.reg[0]); 452 StoreInt16x8(buf + 8, src.buf.reg[1]); 453 for (int i = 0; i < 4; i++) { 454 for (int j = 0; j < 4; j++) { 455 *dst->data(row + i, col + j) = buf[i + 4 * j]; 456 } 457 } 458 } 459 }; 460 461 template <typename DstType> 462 struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> { 463 static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row, 464 int col) { 465 if (DstType::kOrder == MapOrder::ColMajor) { 466 for (int i = 0; i < 4; i++) { 467 StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]); 468 StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]); 469 } 470 } else { 471 RegBlockInt32<4, 4> top; 472 top.buf.reg[0] = src.buf.reg[0]; 473 top.buf.reg[1] = src.buf.reg[2]; 474 top.buf.reg[2] = src.buf.reg[4]; 475 top.buf.reg[3] = src.buf.reg[6]; 476 const auto transpose_top = Transpose(top); 477 for (int i = 0; i < 4; i++) { 478 StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]); 479 } 480 RegBlockInt32<4, 4> bottom; 481 bottom.buf.reg[0] = src.buf.reg[1]; 482 bottom.buf.reg[1] = src.buf.reg[3]; 483 bottom.buf.reg[2] = src.buf.reg[5]; 484 bottom.buf.reg[3] = src.buf.reg[7]; 485 const auto transpose_bottom = Transpose(bottom); 486 for (int i = 0; i < 4; i++) { 487 StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]); 488 } 489 } 490 } 491 }; 492 493 template <typename DstType> 494 struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> { 495 static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row, 496 int col) { 497 if (DstType::kOrder == MapOrder::ColMajor) { 498 for (int i = 0; i < 4; i++) { 499 StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]); 500 } 501 } else { 502 std::int16_t buf[32]; 503 StoreInt16x8(buf + 0, src.buf.reg[0]); 504 StoreInt16x8(buf + 8, src.buf.reg[1]); 505 StoreInt16x8(buf + 16, src.buf.reg[2]); 506 StoreInt16x8(buf + 24, src.buf.reg[3]); 507 for (int i = 0; i < 8; i++) { 508 for (int j = 0; j < 4; j++) { 509 *dst->data(row + i, col + j) = buf[i + 8 * j]; 510 } 511 } 512 } 513 } 514 }; 515 516 template <typename DstType> 517 struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> { 518 static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row, 519 int col) { 520 if (DstType::kOrder == MapOrder::ColMajor) { 521 for (int i = 0; i < 8; i++) { 522 StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]); 523 StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]); 524 } 525 } else { 526 RegBlockInt32<4, 4> top_left; 527 top_left.buf.reg[0] = src.buf.reg[0]; 528 top_left.buf.reg[1] = src.buf.reg[2]; 529 top_left.buf.reg[2] = src.buf.reg[4]; 530 top_left.buf.reg[3] = src.buf.reg[6]; 531 const auto transpose_top_left = Transpose(top_left); 532 for (int i = 0; i < 4; i++) { 533 StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]); 534 } 535 RegBlockInt32<4, 4> bottom_left; 536 bottom_left.buf.reg[0] = src.buf.reg[1]; 537 bottom_left.buf.reg[1] = src.buf.reg[3]; 538 bottom_left.buf.reg[2] = src.buf.reg[5]; 539 bottom_left.buf.reg[3] = src.buf.reg[7]; 540 const auto transpose_bottom_left = Transpose(bottom_left); 541 for (int i = 0; i < 4; i++) { 542 StoreInt32x4(dst->data(row + 4 + i, col), 543 transpose_bottom_left.buf.reg[i]); 544 } 545 RegBlockInt32<4, 4> top_right; 546 top_right.buf.reg[0] = src.buf.reg[8]; 547 top_right.buf.reg[1] = src.buf.reg[10]; 548 top_right.buf.reg[2] = src.buf.reg[12]; 549 top_right.buf.reg[3] = src.buf.reg[14]; 550 const auto transpose_top_right = Transpose(top_right); 551 for (int i = 0; i < 4; i++) { 552 StoreInt32x4(dst->data(row + i, col + 4), 553 transpose_top_right.buf.reg[i]); 554 } 555 RegBlockInt32<4, 4> bottom_right; 556 bottom_right.buf.reg[0] = src.buf.reg[9]; 557 bottom_right.buf.reg[1] = src.buf.reg[11]; 558 bottom_right.buf.reg[2] = src.buf.reg[13]; 559 bottom_right.buf.reg[3] = src.buf.reg[15]; 560 const auto transpose_bottom_right = Transpose(bottom_right); 561 for (int i = 0; i < 4; i++) { 562 StoreInt32x4(dst->data(row + 4 + i, col + 4), 563 transpose_bottom_right.buf.reg[i]); 564 } 565 } 566 } 567 }; 568 569 template <typename DstType> 570 struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> { 571 static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row, 572 int col) { 573 if (DstType::kOrder == MapOrder::ColMajor) { 574 for (int i = 0; i < 8; i++) { 575 StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]); 576 } 577 } else { 578 // top-left 4x4 579 v4i32 t0 = reinterpret_cast<v4i32>( 580 __builtin_msa_ilvr_h(src.buf.reg[1], src.buf.reg[0])); 581 v4i32 t1 = reinterpret_cast<v4i32>( 582 __builtin_msa_ilvr_h(src.buf.reg[3], src.buf.reg[2])); 583 v2i64 u0 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t1, t0)); 584 v2i64 u1 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t1, t0)); 585 // top-right 4x4 586 v4i32 t2 = reinterpret_cast<v4i32>( 587 __builtin_msa_ilvr_h(src.buf.reg[5], src.buf.reg[4])); 588 v4i32 t3 = reinterpret_cast<v4i32>( 589 __builtin_msa_ilvr_h(src.buf.reg[7], src.buf.reg[6])); 590 v2i64 u2 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t3, t2)); 591 v2i64 u3 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t3, t2)); 592 // bottom-left 4x4 593 v4i32 t4 = reinterpret_cast<v4i32>( 594 __builtin_msa_ilvl_h(src.buf.reg[1], src.buf.reg[0])); 595 v4i32 t5 = reinterpret_cast<v4i32>( 596 __builtin_msa_ilvl_h(src.buf.reg[3], src.buf.reg[2])); 597 v2i64 u4 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t5, t4)); 598 v2i64 u5 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t5, t4)); 599 // bottom-right 4x4 600 v4i32 t6 = reinterpret_cast<v4i32>( 601 __builtin_msa_ilvl_h(src.buf.reg[5], src.buf.reg[4])); 602 v4i32 t7 = reinterpret_cast<v4i32>( 603 __builtin_msa_ilvl_h(src.buf.reg[7], src.buf.reg[6])); 604 v2i64 u6 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t7, t6)); 605 v2i64 u7 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t7, t6)); 606 607 StoreInt16x8(dst->data(row + 0, col), 608 reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u2, u0))); 609 StoreInt16x8(dst->data(row + 1, col), 610 reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u2, u0))); 611 StoreInt16x8(dst->data(row + 2, col), 612 reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u3, u1))); 613 StoreInt16x8(dst->data(row + 3, col), 614 reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u3, u1))); 615 StoreInt16x8(dst->data(row + 4, col), 616 reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u6, u4))); 617 StoreInt16x8(dst->data(row + 5, col), 618 reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u6, u4))); 619 StoreInt16x8(dst->data(row + 6, col), 620 reinterpret_cast<v8i16>(__builtin_msa_ilvr_d(u7, u5))); 621 StoreInt16x8(dst->data(row + 7, col), 622 reinterpret_cast<v8i16>(__builtin_msa_ilvl_d(u7, u5))); 623 } 624 } 625 }; 626 627 template <typename DstType> 628 struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> { 629 static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row, 630 int col) { 631 if (DstType::kOrder == MapOrder::ColMajor) { 632 *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]); 633 *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]); 634 *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]); 635 *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]); 636 } else { 637 StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 638 } 639 } 640 }; 641 642 template <typename DstType> 643 struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> { 644 static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row, 645 int col) { 646 const std::uint32_t src_reg = src.buf.reg[0]; 647 for (int i = 0; i < 4; i++) { 648 *dst->data(row + i, col) = (src_reg >> (8 * i)); 649 } 650 } 651 }; 652 653 template <typename DstType> 654 struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> { 655 static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row, 656 int col) { 657 for (int i = 0; i < 4; i++) { 658 *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i)); 659 } 660 for (int i = 0; i < 4; i++) { 661 *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i)); 662 } 663 } 664 }; 665 666 template <typename DstType> 667 struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> { 668 static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row, 669 int col) { 670 for (int i = 0; i < 4; i++) { 671 *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); 672 } 673 } 674 }; 675 676 template <typename DstType> 677 struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> { 678 static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row, 679 int col) { 680 std::uint8_t buf[16]; 681 StoreUint8x16(buf, src.buf.reg[0]); 682 for (int c = 0; c < 4; c++) { 683 for (int r = 0; r < 4; r++) { 684 *dst->data(row + r, col + c) = buf[r + 4 * c]; 685 } 686 } 687 } 688 }; 689 690 // There's no way to express in C++ the desired machine code for 691 // StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> and 692 // StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType>. 693 // Hence, if we can, we use inline assembly, which takes advantage 694 // of little-endian byte order and specifics of different CPU revisions. 695 // Note, clang currently can't derive MSA register names from floating- 696 // point register names and vice versa in inline assembly. 697 #if defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) && \ 698 !defined(__clang__) 699 700 // Instructions for pointer-sized operands. 701 #ifdef GEMMLOWP_MIPS_64 702 #define GEMMLOWP_MIPS_XADDU "daddu" 703 #define GEMMLOWP_MIPS_XLSA "dlsa" 704 #else 705 #define GEMMLOWP_MIPS_XADDU "addu" 706 #define GEMMLOWP_MIPS_XLSA "lsa" 707 #endif 708 709 // Stores 4 8-byte half-vectors with a stride. 710 inline void MipsMsaStore4x8(const RegBlockUint8<8, 4>& src, 711 std::uint8_t* dst_ptr, int stride) { 712 #if (__mips_isa_rev >= 6) 713 // Assembly temporaries that will be handily referred to by their names. 714 std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3; 715 v16i8 vtmp0, vtmp1; 716 asm volatile( 717 GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n" 718 "ilvl.d %w[vtmp0], %w[src0], %w[src0]\n" 719 GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n" 720 "ilvl.d %w[vtmp1], %w[src1], %w[src1]\n" 721 GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n" 722 "sdc1 %[src0], 0(%[dst_ptr0])\n" 723 "sdc1 %[vtmp0], 0(%[dst_ptr1])\n" 724 "sdc1 %[src1], 0(%[dst_ptr2])\n" 725 "sdc1 %[vtmp1], 0(%[dst_ptr3])\n" 726 : 727 // Outputs. 728 [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1), 729 [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), 730 [vtmp0] "=&f"(vtmp0), [vtmp1] "=&f"(vtmp1) 731 : 732 // Inputs. 733 [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]), 734 [stride] "r"(stride) 735 : 736 // Clobbers. 737 "memory"); 738 #else 739 // Assembly temporaries that will be handily referred to by their names. 740 std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3; 741 int tmp0, tmp1, tmp2, tmp3; 742 asm volatile( 743 GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n" 744 GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n" 745 GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n" 746 "copy_s.w %[tmp0], %w[src0][0]\n" 747 "copy_s.w %[tmp1], %w[src0][1]\n" 748 "copy_s.w %[tmp2], %w[src0][2]\n" 749 "copy_s.w %[tmp3], %w[src0][3]\n" 750 "swr %[tmp0], 0(%[dst_ptr0])\n" 751 "swl %[tmp0], 3(%[dst_ptr0])\n" 752 "swr %[tmp1], 4(%[dst_ptr0])\n" 753 "swl %[tmp1], 7(%[dst_ptr0])\n" 754 "swr %[tmp2], 0(%[dst_ptr1])\n" 755 "swl %[tmp2], 3(%[dst_ptr1])\n" 756 "swr %[tmp3], 4(%[dst_ptr1])\n" 757 "swl %[tmp3], 7(%[dst_ptr1])\n" 758 "copy_s.w %[tmp0], %w[src1][0]\n" 759 "copy_s.w %[tmp1], %w[src1][1]\n" 760 "copy_s.w %[tmp2], %w[src1][2]\n" 761 "copy_s.w %[tmp3], %w[src1][3]\n" 762 "swr %[tmp0], 0(%[dst_ptr2])\n" 763 "swl %[tmp0], 3(%[dst_ptr2])\n" 764 "swr %[tmp1], 4(%[dst_ptr2])\n" 765 "swl %[tmp1], 7(%[dst_ptr2])\n" 766 "swr %[tmp2], 0(%[dst_ptr3])\n" 767 "swl %[tmp2], 3(%[dst_ptr3])\n" 768 "swr %[tmp3], 4(%[dst_ptr3])\n" 769 "swl %[tmp3], 7(%[dst_ptr3])\n" 770 : 771 // Outputs. 772 [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1), 773 [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), [tmp0] "=&r"(tmp0), 774 [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), [tmp3] "=&r"(tmp3) 775 : 776 // Inputs. 777 [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]), 778 [stride] "r"(stride) 779 : 780 // Clobbers. 781 "memory"); 782 #endif 783 } 784 785 // Stores 8 4-byte quarter-vectors with a stride. 786 inline void MipsMsaStore8x4(const RegBlockUint8<4, 8>& src, 787 std::uint8_t* dst_ptr, int stride) { 788 #if (__mips_isa_rev >= 6) 789 // Assembly temporaries that will be handily referred to by their names. 790 std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5, 791 *dst_ptr6, *dst_ptr7; 792 int tmp1, tmp2, tmp3; 793 asm volatile( 794 GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n" 795 GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n" 796 GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n" 797 GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n" 798 GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n" 799 GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n" 800 GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n" 801 "copy_s.w %[tmp1], %w[src0][1]\n" 802 "copy_s.w %[tmp2], %w[src0][2]\n" 803 "copy_s.w %[tmp3], %w[src0][3]\n" 804 "swc1 %[src0], 0(%[dst_ptr0])\n" 805 "sw %[tmp1], 0(%[dst_ptr1])\n" 806 "sw %[tmp2], 0(%[dst_ptr2])\n" 807 "sw %[tmp3], 0(%[dst_ptr3])\n" 808 "copy_s.w %[tmp1], %w[src1][1]\n" 809 "copy_s.w %[tmp2], %w[src1][2]\n" 810 "copy_s.w %[tmp3], %w[src1][3]\n" 811 "swc1 %[src1], 0(%[dst_ptr4])\n" 812 "sw %[tmp1], 0(%[dst_ptr5])\n" 813 "sw %[tmp2], 0(%[dst_ptr6])\n" 814 "sw %[tmp3], 0(%[dst_ptr7])\n" 815 : 816 // Outputs. 817 [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1), 818 [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), 819 [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5), 820 [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7), 821 [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), [tmp3] "=&r"(tmp3) 822 : 823 // Inputs. 824 [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]), 825 [stride] "r"(stride) 826 : 827 // Clobbers. 828 "memory"); 829 #else 830 // Assembly temporaries that will be handily referred to by their names. 831 std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5, 832 *dst_ptr6, *dst_ptr7; 833 int tmp0, tmp1, tmp2, tmp3; 834 asm volatile( 835 GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n" 836 GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n" 837 GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n" 838 GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n" 839 GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n" 840 GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n" 841 GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n" 842 "copy_s.w %[tmp0], %w[src0][0]\n" 843 "copy_s.w %[tmp1], %w[src0][1]\n" 844 "copy_s.w %[tmp2], %w[src0][2]\n" 845 "copy_s.w %[tmp3], %w[src0][3]\n" 846 "swr %[tmp0], 0(%[dst_ptr0])\n" 847 "swl %[tmp0], 3(%[dst_ptr0])\n" 848 "swr %[tmp1], 0(%[dst_ptr1])\n" 849 "swl %[tmp1], 3(%[dst_ptr1])\n" 850 "swr %[tmp2], 0(%[dst_ptr2])\n" 851 "swl %[tmp2], 3(%[dst_ptr2])\n" 852 "swr %[tmp3], 0(%[dst_ptr3])\n" 853 "swl %[tmp3], 3(%[dst_ptr3])\n" 854 "copy_s.w %[tmp0], %w[src1][0]\n" 855 "copy_s.w %[tmp1], %w[src1][1]\n" 856 "copy_s.w %[tmp2], %w[src1][2]\n" 857 "copy_s.w %[tmp3], %w[src1][3]\n" 858 "swr %[tmp0], 0(%[dst_ptr4])\n" 859 "swl %[tmp0], 3(%[dst_ptr4])\n" 860 "swr %[tmp1], 0(%[dst_ptr5])\n" 861 "swl %[tmp1], 3(%[dst_ptr5])\n" 862 "swr %[tmp2], 0(%[dst_ptr6])\n" 863 "swl %[tmp2], 3(%[dst_ptr6])\n" 864 "swr %[tmp3], 0(%[dst_ptr7])\n" 865 "swl %[tmp3], 3(%[dst_ptr7])\n" 866 : 867 // Outputs. 868 [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1), 869 [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), 870 [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5), 871 [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7), 872 [tmp0] "=&r"(tmp0), [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), 873 [tmp3] "=&r"(tmp3) 874 : 875 // Inputs. 876 [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]), 877 [stride] "r"(stride) 878 : 879 // Clobbers. 880 "memory"); 881 #endif 882 } 883 884 // Stores 8 8-byte half-vectors with a stride. 885 inline void MipsMsaStore8x8(const RegBlockUint8<8, 8>& src, 886 std::uint8_t* dst_ptr, int stride) { 887 #if (__mips_isa_rev >= 6) 888 // Assembly temporaries that will be handily referred to by their names. 889 std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5, 890 *dst_ptr6, *dst_ptr7; 891 v16i8 vtmp0, vtmp1, vtmp2, vtmp3; 892 asm volatile( 893 "ilvl.d %w[vtmp0], %w[src0], %w[src0]\n" 894 GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n" 895 GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n" 896 "ilvl.d %w[vtmp1], %w[src1], %w[src1]\n" 897 GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n" 898 GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n" 899 "ilvl.d %w[vtmp2], %w[src2], %w[src2]\n" 900 GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n" 901 GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n" 902 "ilvl.d %w[vtmp3], %w[src3], %w[src3]\n" 903 GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n" 904 "sdc1 %[src0], 0(%[dst_ptr0])\n" 905 "sdc1 %[vtmp0], 0(%[dst_ptr1])\n" 906 "sdc1 %[src1], 0(%[dst_ptr2])\n" 907 "sdc1 %[vtmp1], 0(%[dst_ptr3])\n" 908 "sdc1 %[src2], 0(%[dst_ptr4])\n" 909 "sdc1 %[vtmp2], 0(%[dst_ptr5])\n" 910 "sdc1 %[src3], 0(%[dst_ptr6])\n" 911 "sdc1 %[vtmp3], 0(%[dst_ptr7])\n" 912 : 913 // Outputs. 914 [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1), 915 [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), 916 [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5), 917 [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7), 918 [vtmp0] "=&f"(vtmp0), [vtmp1] "=&f"(vtmp1), [vtmp2] "=&f"(vtmp2), 919 [vtmp3] "=&f"(vtmp3) 920 : 921 // Inputs. 922 [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]), 923 [src2] "f"(src.buf.reg[2]), [src3] "f"(src.buf.reg[3]), 924 [stride] "r"(stride) 925 : 926 // Clobbers. 927 "memory"); 928 #else 929 // Assembly temporaries that will be handily referred to by their names. 930 std::uint8_t *dst_ptr1, *dst_ptr2, *dst_ptr3, *dst_ptr4, *dst_ptr5, 931 *dst_ptr6, *dst_ptr7; 932 int tmp0, tmp1, tmp2, tmp3; 933 asm volatile( 934 GEMMLOWP_MIPS_XADDU " %[dst_ptr1], %[dst_ptr0], %[stride]\n" 935 GEMMLOWP_MIPS_XLSA " %[dst_ptr2], %[stride], %[dst_ptr0], 1\n" 936 GEMMLOWP_MIPS_XLSA " %[dst_ptr4], %[stride], %[dst_ptr0], 2\n" 937 GEMMLOWP_MIPS_XLSA " %[dst_ptr3], %[stride], %[dst_ptr1], 1\n" 938 GEMMLOWP_MIPS_XLSA " %[dst_ptr5], %[stride], %[dst_ptr1], 2\n" 939 GEMMLOWP_MIPS_XLSA " %[dst_ptr6], %[stride], %[dst_ptr2], 2\n" 940 GEMMLOWP_MIPS_XLSA " %[dst_ptr7], %[stride], %[dst_ptr3], 2\n" 941 "copy_s.w %[tmp0], %w[src0][0]\n" 942 "copy_s.w %[tmp1], %w[src0][1]\n" 943 "copy_s.w %[tmp2], %w[src0][2]\n" 944 "copy_s.w %[tmp3], %w[src0][3]\n" 945 "swr %[tmp0], 0(%[dst_ptr0])\n" 946 "swl %[tmp0], 3(%[dst_ptr0])\n" 947 "swr %[tmp1], 4(%[dst_ptr0])\n" 948 "swl %[tmp1], 7(%[dst_ptr0])\n" 949 "swr %[tmp2], 0(%[dst_ptr1])\n" 950 "swl %[tmp2], 3(%[dst_ptr1])\n" 951 "swr %[tmp3], 4(%[dst_ptr1])\n" 952 "swl %[tmp3], 7(%[dst_ptr1])\n" 953 "copy_s.w %[tmp0], %w[src1][0]\n" 954 "copy_s.w %[tmp1], %w[src1][1]\n" 955 "copy_s.w %[tmp2], %w[src1][2]\n" 956 "copy_s.w %[tmp3], %w[src1][3]\n" 957 "swr %[tmp0], 0(%[dst_ptr2])\n" 958 "swl %[tmp0], 3(%[dst_ptr2])\n" 959 "swr %[tmp1], 4(%[dst_ptr2])\n" 960 "swl %[tmp1], 7(%[dst_ptr2])\n" 961 "swr %[tmp2], 0(%[dst_ptr3])\n" 962 "swl %[tmp2], 3(%[dst_ptr3])\n" 963 "swr %[tmp3], 4(%[dst_ptr3])\n" 964 "swl %[tmp3], 7(%[dst_ptr3])\n" 965 "copy_s.w %[tmp0], %w[src2][0]\n" 966 "copy_s.w %[tmp1], %w[src2][1]\n" 967 "copy_s.w %[tmp2], %w[src2][2]\n" 968 "copy_s.w %[tmp3], %w[src2][3]\n" 969 "swr %[tmp0], 0(%[dst_ptr4])\n" 970 "swl %[tmp0], 3(%[dst_ptr4])\n" 971 "swr %[tmp1], 4(%[dst_ptr4])\n" 972 "swl %[tmp1], 7(%[dst_ptr4])\n" 973 "swr %[tmp2], 0(%[dst_ptr5])\n" 974 "swl %[tmp2], 3(%[dst_ptr5])\n" 975 "swr %[tmp3], 4(%[dst_ptr5])\n" 976 "swl %[tmp3], 7(%[dst_ptr5])\n" 977 "copy_s.w %[tmp0], %w[src3][0]\n" 978 "copy_s.w %[tmp1], %w[src3][1]\n" 979 "copy_s.w %[tmp2], %w[src3][2]\n" 980 "copy_s.w %[tmp3], %w[src3][3]\n" 981 "swr %[tmp0], 0(%[dst_ptr6])\n" 982 "swl %[tmp0], 3(%[dst_ptr6])\n" 983 "swr %[tmp1], 4(%[dst_ptr6])\n" 984 "swl %[tmp1], 7(%[dst_ptr6])\n" 985 "swr %[tmp2], 0(%[dst_ptr7])\n" 986 "swl %[tmp2], 3(%[dst_ptr7])\n" 987 "swr %[tmp3], 4(%[dst_ptr7])\n" 988 "swl %[tmp3], 7(%[dst_ptr7])\n" 989 : 990 // Outputs. 991 [dst_ptr0] "+r"(dst_ptr), [dst_ptr1] "=&r"(dst_ptr1), 992 [dst_ptr2] "=&r"(dst_ptr2), [dst_ptr3] "=&r"(dst_ptr3), 993 [dst_ptr4] "=&r"(dst_ptr4), [dst_ptr5] "=&r"(dst_ptr5), 994 [dst_ptr6] "=&r"(dst_ptr6), [dst_ptr7] "=&r"(dst_ptr7), 995 [tmp0] "=&r"(tmp0), [tmp1] "=&r"(tmp1), [tmp2] "=&r"(tmp2), 996 [tmp3] "=&r"(tmp3) 997 : 998 // Inputs. 999 [src0] "f"(src.buf.reg[0]), [src1] "f"(src.buf.reg[1]), 1000 [src2] "f"(src.buf.reg[2]), [src3] "f"(src.buf.reg[3]), 1001 [stride] "r"(stride) 1002 : 1003 // Clobbers. 1004 "memory"); 1005 #endif 1006 } 1007 1008 #undef GEMMLOWP_MIPS_XADDU 1009 #undef GEMMLOWP_MIPS_XLSA 1010 1011 // Transposes a column-major 8x4 block for storage into a row-major matrix. 1012 inline RegBlockUint8<4, 8> Transpose(const RegBlockUint8<8, 4>& src) { 1013 v16i8 tmp0 = __builtin_msa_ilvr_b(src.buf.reg[1], src.buf.reg[0]); 1014 v16i8 tmp1 = __builtin_msa_ilvl_b(src.buf.reg[1], src.buf.reg[0]); 1015 RegBlockUint8<4, 8> result; 1016 result.buf.reg[0] = __builtin_msa_ilvr_b(tmp1, tmp0); 1017 result.buf.reg[1] = __builtin_msa_ilvl_b(tmp1, tmp0); 1018 return result; 1019 } 1020 1021 inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) { 1022 v16i8 tmp0[4]; 1023 tmp0[0] = __builtin_msa_ilvr_b(src.buf.reg[1], src.buf.reg[0]); 1024 tmp0[1] = __builtin_msa_ilvl_b(src.buf.reg[1], src.buf.reg[0]); 1025 tmp0[2] = __builtin_msa_ilvr_b(src.buf.reg[3], src.buf.reg[2]); 1026 tmp0[3] = __builtin_msa_ilvl_b(src.buf.reg[3], src.buf.reg[2]); 1027 v16i8 tmp1[4]; 1028 tmp1[0] = __builtin_msa_ilvr_b(tmp0[1], tmp0[0]); 1029 tmp1[1] = __builtin_msa_ilvl_b(tmp0[1], tmp0[0]); 1030 tmp1[2] = __builtin_msa_ilvr_b(tmp0[3], tmp0[2]); 1031 tmp1[3] = __builtin_msa_ilvl_b(tmp0[3], tmp0[2]); 1032 RegBlockUint8<8, 8> result; 1033 result.buf.reg[0] = reinterpret_cast<v16i8>(__builtin_msa_ilvr_w( 1034 reinterpret_cast<v4i32>(tmp1[2]), reinterpret_cast<v4i32>(tmp1[0]))); 1035 result.buf.reg[1] = reinterpret_cast<v16i8>(__builtin_msa_ilvl_w( 1036 reinterpret_cast<v4i32>(tmp1[2]), reinterpret_cast<v4i32>(tmp1[0]))); 1037 result.buf.reg[2] = reinterpret_cast<v16i8>(__builtin_msa_ilvr_w( 1038 reinterpret_cast<v4i32>(tmp1[3]), reinterpret_cast<v4i32>(tmp1[1]))); 1039 result.buf.reg[3] = reinterpret_cast<v16i8>(__builtin_msa_ilvl_w( 1040 reinterpret_cast<v4i32>(tmp1[3]), reinterpret_cast<v4i32>(tmp1[1]))); 1041 return result; 1042 } 1043 1044 template <typename DstType> 1045 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { 1046 static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, 1047 int col) { 1048 if (DstType::kOrder == MapOrder::ColMajor) { 1049 std::uint8_t* dst_ptr = dst->data(row, col); 1050 int col_stride = dst->cols_stride(); 1051 MipsMsaStore4x8(src, dst_ptr, col_stride); 1052 } else { 1053 const auto& block = Transpose(src); 1054 std::uint8_t* dst_ptr = dst->data(row, col); 1055 int row_stride = dst->rows_stride(); 1056 MipsMsaStore8x4(block, dst_ptr, row_stride); 1057 } 1058 } 1059 }; 1060 1061 template <typename DstType> 1062 struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { 1063 static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, 1064 int col) { 1065 const auto& block = 1066 (DstType::kOrder == MapOrder::ColMajor) ? src : Transpose(src); 1067 std::uint8_t* dst_ptr = dst->data(row, col); 1068 int stride = dst->stride(); 1069 MipsMsaStore8x8(block, dst_ptr, stride); 1070 } 1071 }; 1072 1073 #else 1074 1075 template <typename DstType> 1076 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { 1077 static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, 1078 int col) { 1079 std::uint8_t buf[32]; 1080 StoreUint8x16(buf, src.buf.reg[0]); 1081 StoreUint8x16(buf + 16, src.buf.reg[1]); 1082 for (int c = 0; c < 4; c++) { 1083 for (int r = 0; r < 8; r++) { 1084 *dst->data(row + r, col + c) = buf[r + 8 * c]; 1085 } 1086 } 1087 } 1088 }; 1089 1090 template <typename DstType> 1091 struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { 1092 static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, 1093 int col) { 1094 std::uint8_t buf[64]; 1095 StoreUint8x16(buf, src.buf.reg[0]); 1096 StoreUint8x16(buf + 16, src.buf.reg[1]); 1097 StoreUint8x16(buf + 32, src.buf.reg[2]); 1098 StoreUint8x16(buf + 48, src.buf.reg[3]); 1099 for (int c = 0; c < 8; c++) { 1100 for (int r = 0; r < 8; r++) { 1101 *dst->data(row + r, col + c) = buf[r + 8 * c]; 1102 } 1103 } 1104 } 1105 }; 1106 1107 #endif // Endianness, compiler. 1108 1109 } // namespace gemmlowp 1110 1111 #endif // GEMMLOWP_INTERNAL_OUTPUT_MSA_H_ 1112