1 #ifndef PYTHONIC_NUMPY_DOT_HPP
2 #define PYTHONIC_NUMPY_DOT_HPP
3 
4 #include "pythonic/include/numpy/dot.hpp"
5 
6 #include "pythonic/types/ndarray.hpp"
7 #include "pythonic/numpy/sum.hpp"
8 #include "pythonic/numpy/multiply.hpp"
9 #include "pythonic/types/traits.hpp"
10 
11 #ifdef PYTHRAN_BLAS_NONE
12 #error pythran configured without BLAS but BLAS seem needed
13 #endif
14 
15 #if defined(PYTHRAN_BLAS_ATLAS) || defined(PYTHRAN_BLAS_SATLAS)
16 extern "C" {
17 #endif
18 #include <cblas.h>
19 #if defined(PYTHRAN_BLAS_ATLAS) || defined(PYTHRAN_BLAS_SATLAS)
20 }
21 #endif
22 
23 PYTHONIC_NS_BEGIN
24 
25 namespace numpy
26 {
27   template <class E, class F>
28   typename std::enable_if<types::is_dtype<E>::value &&
29                               types::is_dtype<F>::value,
30                           decltype(std::declval<E>() * std::declval<F>())>::type
dot(E const & e,F const & f)31   dot(E const &e, F const &f)
32   {
33     return e * f;
34   }
35 
36   template <class E>
37   struct blas_buffer_t {
operator ()numpy::blas_buffer_t38     typename E::dtype const *operator()(E const &e) const
39     {
40       return e.buffer;
41     }
42   };
43   template <class T>
44   struct blas_buffer_t<types::list<T>> {
operator ()numpy::blas_buffer_t45     T const *operator()(types::list<T> const &e) const
46     {
47       return &e.fast(0);
48     }
49   };
50   template <class T, size_t N>
51   struct blas_buffer_t<types::array<T, N>> {
operator ()numpy::blas_buffer_t52     T const *operator()(types::array<T, N> const &e) const
53     {
54       return e.data();
55     }
56   };
57 
58   template <class E>
blas_buffer(E const & e)59   auto blas_buffer(E const &e) -> decltype(blas_buffer_t<E>{}(e))
60   {
61     return blas_buffer_t<E>{}(e);
62   }
63 
64   template <class E, class F>
65   typename std::enable_if<
66       types::is_numexpr_arg<E>::value &&
67           types::is_numexpr_arg<F>::value   // Arguments are array_like
68           && E::value == 1 && F::value == 1 // It is a two vectors.
69           && (!is_blas_array<E>::value || !is_blas_array<F>::value ||
70               !std::is_same<typename E::dtype, typename F::dtype>::value),
71       typename __combined<typename E::dtype, typename F::dtype>::type>::type
dot(E const & e,F const & f)72   dot(E const &e, F const &f)
73   {
74     return sum(functor::multiply{}(e, f));
75   }
76 
77   template <class E, class F>
78   typename std::enable_if<E::value == 1 && F::value == 1 &&
79                               std::is_same<typename E::dtype, float>::value &&
80                               std::is_same<typename F::dtype, float>::value &&
81                               is_blas_array<E>::value &&
82                               is_blas_array<F>::value,
83                           float>::type
dot(E const & e,F const & f)84   dot(E const &e, F const &f)
85   {
86     return cblas_sdot(e.size(), blas_buffer(e), 1, blas_buffer(f), 1);
87   }
88 
89   template <class E, class F>
90   typename std::enable_if<E::value == 1 && F::value == 1 &&
91                               std::is_same<typename E::dtype, double>::value &&
92                               std::is_same<typename F::dtype, double>::value &&
93                               is_blas_array<E>::value &&
94                               is_blas_array<F>::value,
95                           double>::type
dot(E const & e,F const & f)96   dot(E const &e, F const &f)
97   {
98     return cblas_ddot(e.size(), blas_buffer(e), 1, blas_buffer(f), 1);
99   }
100 
101   template <class E, class F>
102   typename std::enable_if<
103       E::value == 1 && F::value == 1 &&
104           std::is_same<typename E::dtype, std::complex<float>>::value &&
105           std::is_same<typename F::dtype, std::complex<float>>::value &&
106           is_blas_array<E>::value && is_blas_array<F>::value,
107       std::complex<float>>::type
dot(E const & e,F const & f)108   dot(E const &e, F const &f)
109   {
110     std::complex<float> out;
111     cblas_cdotu_sub(e.size(), blas_buffer(e), 1, blas_buffer(f), 1, &out);
112     return out;
113   }
114 
115   template <class E, class F>
116   typename std::enable_if<
117       E::value == 1 && F::value == 1 &&
118           std::is_same<typename E::dtype, std::complex<double>>::value &&
119           std::is_same<typename F::dtype, std::complex<double>>::value &&
120           is_blas_array<E>::value && is_blas_array<F>::value,
121       std::complex<double>>::type
dot(E const & e,F const & f)122   dot(E const &e, F const &f)
123   {
124     std::complex<double> out;
125     cblas_zdotu_sub(e.size(), blas_buffer(e), 1, blas_buffer(f), 1, &out);
126     return out;
127   }
128 
129 /// Matrice / Vector multiplication
130 
131 #define MV_DEF(T, L)                                                           \
132   void mv(int m, int n, T *A, T *B, T *C)                                      \
133   {                                                                            \
134     cblas_##L##gemv(CblasRowMajor, CblasNoTrans, n, m, 1, A, m, B, 1, 0, C,    \
135                     1);                                                        \
136   }
137 
MV_DEF(double,d)138   MV_DEF(double, d)
139   MV_DEF(float, s)
140 
141 #undef MV_DEF
142 
143 #define TV_DEF(T, L)                                                           \
144   void tv(int m, int n, T *A, T *B, T *C)                                      \
145   {                                                                            \
146     cblas_##L##gemv(CblasRowMajor, CblasTrans, m, n, 1, A, n, B, 1, 0, C, 1);  \
147   }
148 
149   TV_DEF(double, d)
150   TV_DEF(float, s)
151 
152 #undef TV_DEF
153 
154 #define MV_DEF(T, K, L)                                                        \
155   void mv(int m, int n, T *A, T *B, T *C)                                      \
156   {                                                                            \
157     T alpha = 1, beta = 0;                                                     \
158     cblas_##L##gemv(CblasRowMajor, CblasNoTrans, n, m, (K *)&alpha, (K *)A, m, \
159                     (K *)B, 1, (K *)&beta, (K *)C, 1);                         \
160   }
161   MV_DEF(std::complex<float>, float, c)
162   MV_DEF(std::complex<double>, double, z)
163 #undef MV_DEF
164 
165   template <class E, class pS0, class pS1>
166   typename std::enable_if<is_blas_type<E>::value &&
167                               std::tuple_size<pS0>::value == 2 &&
168                               std::tuple_size<pS1>::value == 1,
169                           types::ndarray<E, types::pshape<long>>>::type
170   dot(types::ndarray<E, pS0> const &f, types::ndarray<E, pS1> const &e)
171   {
172     types::ndarray<E, types::pshape<long>> out(
173         types::pshape<long>{f.template shape<0>()}, builtins::None);
174     const int m = f.template shape<1>(), n = f.template shape<0>();
175     mv(m, n, f.buffer, e.buffer, out.buffer);
176     return out;
177   }
178 
179   template <class E, class pS0, class pS1>
180   typename std::enable_if<is_blas_type<E>::value &&
181                               std::tuple_size<pS0>::value == 2 &&
182                               std::tuple_size<pS1>::value == 1,
183                           types::ndarray<E, types::pshape<long>>>::type
dot(types::numpy_texpr<types::ndarray<E,pS0>> const & f,types::ndarray<E,pS1> const & e)184   dot(types::numpy_texpr<types::ndarray<E, pS0>> const &f,
185       types::ndarray<E, pS1> const &e)
186   {
187     types::ndarray<E, types::pshape<long>> out(
188         types::pshape<long>{f.template shape<0>()}, builtins::None);
189     const int m = f.template shape<1>(), n = f.template shape<0>();
190     tv(m, n, f.arg.buffer, e.buffer, out.buffer);
191     return out;
192   }
193 
194 // The trick is to not transpose the matrix so that MV become VM
195 #define VM_DEF(T, L)                                                           \
196   void vm(int m, int n, T *A, T *B, T *C)                                      \
197   {                                                                            \
198     cblas_##L##gemv(CblasRowMajor, CblasTrans, n, m, 1, A, m, B, 1, 0, C, 1);  \
199   }
200 
VM_DEF(double,d)201   VM_DEF(double, d)
202   VM_DEF(float, s)
203 
204 #undef VM_DEF
205 #define VT_DEF(T, L)                                                           \
206   void vt(int m, int n, T *A, T *B, T *C)                                      \
207   {                                                                            \
208     cblas_##L##gemv(CblasRowMajor, CblasNoTrans, m, n, 1, A, n, B, 1, 0, C,    \
209                     1);                                                        \
210   }
211 
212   VT_DEF(double, d)
213   VT_DEF(float, s)
214 
215 #undef VM_DEF
216 #define VM_DEF(T, K, L)                                                        \
217   void vm(int m, int n, T *A, T *B, T *C)                                      \
218   {                                                                            \
219     T alpha = 1, beta = 0;                                                     \
220     cblas_##L##gemv(CblasRowMajor, CblasTrans, n, m, (K *)&alpha, (K *)A, m,   \
221                     (K *)B, 1, (K *)&beta, (K *)C, 1);                         \
222   }
223   VM_DEF(std::complex<float>, float, c)
224   VM_DEF(std::complex<double>, double, z)
225 #undef VM_DEF
226 
227   template <class E, class pS0, class pS1>
228   typename std::enable_if<is_blas_type<E>::value &&
229                               std::tuple_size<pS0>::value == 1 &&
230                               std::tuple_size<pS1>::value == 2,
231                           types::ndarray<E, types::pshape<long>>>::type
232   dot(types::ndarray<E, pS0> const &e, types::ndarray<E, pS1> const &f)
233   {
234     types::ndarray<E, types::pshape<long>> out(
235         types::pshape<long>{f.template shape<1>()}, builtins::None);
236     const int m = f.template shape<1>(), n = f.template shape<0>();
237     vm(m, n, f.buffer, e.buffer, out.buffer);
238     return out;
239   }
240 
241   template <class E, class pS0, class pS1>
242   typename std::enable_if<is_blas_type<E>::value &&
243                               std::tuple_size<pS0>::value == 1 &&
244                               std::tuple_size<pS1>::value == 2,
245                           types::ndarray<E, types::pshape<long>>>::type
dot(types::ndarray<E,pS0> const & e,types::numpy_texpr<types::ndarray<E,pS1>> const & f)246   dot(types::ndarray<E, pS0> const &e,
247       types::numpy_texpr<types::ndarray<E, pS1>> const &f)
248   {
249     types::ndarray<E, types::pshape<long>> out(
250         types::pshape<long>{f.template shape<1>()}, builtins::None);
251     const int m = f.template shape<1>(), n = f.template shape<0>();
252     vt(m, n, f.arg.buffer, e.buffer, out.buffer);
253     return out;
254   }
255 
256   // If arguments could be use with blas, we evaluate them as we need pointer
257   // on array for blas
258   template <class E, class F>
259   typename std::enable_if<
260       types::is_numexpr_arg<E>::value &&
261           types::is_numexpr_arg<F>::value // It is an array_like
262           && (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) ||
263               !std::is_same<typename E::dtype, typename F::dtype>::value) &&
264           is_blas_type<typename E::dtype>::value &&
265           is_blas_type<typename F::dtype>::value // With dtype compatible with
266                                                  // blas
267           &&
268           E::value == 2 && F::value == 1, // And it is matrix / vect
269       types::ndarray<
270           typename __combined<typename E::dtype, typename F::dtype>::type,
271           types::pshape<long>>>::type
dot(E const & e,F const & f)272   dot(E const &e, F const &f)
273   {
274     types::ndarray<
275         typename __combined<typename E::dtype, typename F::dtype>::type,
276         typename E::shape_t> e_ = e;
277     types::ndarray<
278         typename __combined<typename E::dtype, typename F::dtype>::type,
279         typename F::shape_t> f_ = f;
280     return dot(e_, f_);
281   }
282 
283   // If arguments could be use with blas, we evaluate them as we need pointer
284   // on array for blas
285   template <class E, class F>
286   typename std::enable_if<
287       types::is_numexpr_arg<E>::value &&
288           types::is_numexpr_arg<F>::value // It is an array_like
289           && (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) ||
290               !std::is_same<typename E::dtype, typename F::dtype>::value) &&
291           is_blas_type<typename E::dtype>::value &&
292           is_blas_type<typename F::dtype>::value // With dtype compatible with
293                                                  // blas
294           &&
295           E::value == 1 && F::value == 2, // And it is vect / matrix
296       types::ndarray<
297           typename __combined<typename E::dtype, typename F::dtype>::type,
298           types::pshape<long>>>::type
dot(E const & e,F const & f)299   dot(E const &e, F const &f)
300   {
301     types::ndarray<
302         typename __combined<typename E::dtype, typename F::dtype>::type,
303         typename E::shape_t> e_ = e;
304     types::ndarray<
305         typename __combined<typename E::dtype, typename F::dtype>::type,
306         typename F::shape_t> f_ = f;
307     return dot(e_, f_);
308   }
309 
310   // If one of the arg doesn't have a "blas compatible type", we use a slow
311   // matrix vector multiplication.
312   template <class E, class F>
313   typename std::enable_if<
314       (!is_blas_type<typename E::dtype>::value ||
315        !is_blas_type<typename F::dtype>::value) &&
316           E::value == 1 && F::value == 2, // And it is vect / matrix
317       types::ndarray<
318           typename __combined<typename E::dtype, typename F::dtype>::type,
319           types::pshape<long>>>::type
dot(E const & e,F const & f)320   dot(E const &e, F const &f)
321   {
322     types::ndarray<
323         typename __combined<typename E::dtype, typename F::dtype>::type,
324         types::pshape<long>>
325     out(types::pshape<long>{f.template shape<1>()}, 0);
326     for (long i = 0; i < out.template shape<0>(); i++)
327       for (long j = 0; j < f.template shape<0>(); j++)
328         out[i] += e[j] * f[types::array<long, 2>{{j, i}}];
329     return out;
330   }
331 
332   // If one of the arg doesn't have a "blas compatible type", we use a slow
333   // matrix vector multiplication.
334   template <class E, class F>
335   typename std::enable_if<
336       (!is_blas_type<typename E::dtype>::value ||
337        !is_blas_type<typename F::dtype>::value) &&
338           E::value == 2 && F::value == 1, // And it is vect / matrix
339       types::ndarray<
340           typename __combined<typename E::dtype, typename F::dtype>::type,
341           types::pshape<long>>>::type
dot(E const & e,F const & f)342   dot(E const &e, F const &f)
343   {
344     types::ndarray<
345         typename __combined<typename E::dtype, typename F::dtype>::type,
346         types::pshape<long>>
347     out(types::pshape<long>{e.template shape<0>()}, 0);
348     for (long i = 0; i < out.template shape<0>(); i++)
349       for (long j = 0; j < f.template shape<0>(); j++)
350         out[i] += e[types::array<long, 2>{{i, j}}] * f[j];
351     return out;
352   }
353 
354 /// Matrix / Matrix multiplication
355 
356 #define MM_DEF(T, L)                                                           \
357   void mm(int m, int n, int k, T *A, T *B, T *C)                               \
358   {                                                                            \
359     cblas_##L##gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1, A,  \
360                     k, B, n, 0, C, n);                                         \
361   }
MM_DEF(double,d)362   MM_DEF(double, d)
363   MM_DEF(float, s)
364 #undef MM_DEF
365 #define MM_DEF(T, K, L)                                                        \
366   void mm(int m, int n, int k, T *A, T *B, T *C)                               \
367   {                                                                            \
368     T alpha = 1, beta = 0;                                                     \
369     cblas_##L##gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k,        \
370                     (K *)&alpha, (K *)A, k, (K *)B, n, (K *)&beta, (K *)C, n); \
371   }
372   MM_DEF(std::complex<float>, float, c)
373   MM_DEF(std::complex<double>, double, z)
374 #undef MM_DEF
375 
376   template <class E, class pS0, class pS1>
377   typename std::enable_if<is_blas_type<E>::value &&
378                               std::tuple_size<pS0>::value == 2 &&
379                               std::tuple_size<pS1>::value == 2,
380                           types::ndarray<E, types::array<long, 2>>>::type
381   dot(types::ndarray<E, pS0> const &a, types::ndarray<E, pS1> const &b)
382   {
383     int n = b.template shape<1>(), m = a.template shape<0>(),
384         k = b.template shape<0>();
385 
386     types::ndarray<E, types::array<long, 2>> out(types::array<long, 2>{{m, n}},
387                                                  builtins::None);
388     mm(m, n, k, a.buffer, b.buffer, out.buffer);
389     return out;
390   }
391 
392   template <class E, class pS0, class pS1, class pS2>
393   typename std::enable_if<
394       is_blas_type<E>::value && std::tuple_size<pS0>::value == 2 &&
395           std::tuple_size<pS1>::value == 2 && std::tuple_size<pS2>::value == 2,
396       types::ndarray<E, pS2>>::type &
dot(types::ndarray<E,pS0> const & a,types::ndarray<E,pS1> const & b,types::ndarray<E,pS2> & c)397   dot(types::ndarray<E, pS0> const &a, types::ndarray<E, pS1> const &b,
398       types::ndarray<E, pS2> &c)
399   {
400     int n = b.template shape<1>(), m = a.template shape<0>(),
401         k = b.template shape<0>();
402 
403     mm(m, n, k, a.buffer, b.buffer, c.buffer);
404     return c;
405   }
406 
407 #define TM_DEF(T, L)                                                           \
408   void tm(int m, int n, int k, T *A, T *B, T *C)                               \
409   {                                                                            \
410     cblas_##L##gemm(CblasRowMajor, CblasTrans, CblasNoTrans, m, n, k, 1, A, m, \
411                     B, n, 0, C, n);                                            \
412   }
TM_DEF(double,d)413   TM_DEF(double, d)
414   TM_DEF(float, s)
415 #undef TM_DEF
416 #define TM_DEF(T, K, L)                                                        \
417   void tm(int m, int n, int k, T *A, T *B, T *C)                               \
418   {                                                                            \
419     T alpha = 1, beta = 0;                                                     \
420     cblas_##L##gemm(CblasRowMajor, CblasTrans, CblasNoTrans, m, n, k,          \
421                     (K *)&alpha, (K *)A, m, (K *)B, n, (K *)&beta, (K *)C, n); \
422   }
423   TM_DEF(std::complex<float>, float, c)
424   TM_DEF(std::complex<double>, double, z)
425 #undef TM_DEF
426 
427   template <class E, class pS0, class pS1>
428   typename std::enable_if<is_blas_type<E>::value &&
429                               std::tuple_size<pS0>::value == 2 &&
430                               std::tuple_size<pS1>::value == 2,
431                           types::ndarray<E, types::array<long, 2>>>::type
432   dot(types::numpy_texpr<types::ndarray<E, pS0>> const &a,
433       types::ndarray<E, pS1> const &b)
434   {
435     int n = b.template shape<1>(), m = a.template shape<0>(),
436         k = b.template shape<0>();
437 
438     types::ndarray<E, types::array<long, 2>> out(types::array<long, 2>{{m, n}},
439                                                  builtins::None);
440     tm(m, n, k, a.arg.buffer, b.buffer, out.buffer);
441     return out;
442   }
443 
444 #define MT_DEF(T, L)                                                           \
445   void mt(int m, int n, int k, T *A, T *B, T *C)                               \
446   {                                                                            \
447     cblas_##L##gemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1, A, k, \
448                     B, k, 0, C, n);                                            \
449   }
MT_DEF(double,d)450   MT_DEF(double, d)
451   MT_DEF(float, s)
452 #undef MT_DEF
453 #define MT_DEF(T, K, L)                                                        \
454   void mt(int m, int n, int k, T *A, T *B, T *C)                               \
455   {                                                                            \
456     T alpha = 1, beta = 0;                                                     \
457     cblas_##L##gemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k,          \
458                     (K *)&alpha, (K *)A, k, (K *)B, k, (K *)&beta, (K *)C, n); \
459   }
460   MT_DEF(std::complex<float>, float, c)
461   MT_DEF(std::complex<double>, double, z)
462 #undef MT_DEF
463 
464   template <class E, class pS0, class pS1>
465   typename std::enable_if<is_blas_type<E>::value &&
466                               std::tuple_size<pS0>::value == 2 &&
467                               std::tuple_size<pS1>::value == 2,
468                           types::ndarray<E, types::array<long, 2>>>::type
469   dot(types::ndarray<E, pS0> const &a,
470       types::numpy_texpr<types::ndarray<E, pS1>> const &b)
471   {
472     int n = b.template shape<1>(), m = a.template shape<0>(),
473         k = b.template shape<0>();
474 
475     types::ndarray<E, types::array<long, 2>> out(types::array<long, 2>{{m, n}},
476                                                  builtins::None);
477     mt(m, n, k, a.buffer, b.arg.buffer, out.buffer);
478     return out;
479   }
480 
481 #define TT_DEF(T, L)                                                           \
482   void tt(int m, int n, int k, T *A, T *B, T *C)                               \
483   {                                                                            \
484     cblas_##L##gemm(CblasRowMajor, CblasTrans, CblasTrans, m, n, k, 1, A, m,   \
485                     B, k, 0, C, n);                                            \
486   }
TT_DEF(double,d)487   TT_DEF(double, d)
488   TT_DEF(float, s)
489 #undef TT_DEF
490 #define TT_DEF(T, K, L)                                                        \
491   void tt(int m, int n, int k, T *A, T *B, T *C)                               \
492   {                                                                            \
493     T alpha = 1, beta = 0;                                                     \
494     cblas_##L##gemm(CblasRowMajor, CblasTrans, CblasTrans, m, n, k,            \
495                     (K *)&alpha, (K *)A, m, (K *)B, k, (K *)&beta, (K *)C, n); \
496   }
497   TT_DEF(std::complex<float>, float, c)
498   TT_DEF(std::complex<double>, double, z)
499 #undef TT_DEF
500 
501   template <class E, class pS0, class pS1>
502   typename std::enable_if<is_blas_type<E>::value &&
503                               std::tuple_size<pS0>::value == 2 &&
504                               std::tuple_size<pS1>::value == 2,
505                           types::ndarray<E, types::array<long, 2>>>::type
506   dot(types::numpy_texpr<types::ndarray<E, pS0>> const &a,
507       types::numpy_texpr<types::ndarray<E, pS1>> const &b)
508   {
509     int n = b.template shape<1>(), m = a.template shape<0>(),
510         k = b.template shape<0>();
511 
512     types::ndarray<E, types::array<long, 2>> out(types::array<long, 2>{{m, n}},
513                                                  builtins::None);
514     tt(m, n, k, a.arg.buffer, b.arg.buffer, out.buffer);
515     return out;
516   }
517 
518   // If arguments could be use with blas, we evaluate them as we need pointer
519   // on array for blas
520   template <class E, class F>
521   typename std::enable_if<
522       types::is_numexpr_arg<E>::value &&
523           types::is_numexpr_arg<F>::value // It is an array_like
524           && (!(types::is_ndarray<E>::value && types::is_ndarray<F>::value) ||
525               !std::is_same<typename E::dtype, typename F::dtype>::value) &&
526           is_blas_type<typename E::dtype>::value &&
527           is_blas_type<typename F::dtype>::value // With dtype compatible with
528                                                  // blas
529           &&
530           E::value == 2 && F::value == 2, // And both are matrix
531       types::ndarray<
532           typename __combined<typename E::dtype, typename F::dtype>::type,
533           types::array<long, 2>>>::type
dot(E const & e,F const & f)534   dot(E const &e, F const &f)
535   {
536     types::ndarray<
537         typename __combined<typename E::dtype, typename F::dtype>::type,
538         typename E::shape_t> e_ = e;
539     types::ndarray<
540         typename __combined<typename E::dtype, typename F::dtype>::type,
541         typename F::shape_t> f_ = f;
542     return dot(e_, f_);
543   }
544 
545   // If one of the arg doesn't have a "blas compatible type", we use a slow
546   // matrix multiplication.
547   template <class E, class F>
548   typename std::enable_if<
549       (!is_blas_type<typename E::dtype>::value ||
550        !is_blas_type<typename F::dtype>::value) &&
551           E::value == 2 && F::value == 2, // And it is matrix / matrix
552       types::ndarray<
553           typename __combined<typename E::dtype, typename F::dtype>::type,
554           types::array<long, 2>>>::type
dot(E const & e,F const & f)555   dot(E const &e, F const &f)
556   {
557     types::ndarray<
558         typename __combined<typename E::dtype, typename F::dtype>::type,
559         types::array<long, 2>>
560     out(types::array<long, 2>{{e.template shape<0>(), f.template shape<1>()}},
561         0);
562     for (long i = 0; i < out.template shape<0>(); i++)
563       for (long j = 0; j < out.template shape<1>(); j++)
564         for (long k = 0; k < e.template shape<1>(); k++)
565           out[types::array<long, 2>{{i, j}}] +=
566               e[types::array<long, 2>{{i, k}}] *
567               f[types::array<long, 2>{{k, j}}];
568     return out;
569   }
570 
571   template <class E, class F>
572   typename std::enable_if<
573       (E::value >= 3 && F::value == 1), // And it is matrix / matrix
574       types::ndarray<
575           typename __combined<typename E::dtype, typename F::dtype>::type,
576           types::array<long, E::value - 1>>>::type
dot(E const & e,F const & f)577   dot(E const &e, F const &f)
578   {
579     auto out = dot(
580         e.reshape(types::array<long, 2>{{sutils::prod_head(e), f.size()}}), f);
581     types::array<long, E::value - 1> out_shape;
582     auto tmp = sutils::getshape(e);
583     std::copy(tmp.begin(), tmp.end() - 1, out_shape.begin());
584     return out.reshape(out_shape);
585   }
586 
587   template <class E, class F>
588   typename std::enable_if<
589       (E::value >= 3 && F::value >= 2),
590       types::ndarray<
591           typename __combined<typename E::dtype, typename F::dtype>::type,
592           types::array<long, E::value - 1>>>::type
dot(E const & e,F const & f)593   dot(E const &e, F const &f)
594   {
595     static_assert(E::value == 0, "not implemented yet");
596   }
597 }
598 PYTHONIC_NS_END
599 
600 #endif
601