1 /*
2  * This module provides a BLAS optimized matrix multiply,
3  * inner product and dot for numpy arrays
4  */
5 
6 #define NPY_NO_DEPRECATED_API NPY_API_VERSION
7 #define _MULTIARRAYMODULE
8 
9 #include <Python.h>
10 #include <assert.h>
11 #include <numpy/arrayobject.h>
12 #include "npy_cblas.h"
13 #include "arraytypes.h"
14 #include "common.h"
15 
16 
17 static const double oneD[2] = {1.0, 0.0}, zeroD[2] = {0.0, 0.0};
18 static const float oneF[2] = {1.0, 0.0}, zeroF[2] = {0.0, 0.0};
19 
20 
21 /*
22  * Helper: dispatch to appropriate cblas_?gemm for typenum.
23  */
24 static void
gemm(int typenum,enum CBLAS_ORDER order,enum CBLAS_TRANSPOSE transA,enum CBLAS_TRANSPOSE transB,npy_intp m,npy_intp n,npy_intp k,PyArrayObject * A,npy_intp lda,PyArrayObject * B,npy_intp ldb,PyArrayObject * R)25 gemm(int typenum, enum CBLAS_ORDER order,
26      enum CBLAS_TRANSPOSE transA, enum CBLAS_TRANSPOSE transB,
27      npy_intp m, npy_intp n, npy_intp k,
28      PyArrayObject *A, npy_intp lda, PyArrayObject *B, npy_intp ldb, PyArrayObject *R)
29 {
30     const void *Adata = PyArray_DATA(A), *Bdata = PyArray_DATA(B);
31     void *Rdata = PyArray_DATA(R);
32     npy_intp ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1;
33 
34     switch (typenum) {
35         case NPY_DOUBLE:
36             CBLAS_FUNC(cblas_dgemm)(order, transA, transB, m, n, k, 1.,
37                         Adata, lda, Bdata, ldb, 0., Rdata, ldc);
38             break;
39         case NPY_FLOAT:
40             CBLAS_FUNC(cblas_sgemm)(order, transA, transB, m, n, k, 1.f,
41                         Adata, lda, Bdata, ldb, 0.f, Rdata, ldc);
42             break;
43         case NPY_CDOUBLE:
44             CBLAS_FUNC(cblas_zgemm)(order, transA, transB, m, n, k, oneD,
45                         Adata, lda, Bdata, ldb, zeroD, Rdata, ldc);
46             break;
47         case NPY_CFLOAT:
48             CBLAS_FUNC(cblas_cgemm)(order, transA, transB, m, n, k, oneF,
49                         Adata, lda, Bdata, ldb, zeroF, Rdata, ldc);
50             break;
51     }
52 }
53 
54 
55 /*
56  * Helper: dispatch to appropriate cblas_?gemv for typenum.
57  */
58 static void
gemv(int typenum,enum CBLAS_ORDER order,enum CBLAS_TRANSPOSE trans,PyArrayObject * A,npy_intp lda,PyArrayObject * X,npy_intp incX,PyArrayObject * R)59 gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
60      PyArrayObject *A, npy_intp lda, PyArrayObject *X, npy_intp incX,
61      PyArrayObject *R)
62 {
63     const void *Adata = PyArray_DATA(A), *Xdata = PyArray_DATA(X);
64     void *Rdata = PyArray_DATA(R);
65 
66     npy_intp m = PyArray_DIM(A, 0), n = PyArray_DIM(A, 1);
67 
68     switch (typenum) {
69         case NPY_DOUBLE:
70             CBLAS_FUNC(cblas_dgemv)(order, trans, m, n, 1., Adata, lda, Xdata, incX,
71                         0., Rdata, 1);
72             break;
73         case NPY_FLOAT:
74             CBLAS_FUNC(cblas_sgemv)(order, trans, m, n, 1.f, Adata, lda, Xdata, incX,
75                         0.f, Rdata, 1);
76             break;
77         case NPY_CDOUBLE:
78             CBLAS_FUNC(cblas_zgemv)(order, trans, m, n, oneD, Adata, lda, Xdata, incX,
79                         zeroD, Rdata, 1);
80             break;
81         case NPY_CFLOAT:
82             CBLAS_FUNC(cblas_cgemv)(order, trans, m, n, oneF, Adata, lda, Xdata, incX,
83                         zeroF, Rdata, 1);
84             break;
85     }
86 }
87 
88 
89 /*
90  * Helper: dispatch to appropriate cblas_?syrk for typenum.
91  */
92 static void
syrk(int typenum,enum CBLAS_ORDER order,enum CBLAS_TRANSPOSE trans,npy_intp n,npy_intp k,PyArrayObject * A,npy_intp lda,PyArrayObject * R)93 syrk(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
94      npy_intp n, npy_intp k,
95      PyArrayObject *A, npy_intp lda, PyArrayObject *R)
96 {
97     const void *Adata = PyArray_DATA(A);
98     void *Rdata = PyArray_DATA(R);
99     npy_intp ldc = PyArray_DIM(R, 1) > 1 ? PyArray_DIM(R, 1) : 1;
100 
101     npy_intp i;
102     npy_intp j;
103 
104     switch (typenum) {
105         case NPY_DOUBLE:
106             CBLAS_FUNC(cblas_dsyrk)(order, CblasUpper, trans, n, k, 1.,
107                         Adata, lda, 0., Rdata, ldc);
108 
109             for (i = 0; i < n; i++) {
110                 for (j = i + 1; j < n; j++) {
111                     *((npy_double*)PyArray_GETPTR2(R, j, i)) =
112                             *((npy_double*)PyArray_GETPTR2(R, i, j));
113                 }
114             }
115             break;
116         case NPY_FLOAT:
117             CBLAS_FUNC(cblas_ssyrk)(order, CblasUpper, trans, n, k, 1.f,
118                         Adata, lda, 0.f, Rdata, ldc);
119 
120             for (i = 0; i < n; i++) {
121                 for (j = i + 1; j < n; j++) {
122                     *((npy_float*)PyArray_GETPTR2(R, j, i)) =
123                             *((npy_float*)PyArray_GETPTR2(R, i, j));
124                 }
125             }
126             break;
127         case NPY_CDOUBLE:
128             CBLAS_FUNC(cblas_zsyrk)(order, CblasUpper, trans, n, k, oneD,
129                         Adata, lda, zeroD, Rdata, ldc);
130 
131             for (i = 0; i < n; i++) {
132                 for (j = i + 1; j < n; j++) {
133                     *((npy_cdouble*)PyArray_GETPTR2(R, j, i)) =
134                             *((npy_cdouble*)PyArray_GETPTR2(R, i, j));
135                 }
136             }
137             break;
138         case NPY_CFLOAT:
139             CBLAS_FUNC(cblas_csyrk)(order, CblasUpper, trans, n, k, oneF,
140                         Adata, lda, zeroF, Rdata, ldc);
141 
142             for (i = 0; i < n; i++) {
143                 for (j = i + 1; j < n; j++) {
144                     *((npy_cfloat*)PyArray_GETPTR2(R, j, i)) =
145                             *((npy_cfloat*)PyArray_GETPTR2(R, i, j));
146                 }
147             }
148             break;
149     }
150 }
151 
152 
153 typedef enum {_scalar, _column, _row, _matrix} MatrixShape;
154 
155 
156 static MatrixShape
_select_matrix_shape(PyArrayObject * array)157 _select_matrix_shape(PyArrayObject *array)
158 {
159     switch (PyArray_NDIM(array)) {
160         case 0:
161             return _scalar;
162         case 1:
163             if (PyArray_DIM(array, 0) > 1)
164                 return _column;
165             return _scalar;
166         case 2:
167             if (PyArray_DIM(array, 0) > 1) {
168                 if (PyArray_DIM(array, 1) == 1)
169                     return _column;
170                 else
171                     return _matrix;
172             }
173             if (PyArray_DIM(array, 1) == 1)
174                 return _scalar;
175             return _row;
176     }
177     return _matrix;
178 }
179 
180 
181 /*
182  * This also makes sure that the data segment is aligned with
183  * an itemsize address as well by returning one if not true.
184  */
185 NPY_NO_EXPORT int
_bad_strides(PyArrayObject * ap)186 _bad_strides(PyArrayObject *ap)
187 {
188     int itemsize = PyArray_ITEMSIZE(ap);
189     int i, N=PyArray_NDIM(ap);
190     npy_intp *strides = PyArray_STRIDES(ap);
191     npy_intp *dims = PyArray_DIMS(ap);
192 
193     if (((npy_intp)(PyArray_DATA(ap)) % itemsize) != 0) {
194         return 1;
195     }
196     for (i = 0; i < N; i++) {
197         if ((strides[i] < 0) || (strides[i] % itemsize) != 0) {
198             return 1;
199         }
200         if ((strides[i] == 0 && dims[i] > 1)) {
201             return 1;
202         }
203     }
204 
205     return 0;
206 }
207 
208 /*
209  * dot(a,b)
210  * Returns the dot product of a and b for arrays of floating point types.
211  * Like the generic numpy equivalent the product sum is over
212  * the last dimension of a and the second-to-last dimension of b.
213  * NB: The first argument is not conjugated.;
214  *
215  * This is for use by PyArray_MatrixProduct2. It is assumed on entry that
216  * the arrays ap1 and ap2 have a common data type given by typenum that is
217  * float, double, cfloat, or cdouble and have dimension <= 2. The
218  * __array_ufunc__ nonsense is also assumed to have been taken care of.
219  */
220 NPY_NO_EXPORT PyObject *
cblas_matrixproduct(int typenum,PyArrayObject * ap1,PyArrayObject * ap2,PyArrayObject * out)221 cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
222                     PyArrayObject *out)
223 {
224     PyArrayObject *result = NULL, *out_buf = NULL;
225     npy_intp j, lda, ldb;
226     npy_intp l;
227     int nd;
228     npy_intp ap1stride = 0;
229     npy_intp dimensions[NPY_MAXDIMS];
230     npy_intp numbytes;
231     MatrixShape ap1shape, ap2shape;
232 
233     if (_bad_strides(ap1)) {
234             PyObject *op1 = PyArray_NewCopy(ap1, NPY_ANYORDER);
235 
236             Py_DECREF(ap1);
237             ap1 = (PyArrayObject *)op1;
238             if (ap1 == NULL) {
239                 goto fail;
240             }
241     }
242     if (_bad_strides(ap2)) {
243             PyObject *op2 = PyArray_NewCopy(ap2, NPY_ANYORDER);
244 
245             Py_DECREF(ap2);
246             ap2 = (PyArrayObject *)op2;
247             if (ap2 == NULL) {
248                 goto fail;
249             }
250     }
251     ap1shape = _select_matrix_shape(ap1);
252     ap2shape = _select_matrix_shape(ap2);
253 
254     if (ap1shape == _scalar || ap2shape == _scalar) {
255         PyArrayObject *oap1, *oap2;
256         oap1 = ap1; oap2 = ap2;
257         /* One of ap1 or ap2 is a scalar */
258         if (ap1shape == _scalar) {
259             /* Make ap2 the scalar */
260             PyArrayObject *t = ap1;
261             ap1 = ap2;
262             ap2 = t;
263             ap1shape = ap2shape;
264             ap2shape = _scalar;
265         }
266 
267         if (ap1shape == _row) {
268             ap1stride = PyArray_STRIDE(ap1, 1);
269         }
270         else if (PyArray_NDIM(ap1) > 0) {
271             ap1stride = PyArray_STRIDE(ap1, 0);
272         }
273 
274         if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) {
275             npy_intp *thisdims;
276             if (PyArray_NDIM(ap1) == 0) {
277                 nd = PyArray_NDIM(ap2);
278                 thisdims = PyArray_DIMS(ap2);
279             }
280             else {
281                 nd = PyArray_NDIM(ap1);
282                 thisdims = PyArray_DIMS(ap1);
283             }
284             l = 1;
285             for (j = 0; j < nd; j++) {
286                 dimensions[j] = thisdims[j];
287                 l *= dimensions[j];
288             }
289         }
290         else {
291             l = PyArray_DIM(oap1, PyArray_NDIM(oap1) - 1);
292 
293             if (PyArray_DIM(oap2, 0) != l) {
294                 dot_alignment_error(oap1, PyArray_NDIM(oap1) - 1, oap2, 0);
295                 goto fail;
296             }
297             nd = PyArray_NDIM(ap1) + PyArray_NDIM(ap2) - 2;
298             /*
299              * nd = 0 or 1 or 2. If nd == 0 do nothing ...
300              */
301             if (nd == 1) {
302                 /*
303                  * Either PyArray_NDIM(ap1) is 1 dim or PyArray_NDIM(ap2) is
304                  * 1 dim and the other is 2 dim
305                  */
306                 dimensions[0] = (PyArray_NDIM(oap1) == 2) ?
307                                 PyArray_DIM(oap1, 0) : PyArray_DIM(oap2, 1);
308                 l = dimensions[0];
309                 /*
310                  * Fix it so that dot(shape=(N,1), shape=(1,))
311                  * and dot(shape=(1,), shape=(1,N)) both return
312                  * an (N,) array (but use the fast scalar code)
313                  */
314             }
315             else if (nd == 2) {
316                 dimensions[0] = PyArray_DIM(oap1, 0);
317                 dimensions[1] = PyArray_DIM(oap2, 1);
318                 /*
319                  * We need to make sure that dot(shape=(1,1), shape=(1,N))
320                  * and dot(shape=(N,1),shape=(1,1)) uses
321                  * scalar multiplication appropriately
322                  */
323                 if (ap1shape == _row) {
324                     l = dimensions[1];
325                 }
326                 else {
327                     l = dimensions[0];
328                 }
329             }
330 
331             /* Check if the summation dimension is 0-sized */
332             if (PyArray_DIM(oap1, PyArray_NDIM(oap1) - 1) == 0) {
333                 l = 0;
334             }
335         }
336     }
337     else {
338         /*
339          * (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2)
340          * Both ap1 and ap2 are vectors or matrices
341          */
342         l = PyArray_DIM(ap1, PyArray_NDIM(ap1) - 1);
343 
344         if (PyArray_DIM(ap2, 0) != l) {
345             dot_alignment_error(ap1, PyArray_NDIM(ap1) - 1, ap2, 0);
346             goto fail;
347         }
348         nd = PyArray_NDIM(ap1) + PyArray_NDIM(ap2) - 2;
349 
350         if (nd == 1) {
351             dimensions[0] = (PyArray_NDIM(ap1) == 2) ?
352                             PyArray_DIM(ap1, 0) : PyArray_DIM(ap2, 1);
353         }
354         else if (nd == 2) {
355             dimensions[0] = PyArray_DIM(ap1, 0);
356             dimensions[1] = PyArray_DIM(ap2, 1);
357         }
358     }
359 
360     out_buf = new_array_for_sum(ap1, ap2, out, nd, dimensions, typenum, &result);
361     if (out_buf == NULL) {
362         goto fail;
363     }
364 
365     numbytes = PyArray_NBYTES(out_buf);
366     memset(PyArray_DATA(out_buf), 0, numbytes);
367     if (numbytes == 0 || l == 0) {
368             Py_DECREF(ap1);
369             Py_DECREF(ap2);
370             Py_DECREF(out_buf);
371             return PyArray_Return(result);
372     }
373 
374     if (ap2shape == _scalar) {
375         /*
376          * Multiplication by a scalar -- Level 1 BLAS
377          * if ap1shape is a matrix and we are not contiguous, then we can't
378          * just blast through the entire array using a single striding factor
379          */
380         NPY_BEGIN_ALLOW_THREADS;
381 
382         if (typenum == NPY_DOUBLE) {
383             if (l == 1) {
384                 *((double *)PyArray_DATA(out_buf)) = *((double *)PyArray_DATA(ap2)) *
385                                                  *((double *)PyArray_DATA(ap1));
386             }
387             else if (ap1shape != _matrix) {
388                 CBLAS_FUNC(cblas_daxpy)(l,
389                             *((double *)PyArray_DATA(ap2)),
390                             (double *)PyArray_DATA(ap1),
391                             ap1stride/sizeof(double),
392                             (double *)PyArray_DATA(out_buf), 1);
393             }
394             else {
395                 int maxind, oind;
396                 npy_intp i, a1s, outs;
397                 char *ptr, *optr;
398                 double val;
399 
400                 maxind = (PyArray_DIM(ap1, 0) >= PyArray_DIM(ap1, 1) ? 0 : 1);
401                 oind = 1 - maxind;
402                 ptr = PyArray_DATA(ap1);
403                 optr = PyArray_DATA(out_buf);
404                 l = PyArray_DIM(ap1, maxind);
405                 val = *((double *)PyArray_DATA(ap2));
406                 a1s = PyArray_STRIDE(ap1, maxind) / sizeof(double);
407                 outs = PyArray_STRIDE(out_buf, maxind) / sizeof(double);
408                 for (i = 0; i < PyArray_DIM(ap1, oind); i++) {
409                     CBLAS_FUNC(cblas_daxpy)(l, val, (double *)ptr, a1s,
410                                 (double *)optr, outs);
411                     ptr += PyArray_STRIDE(ap1, oind);
412                     optr += PyArray_STRIDE(out_buf, oind);
413                 }
414             }
415         }
416         else if (typenum == NPY_CDOUBLE) {
417             if (l == 1) {
418                 npy_cdouble *ptr1, *ptr2, *res;
419 
420                 ptr1 = (npy_cdouble *)PyArray_DATA(ap2);
421                 ptr2 = (npy_cdouble *)PyArray_DATA(ap1);
422                 res = (npy_cdouble *)PyArray_DATA(out_buf);
423                 res->real = ptr1->real * ptr2->real - ptr1->imag * ptr2->imag;
424                 res->imag = ptr1->real * ptr2->imag + ptr1->imag * ptr2->real;
425             }
426             else if (ap1shape != _matrix) {
427                 CBLAS_FUNC(cblas_zaxpy)(l,
428                             (double *)PyArray_DATA(ap2),
429                             (double *)PyArray_DATA(ap1),
430                             ap1stride/sizeof(npy_cdouble),
431                             (double *)PyArray_DATA(out_buf), 1);
432             }
433             else {
434                 int maxind, oind;
435                 npy_intp i, a1s, outs;
436                 char *ptr, *optr;
437                 double *pval;
438 
439                 maxind = (PyArray_DIM(ap1, 0) >= PyArray_DIM(ap1, 1) ? 0 : 1);
440                 oind = 1 - maxind;
441                 ptr = PyArray_DATA(ap1);
442                 optr = PyArray_DATA(out_buf);
443                 l = PyArray_DIM(ap1, maxind);
444                 pval = (double *)PyArray_DATA(ap2);
445                 a1s = PyArray_STRIDE(ap1, maxind) / sizeof(npy_cdouble);
446                 outs = PyArray_STRIDE(out_buf, maxind) / sizeof(npy_cdouble);
447                 for (i = 0; i < PyArray_DIM(ap1, oind); i++) {
448                     CBLAS_FUNC(cblas_zaxpy)(l, pval, (double *)ptr, a1s,
449                                 (double *)optr, outs);
450                     ptr += PyArray_STRIDE(ap1, oind);
451                     optr += PyArray_STRIDE(out_buf, oind);
452                 }
453             }
454         }
455         else if (typenum == NPY_FLOAT) {
456             if (l == 1) {
457                 *((float *)PyArray_DATA(out_buf)) = *((float *)PyArray_DATA(ap2)) *
458                     *((float *)PyArray_DATA(ap1));
459             }
460             else if (ap1shape != _matrix) {
461                 CBLAS_FUNC(cblas_saxpy)(l,
462                             *((float *)PyArray_DATA(ap2)),
463                             (float *)PyArray_DATA(ap1),
464                             ap1stride/sizeof(float),
465                             (float *)PyArray_DATA(out_buf), 1);
466             }
467             else {
468                 int maxind, oind;
469                 npy_intp i, a1s, outs;
470                 char *ptr, *optr;
471                 float val;
472 
473                 maxind = (PyArray_DIM(ap1, 0) >= PyArray_DIM(ap1, 1) ? 0 : 1);
474                 oind = 1 - maxind;
475                 ptr = PyArray_DATA(ap1);
476                 optr = PyArray_DATA(out_buf);
477                 l = PyArray_DIM(ap1, maxind);
478                 val = *((float *)PyArray_DATA(ap2));
479                 a1s = PyArray_STRIDE(ap1, maxind) / sizeof(float);
480                 outs = PyArray_STRIDE(out_buf, maxind) / sizeof(float);
481                 for (i = 0; i < PyArray_DIM(ap1, oind); i++) {
482                     CBLAS_FUNC(cblas_saxpy)(l, val, (float *)ptr, a1s,
483                                 (float *)optr, outs);
484                     ptr += PyArray_STRIDE(ap1, oind);
485                     optr += PyArray_STRIDE(out_buf, oind);
486                 }
487             }
488         }
489         else if (typenum == NPY_CFLOAT) {
490             if (l == 1) {
491                 npy_cfloat *ptr1, *ptr2, *res;
492 
493                 ptr1 = (npy_cfloat *)PyArray_DATA(ap2);
494                 ptr2 = (npy_cfloat *)PyArray_DATA(ap1);
495                 res = (npy_cfloat *)PyArray_DATA(out_buf);
496                 res->real = ptr1->real * ptr2->real - ptr1->imag * ptr2->imag;
497                 res->imag = ptr1->real * ptr2->imag + ptr1->imag * ptr2->real;
498             }
499             else if (ap1shape != _matrix) {
500                 CBLAS_FUNC(cblas_caxpy)(l,
501                             (float *)PyArray_DATA(ap2),
502                             (float *)PyArray_DATA(ap1),
503                             ap1stride/sizeof(npy_cfloat),
504                             (float *)PyArray_DATA(out_buf), 1);
505             }
506             else {
507                 int maxind, oind;
508                 npy_intp i, a1s, outs;
509                 char *ptr, *optr;
510                 float *pval;
511 
512                 maxind = (PyArray_DIM(ap1, 0) >= PyArray_DIM(ap1, 1) ? 0 : 1);
513                 oind = 1 - maxind;
514                 ptr = PyArray_DATA(ap1);
515                 optr = PyArray_DATA(out_buf);
516                 l = PyArray_DIM(ap1, maxind);
517                 pval = (float *)PyArray_DATA(ap2);
518                 a1s = PyArray_STRIDE(ap1, maxind) / sizeof(npy_cfloat);
519                 outs = PyArray_STRIDE(out_buf, maxind) / sizeof(npy_cfloat);
520                 for (i = 0; i < PyArray_DIM(ap1, oind); i++) {
521                     CBLAS_FUNC(cblas_caxpy)(l, pval, (float *)ptr, a1s,
522                                 (float *)optr, outs);
523                     ptr += PyArray_STRIDE(ap1, oind);
524                     optr += PyArray_STRIDE(out_buf, oind);
525                 }
526             }
527         }
528         NPY_END_ALLOW_THREADS;
529     }
530     else if ((ap2shape == _column) && (ap1shape != _matrix)) {
531         NPY_BEGIN_ALLOW_THREADS;
532 
533         /* Dot product between two vectors -- Level 1 BLAS */
534         PyArray_DESCR(out_buf)->f->dotfunc(
535                  PyArray_DATA(ap1), PyArray_STRIDE(ap1, (ap1shape == _row)),
536                  PyArray_DATA(ap2), PyArray_STRIDE(ap2, 0),
537                  PyArray_DATA(out_buf), l, NULL);
538         NPY_END_ALLOW_THREADS;
539     }
540     else if (ap1shape == _matrix && ap2shape != _matrix) {
541         /* Matrix vector multiplication -- Level 2 BLAS */
542         /* lda must be MAX(M,1) */
543         enum CBLAS_ORDER Order;
544         npy_intp ap2s;
545 
546         if (!PyArray_ISONESEGMENT(ap1)) {
547             PyObject *new;
548             new = PyArray_Copy(ap1);
549             Py_DECREF(ap1);
550             ap1 = (PyArrayObject *)new;
551             if (new == NULL) {
552                 goto fail;
553             }
554         }
555         NPY_BEGIN_ALLOW_THREADS
556         if (PyArray_ISCONTIGUOUS(ap1)) {
557             Order = CblasRowMajor;
558             lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1);
559         }
560         else {
561             Order = CblasColMajor;
562             lda = (PyArray_DIM(ap1, 0) > 1 ? PyArray_DIM(ap1, 0) : 1);
563         }
564         ap2s = PyArray_STRIDE(ap2, 0) / PyArray_ITEMSIZE(ap2);
565         gemv(typenum, Order, CblasNoTrans, ap1, lda, ap2, ap2s, out_buf);
566         NPY_END_ALLOW_THREADS;
567     }
568     else if (ap1shape != _matrix && ap2shape == _matrix) {
569         /* Vector matrix multiplication -- Level 2 BLAS */
570         enum CBLAS_ORDER Order;
571         npy_intp ap1s;
572 
573         if (!PyArray_ISONESEGMENT(ap2)) {
574             PyObject *new;
575             new = PyArray_Copy(ap2);
576             Py_DECREF(ap2);
577             ap2 = (PyArrayObject *)new;
578             if (new == NULL) {
579                 goto fail;
580             }
581         }
582         NPY_BEGIN_ALLOW_THREADS
583         if (PyArray_ISCONTIGUOUS(ap2)) {
584             Order = CblasRowMajor;
585             lda = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1);
586         }
587         else {
588             Order = CblasColMajor;
589             lda = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1);
590         }
591         if (ap1shape == _row) {
592             ap1s = PyArray_STRIDE(ap1, 1) / PyArray_ITEMSIZE(ap1);
593         }
594         else {
595             ap1s = PyArray_STRIDE(ap1, 0) / PyArray_ITEMSIZE(ap1);
596         }
597         gemv(typenum, Order, CblasTrans, ap2, lda, ap1, ap1s, out_buf);
598         NPY_END_ALLOW_THREADS;
599     }
600     else {
601         /*
602          * (PyArray_NDIM(ap1) == 2 && PyArray_NDIM(ap2) == 2)
603          * Matrix matrix multiplication -- Level 3 BLAS
604          *  L x M  multiplied by M x N
605          */
606         enum CBLAS_ORDER Order;
607         enum CBLAS_TRANSPOSE Trans1, Trans2;
608         npy_intp M, N, L;
609 
610         /* Optimization possible: */
611         /*
612          * We may be able to handle single-segment arrays here
613          * using appropriate values of Order, Trans1, and Trans2.
614          */
615         if (!PyArray_IS_C_CONTIGUOUS(ap2) && !PyArray_IS_F_CONTIGUOUS(ap2)) {
616             PyObject *new = PyArray_Copy(ap2);
617 
618             Py_DECREF(ap2);
619             ap2 = (PyArrayObject *)new;
620             if (new == NULL) {
621                 goto fail;
622             }
623         }
624         if (!PyArray_IS_C_CONTIGUOUS(ap1) && !PyArray_IS_F_CONTIGUOUS(ap1)) {
625             PyObject *new = PyArray_Copy(ap1);
626 
627             Py_DECREF(ap1);
628             ap1 = (PyArrayObject *)new;
629             if (new == NULL) {
630                 goto fail;
631             }
632         }
633 
634         NPY_BEGIN_ALLOW_THREADS;
635 
636         Order = CblasRowMajor;
637         Trans1 = CblasNoTrans;
638         Trans2 = CblasNoTrans;
639         L = PyArray_DIM(ap1, 0);
640         N = PyArray_DIM(ap2, 1);
641         M = PyArray_DIM(ap2, 0);
642         lda = (PyArray_DIM(ap1, 1) > 1 ? PyArray_DIM(ap1, 1) : 1);
643         ldb = (PyArray_DIM(ap2, 1) > 1 ? PyArray_DIM(ap2, 1) : 1);
644 
645         /*
646          * Avoid temporary copies for arrays in Fortran order
647          */
648         if (PyArray_IS_F_CONTIGUOUS(ap1)) {
649             Trans1 = CblasTrans;
650             lda = (PyArray_DIM(ap1, 0) > 1 ? PyArray_DIM(ap1, 0) : 1);
651         }
652         if (PyArray_IS_F_CONTIGUOUS(ap2)) {
653             Trans2 = CblasTrans;
654             ldb = (PyArray_DIM(ap2, 0) > 1 ? PyArray_DIM(ap2, 0) : 1);
655         }
656 
657         /*
658          * Use syrk if we have a case of a matrix times its transpose.
659          * Otherwise, use gemm for all other cases.
660          */
661         if (
662             (PyArray_BYTES(ap1) == PyArray_BYTES(ap2)) &&
663             (PyArray_DIM(ap1, 0) == PyArray_DIM(ap2, 1)) &&
664             (PyArray_DIM(ap1, 1) == PyArray_DIM(ap2, 0)) &&
665             (PyArray_STRIDE(ap1, 0) == PyArray_STRIDE(ap2, 1)) &&
666             (PyArray_STRIDE(ap1, 1) == PyArray_STRIDE(ap2, 0)) &&
667             ((Trans1 == CblasTrans) ^ (Trans2 == CblasTrans)) &&
668             ((Trans1 == CblasNoTrans) ^ (Trans2 == CblasNoTrans))
669         ) {
670             if (Trans1 == CblasNoTrans) {
671                 syrk(typenum, Order, Trans1, N, M, ap1, lda, out_buf);
672             }
673             else {
674                 syrk(typenum, Order, Trans1, N, M, ap2, ldb, out_buf);
675             }
676         }
677         else {
678             gemm(typenum, Order, Trans1, Trans2, L, N, M, ap1, lda, ap2, ldb,
679                  out_buf);
680         }
681         NPY_END_ALLOW_THREADS;
682     }
683 
684 
685     Py_DECREF(ap1);
686     Py_DECREF(ap2);
687 
688     /* Trigger possible copyback into `result` */
689     PyArray_ResolveWritebackIfCopy(out_buf);
690     Py_DECREF(out_buf);
691 
692     return PyArray_Return(result);
693 
694 fail:
695     Py_XDECREF(ap1);
696     Py_XDECREF(ap2);
697     Py_XDECREF(out_buf);
698     Py_XDECREF(result);
699     return NULL;
700 }
701