1 #ifndef PYTHONIC_NUMPY_DOT_HPP 2 #define PYTHONIC_NUMPY_DOT_HPP 3 4 #include "pythonic/include/numpy/dot.hpp" 5 6 #include "pythonic/types/ndarray.hpp" 7 #include "pythonic/numpy/sum.hpp" 8 #include "pythonic/numpy/multiply.hpp" 9 #include "pythonic/types/traits.hpp" 10 11 #ifdef PYTHRAN_BLAS_NONE 12 #error pythran configured without BLAS but BLAS seem needed 13 #endif 14 15 #if defined(PYTHRAN_BLAS_ATLAS) || defined(PYTHRAN_BLAS_SATLAS) 16 extern "C" { 17 #endif 18 #include <cblas.h> 19 #if defined(PYTHRAN_BLAS_ATLAS) || defined(PYTHRAN_BLAS_SATLAS) 20 } 21 #endif 22 23 PYTHONIC_NS_BEGIN 24 25 namespace numpy 26 { 27 template <class E, class F> 28 typename std::enable_if<types::is_dtype<E>::value && 29 types::is_dtype<F>::value, 30 decltype(std::declval<E>() * std::declval<F>())>::type dot(E const & e,F const & f)31 dot(E const &e, F const &f) 32 { 33 return e * f; 34 } 35 36 template <class E> 37 struct blas_buffer_t { operator ()numpy::blas_buffer_t38 typename E::dtype const *operator()(E const &e) const 39 { 40 return e.buffer; 41 } 42 }; 43 template <class T> 44 struct blas_buffer_t<types::list<T>> { operator ()numpy::blas_buffer_t45 T const *operator()(types::list<T> const &e) const 46 { 47 return &e.fast(0); 48 } 49 }; 50 template <class T, size_t N> 51 struct blas_buffer_t<types::array<T, N>> { operator ()numpy::blas_buffer_t52 T const *operator()(types::array<T, N> const &e) const 53 { 54 return e.data(); 55 } 56 }; 57 58 template <class E> blas_buffer(E const & e)59 auto blas_buffer(E const &e) -> decltype(blas_buffer_t<E>{}(e)) 60 { 61 return blas_buffer_t<E>{}(e); 62 } 63 64 template <class E, class F> 65 typename std::enable_if< 66 types::is_numexpr_arg<E>::value && 67 types::is_numexpr_arg<F>::value // Arguments are array_like 68 && E::value == 1 && F::value == 1 // It is a two vectors. 69 && (!is_blas_array<E>::value || !is_blas_array<F>::value || 70 !std::is_same<typename E::dtype, typename F::dtype>::value), 71 typename __combined<typename E::dtype, typename F::dtype>::type>::type dot(E const & e,F const & f)72 dot(E const &e, F const &f) 73 { 74 return sum(functor::multiply{}(e, f)); 75 } 76 77 template <class E, class F> 78 typename std::enable_if<E::value == 1 && F::value == 1 && 79 std::is_same<typename E::dtype, float>::value && 80 std::is_same<typename F::dtype, float>::value && 81 is_blas_array<E>::value && 82 is_blas_array<F>::value, 83 float>::type dot(E const & e,F const & f)84 dot(E const &e, F const &f) 85 { 86 return cblas_sdot(e.size(), blas_buffer(e), 1, blas_buffer(f), 1); 87 } 88 89 template <class E, class F> 90 typename std::enable_if<E::value == 1 && F::value == 1 && 91 std::is_same<typename E::dtype, double>::value && 92 std::is_same<typename F::dtype, double>::value && 93 is_blas_array<E>::value && 94 is_blas_array<F>::value, 95 double>::type dot(E const & e,F const & f)96 dot(E const &e, F const &f) 97 { 98 return cblas_ddot(e.size(), blas_buffer(e), 1, blas_buffer(f), 1); 99 } 100 101 template <class E, class F> 102 typename std::enable_if< 103 E::value == 1 && F::value == 1 && 104 std::is_same<typename E::dtype, std::complex<float>>::value && 105 std::is_same<typename F::dtype, std::complex<float>>::value && 106 is_blas_array<E>::value && is_blas_array<F>::value, 107 std::complex<float>>::type dot(E const & e,F const & f)108 dot(E const &e, F const &f) 109 { 110 std::complex<float> out; 111 cblas_cdotu_sub(e.size(), blas_buffer(e), 1, blas_buffer(f), 1, &out); 112 return out; 113 } 114 115 template <class E, class F> 116 typename std::enable_if< 117 E::value == 1 && F::value == 1 && 118 std::is_same<typename E::dtype, std::complex<double>>::value && 119 std::is_same<typename F::dtype, std::complex<double>>::value && 120 is_blas_array<E>::value && is_blas_array<F>::value, 121 std::complex<double>>::type dot(E const & e,F const & f)122 dot(E const &e, F const &f) 123 { 124 std::complex<double> out; 125 cblas_zdotu_sub(e.size(), blas_buffer(e), 1, blas_buffer(f), 1, &out); 126 return out; 127 } 128 129 /// Matrice / Vector multiplication 130 131 #define MV_DEF(T, L) \ 132 void mv(int m, int n, T *A, T *B, T *C) \ 133 { \ 134 cblas_##L##gemv(CblasRowMajor, CblasNoTrans, n, m, 1, A, m, B, 1, 0, C, \ 135 1); \ 136 } 137 MV_DEF(double,d)138 MV_DEF(double, d) 139 MV_DEF(float, s) 140 141 #undef MV_DEF 142 143 #define TV_DEF(T, L) \ 144 void tv(int m, int n, T *A, T *B, T *C) \ 145 { \ 146 cblas_##L##gemv(CblasRowMajor, CblasTrans, m, n, 1, A, n, B, 1, 0, C, 1); \ 147 } 148 149 TV_DEF(double, d) 150 TV_DEF(float, s) 151 152 #undef TV_DEF 153 154 #define MV_DEF(T, K, L) \ 155 void mv(int m, int n, T *A, T *B, T *C) \ 156 { \ 157 T alpha = 1, beta = 0; \ 158 cblas_##L##gemv(CblasRowMajor, CblasNoTrans, n, m, (K *)&alpha, (K *)A, m, \ 159 (K *)B, 1, (K *)&beta, (K *)C, 1); \ 160 } 161 MV_DEF(std::complex<float>, float, c) 162 MV_DEF(std::complex<double>, double, z) 163 #undef MV_DEF 164 165 template <class E, class pS0, class pS1> 166 typename std::enable_if<is_blas_type<E>::value && 167 std::tuple_size<pS0>::value == 2 && 168 std::tuple_size<pS1>::value == 1, 169 types::ndarray<E, types::pshape<long>>>::type 170 dot(types::ndarray<E, pS0> const &f, types::ndarray<E, pS1> const &e) 171 { 172 types::ndarray<E, types::pshape<long>> out( 173 types::pshape<long>{f.template shape<0>()}, builtins::None); 174 const int m = f.template shape<1>(), n = f.template shape<0>(); 175 mv(m, n, f.buffer, e.buffer, out.buffer); 176 return out; 177 } 178 179 template <class E, class pS0, class pS1> 180 typename std::enable_if<is_blas_type<E>::value && 181 std::tuple_size<pS0>::value == 2 && 182 std::tuple_size<pS1>::value == 1, 183 types::ndarray<E, types::pshape<long>>>::type dot(types::numpy_texpr<types::ndarray<E,pS0>> const & f,types::ndarray<E,pS1> const & e)184 dot(types::numpy_texpr<types::ndarray<E, pS0>> const &f, 185 types::ndarray<E, pS1> const &e) 186 { 187 types::ndarray<E, types::pshape<long>> out( 188 types::pshape<long>{f.template shape<0>()}, builtins::None); 189 const int m = f.template shape<1>(), n = f.template shape<0>(); 190 tv(m, n, f.arg.buffer, e.buffer, out.buffer); 191 return out; 192 } 193 194 // The trick is to not transpose the matrix so that MV become VM 195 #define VM_DEF(T, L) \ 196 void vm(int m, int n, T *A, T *B, T *C) \ 197 { \ 198 cblas_##L##gemv(CblasRowMajor, CblasTrans, n, m, 1, A, m, B, 1, 0, C, 1); \ 199 } 200 VM_DEF(double,d)201 VM_DEF(double, d) 202 VM_DEF(float, s) 203 204 #undef VM_DEF 205 #define VT_DEF(T, L) \ 206 void vt(int m, int n, T *A, T *B, T *C) \ 207 { \ 208 cblas_##L##gemv(CblasRowMajor, CblasNoTrans, m, n, 1, A, n, B, 1, 0, C, \ 209 1); \ 210 } 211 212 VT_DEF(double, d) 213 VT_DEF(float, s) 214 215 #undef VM_DEF 216 #define VM_DEF(T, K, L) \ 217 void vm(int m, int n, T *A, T *B, T *C) \ 218 { \ 219 T alpha = 1, beta = 0; \ 220 cblas_##L##gemv(CblasRowMajor, CblasTrans, n, m, (K *)&alpha, (K *)A, m, \ 221 (K *)B, 1, (K *)&beta, (K *)C, 1); \ 222 } 223 VM_DEF(std::complex<float>, float, c) 224 VM_DEF(std::complex<double>, double, z) 225 #undef VM_DEF 226 227 template <class E, class pS0, class pS1> 228 typename std::enable_if<is_blas_type<E>::value && 229 std::tuple_size<pS0>::value == 1 && 230 std::tuple_size<pS1>::value == 2, 231 types::ndarray<E, types::pshape<long>>>::type 232 dot(types::ndarray<E, pS0> const &e, types::ndarray<E, pS1> const &f) 233 { 234 types::ndarray<E, types::pshape<long>> out( 235 types::pshape<long>{f.template shape<1>()}, builtins::None); 236 const int m = f.template shape<1>(), n = f.template shape<0>(); 237 vm(m, n, f.buffer, e.buffer, out.buffer); 238 return out; 239 } 240 241 template <class E, class pS0, class pS1> 242 typename std::enable_if<is_blas_type<E>::value && 243 std::tuple_size<pS0>::value == 1 && 244 std::tuple_size<pS1>::value == 2, 245 types::ndarray<E, types::pshape<long>>>::type dot(types::ndarray<E,pS0> const & e,types::numpy_texpr<types::ndarray<E,pS1>> const & f)246 dot(types::ndarray<E, pS0> const &e, 247 types::numpy_texpr<types::ndarray<E, pS1>> const &f) 248 { 249 types::ndarray<E, types::pshape<long>> out( 250 types::pshape<long>{f.template shape<1>()}, builtins::None); 251 const int m = f.template shape<1>(), n = f.template shape<0>(); 252 vt(m, n, f.arg.buffer, e.buffer, out.buffer); 253 return out; 254 } 255 256 // If arguments could be use with blas, we evaluate them as we need pointer 257 // on array for blas 258 template <class E, class F> 259 typename std::enable_if< 260 types::is_numexpr_arg<E>::value && 261 types::is_numexpr_arg<F>::value // It is an array_like 262 && (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) || 263 !std::is_same<typename E::dtype, typename F::dtype>::value) && 264 is_blas_type<typename E::dtype>::value && 265 is_blas_type<typename F::dtype>::value // With dtype compatible with 266 // blas 267 && 268 E::value == 2 && F::value == 1, // And it is matrix / vect 269 types::ndarray< 270 typename __combined<typename E::dtype, typename F::dtype>::type, 271 types::pshape<long>>>::type dot(E const & e,F const & f)272 dot(E const &e, F const &f) 273 { 274 types::ndarray< 275 typename __combined<typename E::dtype, typename F::dtype>::type, 276 typename E::shape_t> e_ = e; 277 types::ndarray< 278 typename __combined<typename E::dtype, typename F::dtype>::type, 279 typename F::shape_t> f_ = f; 280 return dot(e_, f_); 281 } 282 283 // If arguments could be use with blas, we evaluate them as we need pointer 284 // on array for blas 285 template <class E, class F> 286 typename std::enable_if< 287 types::is_numexpr_arg<E>::value && 288 types::is_numexpr_arg<F>::value // It is an array_like 289 && (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) || 290 !std::is_same<typename E::dtype, typename F::dtype>::value) && 291 is_blas_type<typename E::dtype>::value && 292 is_blas_type<typename F::dtype>::value // With dtype compatible with 293 // blas 294 && 295 E::value == 1 && F::value == 2, // And it is vect / matrix 296 types::ndarray< 297 typename __combined<typename E::dtype, typename F::dtype>::type, 298 types::pshape<long>>>::type dot(E const & e,F const & f)299 dot(E const &e, F const &f) 300 { 301 types::ndarray< 302 typename __combined<typename E::dtype, typename F::dtype>::type, 303 typename E::shape_t> e_ = e; 304 types::ndarray< 305 typename __combined<typename E::dtype, typename F::dtype>::type, 306 typename F::shape_t> f_ = f; 307 return dot(e_, f_); 308 } 309 310 // If one of the arg doesn't have a "blas compatible type", we use a slow 311 // matrix vector multiplication. 312 template <class E, class F> 313 typename std::enable_if< 314 (!is_blas_type<typename E::dtype>::value || 315 !is_blas_type<typename F::dtype>::value) && 316 E::value == 1 && F::value == 2, // And it is vect / matrix 317 types::ndarray< 318 typename __combined<typename E::dtype, typename F::dtype>::type, 319 types::pshape<long>>>::type dot(E const & e,F const & f)320 dot(E const &e, F const &f) 321 { 322 types::ndarray< 323 typename __combined<typename E::dtype, typename F::dtype>::type, 324 types::pshape<long>> 325 out(types::pshape<long>{f.template shape<1>()}, 0); 326 for (long i = 0; i < out.template shape<0>(); i++) 327 for (long j = 0; j < f.template shape<0>(); j++) 328 out[i] += e[j] * f[types::array<long, 2>{{j, i}}]; 329 return out; 330 } 331 332 // If one of the arg doesn't have a "blas compatible type", we use a slow 333 // matrix vector multiplication. 334 template <class E, class F> 335 typename std::enable_if< 336 (!is_blas_type<typename E::dtype>::value || 337 !is_blas_type<typename F::dtype>::value) && 338 E::value == 2 && F::value == 1, // And it is vect / matrix 339 types::ndarray< 340 typename __combined<typename E::dtype, typename F::dtype>::type, 341 types::pshape<long>>>::type dot(E const & e,F const & f)342 dot(E const &e, F const &f) 343 { 344 types::ndarray< 345 typename __combined<typename E::dtype, typename F::dtype>::type, 346 types::pshape<long>> 347 out(types::pshape<long>{e.template shape<0>()}, 0); 348 for (long i = 0; i < out.template shape<0>(); i++) 349 for (long j = 0; j < f.template shape<0>(); j++) 350 out[i] += e[types::array<long, 2>{{i, j}}] * f[j]; 351 return out; 352 } 353 354 /// Matrix / Matrix multiplication 355 356 #define MM_DEF(T, L) \ 357 void mm(int m, int n, int k, T *A, T *B, T *C) \ 358 { \ 359 cblas_##L##gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, A, \ 360 k, B, n, 0, C, n); \ 361 } MM_DEF(double,d)362 MM_DEF(double, d) 363 MM_DEF(float, s) 364 #undef MM_DEF 365 #define MM_DEF(T, K, L) \ 366 void mm(int m, int n, int k, T *A, T *B, T *C) \ 367 { \ 368 T alpha = 1, beta = 0; \ 369 cblas_##L##gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, \ 370 (K *)&alpha, (K *)A, k, (K *)B, n, (K *)&beta, (K *)C, n); \ 371 } 372 MM_DEF(std::complex<float>, float, c) 373 MM_DEF(std::complex<double>, double, z) 374 #undef MM_DEF 375 376 template <class E, class pS0, class pS1> 377 typename std::enable_if<is_blas_type<E>::value && 378 std::tuple_size<pS0>::value == 2 && 379 std::tuple_size<pS1>::value == 2, 380 types::ndarray<E, types::array<long, 2>>>::type 381 dot(types::ndarray<E, pS0> const &a, types::ndarray<E, pS1> const &b) 382 { 383 int n = b.template shape<1>(), m = a.template shape<0>(), 384 k = b.template shape<0>(); 385 386 types::ndarray<E, types::array<long, 2>> out(types::array<long, 2>{{m, n}}, 387 builtins::None); 388 mm(m, n, k, a.buffer, b.buffer, out.buffer); 389 return out; 390 } 391 392 template <class E, class pS0, class pS1, class pS2> 393 typename std::enable_if< 394 is_blas_type<E>::value && std::tuple_size<pS0>::value == 2 && 395 std::tuple_size<pS1>::value == 2 && std::tuple_size<pS2>::value == 2, 396 types::ndarray<E, pS2>>::type & dot(types::ndarray<E,pS0> const & a,types::ndarray<E,pS1> const & b,types::ndarray<E,pS2> & c)397 dot(types::ndarray<E, pS0> const &a, types::ndarray<E, pS1> const &b, 398 types::ndarray<E, pS2> &c) 399 { 400 int n = b.template shape<1>(), m = a.template shape<0>(), 401 k = b.template shape<0>(); 402 403 mm(m, n, k, a.buffer, b.buffer, c.buffer); 404 return c; 405 } 406 407 #define TM_DEF(T, L) \ 408 void tm(int m, int n, int k, T *A, T *B, T *C) \ 409 { \ 410 cblas_##L##gemm(CblasRowMajor, CblasTrans, CblasNoTrans, m, n, k, 1, A, m, \ 411 B, n, 0, C, n); \ 412 } TM_DEF(double,d)413 TM_DEF(double, d) 414 TM_DEF(float, s) 415 #undef TM_DEF 416 #define TM_DEF(T, K, L) \ 417 void tm(int m, int n, int k, T *A, T *B, T *C) \ 418 { \ 419 T alpha = 1, beta = 0; \ 420 cblas_##L##gemm(CblasRowMajor, CblasTrans, CblasNoTrans, m, n, k, \ 421 (K *)&alpha, (K *)A, m, (K *)B, n, (K *)&beta, (K *)C, n); \ 422 } 423 TM_DEF(std::complex<float>, float, c) 424 TM_DEF(std::complex<double>, double, z) 425 #undef TM_DEF 426 427 template <class E, class pS0, class pS1> 428 typename std::enable_if<is_blas_type<E>::value && 429 std::tuple_size<pS0>::value == 2 && 430 std::tuple_size<pS1>::value == 2, 431 types::ndarray<E, types::array<long, 2>>>::type 432 dot(types::numpy_texpr<types::ndarray<E, pS0>> const &a, 433 types::ndarray<E, pS1> const &b) 434 { 435 int n = b.template shape<1>(), m = a.template shape<0>(), 436 k = b.template shape<0>(); 437 438 types::ndarray<E, types::array<long, 2>> out(types::array<long, 2>{{m, n}}, 439 builtins::None); 440 tm(m, n, k, a.arg.buffer, b.buffer, out.buffer); 441 return out; 442 } 443 444 #define MT_DEF(T, L) \ 445 void mt(int m, int n, int k, T *A, T *B, T *C) \ 446 { \ 447 cblas_##L##gemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1, A, k, \ 448 B, k, 0, C, n); \ 449 } MT_DEF(double,d)450 MT_DEF(double, d) 451 MT_DEF(float, s) 452 #undef MT_DEF 453 #define MT_DEF(T, K, L) \ 454 void mt(int m, int n, int k, T *A, T *B, T *C) \ 455 { \ 456 T alpha = 1, beta = 0; \ 457 cblas_##L##gemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, \ 458 (K *)&alpha, (K *)A, k, (K *)B, k, (K *)&beta, (K *)C, n); \ 459 } 460 MT_DEF(std::complex<float>, float, c) 461 MT_DEF(std::complex<double>, double, z) 462 #undef MT_DEF 463 464 template <class E, class pS0, class pS1> 465 typename std::enable_if<is_blas_type<E>::value && 466 std::tuple_size<pS0>::value == 2 && 467 std::tuple_size<pS1>::value == 2, 468 types::ndarray<E, types::array<long, 2>>>::type 469 dot(types::ndarray<E, pS0> const &a, 470 types::numpy_texpr<types::ndarray<E, pS1>> const &b) 471 { 472 int n = b.template shape<1>(), m = a.template shape<0>(), 473 k = b.template shape<0>(); 474 475 types::ndarray<E, types::array<long, 2>> out(types::array<long, 2>{{m, n}}, 476 builtins::None); 477 mt(m, n, k, a.buffer, b.arg.buffer, out.buffer); 478 return out; 479 } 480 481 #define TT_DEF(T, L) \ 482 void tt(int m, int n, int k, T *A, T *B, T *C) \ 483 { \ 484 cblas_##L##gemm(CblasRowMajor, CblasTrans, CblasTrans, m, n, k, 1, A, m, \ 485 B, k, 0, C, n); \ 486 } TT_DEF(double,d)487 TT_DEF(double, d) 488 TT_DEF(float, s) 489 #undef TT_DEF 490 #define TT_DEF(T, K, L) \ 491 void tt(int m, int n, int k, T *A, T *B, T *C) \ 492 { \ 493 T alpha = 1, beta = 0; \ 494 cblas_##L##gemm(CblasRowMajor, CblasTrans, CblasTrans, m, n, k, \ 495 (K *)&alpha, (K *)A, m, (K *)B, k, (K *)&beta, (K *)C, n); \ 496 } 497 TT_DEF(std::complex<float>, float, c) 498 TT_DEF(std::complex<double>, double, z) 499 #undef TT_DEF 500 501 template <class E, class pS0, class pS1> 502 typename std::enable_if<is_blas_type<E>::value && 503 std::tuple_size<pS0>::value == 2 && 504 std::tuple_size<pS1>::value == 2, 505 types::ndarray<E, types::array<long, 2>>>::type 506 dot(types::numpy_texpr<types::ndarray<E, pS0>> const &a, 507 types::numpy_texpr<types::ndarray<E, pS1>> const &b) 508 { 509 int n = b.template shape<1>(), m = a.template shape<0>(), 510 k = b.template shape<0>(); 511 512 types::ndarray<E, types::array<long, 2>> out(types::array<long, 2>{{m, n}}, 513 builtins::None); 514 tt(m, n, k, a.arg.buffer, b.arg.buffer, out.buffer); 515 return out; 516 } 517 518 // If arguments could be use with blas, we evaluate them as we need pointer 519 // on array for blas 520 template <class E, class F> 521 typename std::enable_if< 522 types::is_numexpr_arg<E>::value && 523 types::is_numexpr_arg<F>::value // It is an array_like 524 && (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) || 525 !std::is_same<typename E::dtype, typename F::dtype>::value) && 526 is_blas_type<typename E::dtype>::value && 527 is_blas_type<typename F::dtype>::value // With dtype compatible with 528 // blas 529 && 530 E::value == 2 && F::value == 2, // And both are matrix 531 types::ndarray< 532 typename __combined<typename E::dtype, typename F::dtype>::type, 533 types::array<long, 2>>>::type dot(E const & e,F const & f)534 dot(E const &e, F const &f) 535 { 536 types::ndarray< 537 typename __combined<typename E::dtype, typename F::dtype>::type, 538 typename E::shape_t> e_ = e; 539 types::ndarray< 540 typename __combined<typename E::dtype, typename F::dtype>::type, 541 typename F::shape_t> f_ = f; 542 return dot(e_, f_); 543 } 544 545 // If one of the arg doesn't have a "blas compatible type", we use a slow 546 // matrix multiplication. 547 template <class E, class F> 548 typename std::enable_if< 549 (!is_blas_type<typename E::dtype>::value || 550 !is_blas_type<typename F::dtype>::value) && 551 E::value == 2 && F::value == 2, // And it is matrix / matrix 552 types::ndarray< 553 typename __combined<typename E::dtype, typename F::dtype>::type, 554 types::array<long, 2>>>::type dot(E const & e,F const & f)555 dot(E const &e, F const &f) 556 { 557 types::ndarray< 558 typename __combined<typename E::dtype, typename F::dtype>::type, 559 types::array<long, 2>> 560 out(types::array<long, 2>{{e.template shape<0>(), f.template shape<1>()}}, 561 0); 562 for (long i = 0; i < out.template shape<0>(); i++) 563 for (long j = 0; j < out.template shape<1>(); j++) 564 for (long k = 0; k < e.template shape<1>(); k++) 565 out[types::array<long, 2>{{i, j}}] += 566 e[types::array<long, 2>{{i, k}}] * 567 f[types::array<long, 2>{{k, j}}]; 568 return out; 569 } 570 571 template <class E, class F> 572 typename std::enable_if< 573 (E::value >= 3 && F::value == 1), // And it is matrix / matrix 574 types::ndarray< 575 typename __combined<typename E::dtype, typename F::dtype>::type, 576 types::array<long, E::value - 1>>>::type dot(E const & e,F const & f)577 dot(E const &e, F const &f) 578 { 579 auto out = dot( 580 e.reshape(types::array<long, 2>{{sutils::prod_head(e), f.size()}}), f); 581 types::array<long, E::value - 1> out_shape; 582 auto tmp = sutils::getshape(e); 583 std::copy(tmp.begin(), tmp.end() - 1, out_shape.begin()); 584 return out.reshape(out_shape); 585 } 586 587 template <class E, class F> 588 typename std::enable_if< 589 (E::value >= 3 && F::value >= 2), 590 types::ndarray< 591 typename __combined<typename E::dtype, typename F::dtype>::type, 592 types::array<long, E::value - 1>>>::type dot(E const & e,F const & f)593 dot(E const &e, F const &f) 594 { 595 static_assert(E::value == 0, "not implemented yet"); 596 } 597 } 598 PYTHONIC_NS_END 599 600 #endif 601