1 #define NPY_NO_DEPRECATED_API NPY_API_VERSION
2 #define _MULTIARRAYMODULE
3 
4 #include <Python.h>
5 #include "common.h"
6 #include "vdot.h"
7 #include "npy_cblas.h"
8 
9 
10 /*
11  * All data is assumed aligned.
12  */
13 NPY_NO_EXPORT void
CFLOAT_vdot(char * ip1,npy_intp is1,char * ip2,npy_intp is2,char * op,npy_intp n,void * NPY_UNUSED (ignore))14 CFLOAT_vdot(char *ip1, npy_intp is1, char *ip2, npy_intp is2,
15             char *op, npy_intp n, void *NPY_UNUSED(ignore))
16 {
17 #if defined(HAVE_CBLAS)
18     CBLAS_INT is1b = blas_stride(is1, sizeof(npy_cfloat));
19     CBLAS_INT is2b = blas_stride(is2, sizeof(npy_cfloat));
20 
21     if (is1b && is2b) {
22         double sum[2] = {0., 0.};  /* double for stability */
23 
24         while (n > 0) {
25             CBLAS_INT chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK;
26             float tmp[2];
27 
28             CBLAS_FUNC(cblas_cdotc_sub)((CBLAS_INT)n, ip1, is1b, ip2, is2b, tmp);
29             sum[0] += (double)tmp[0];
30             sum[1] += (double)tmp[1];
31             /* use char strides here */
32             ip1 += chunk * is1;
33             ip2 += chunk * is2;
34             n -= chunk;
35         }
36         ((float *)op)[0] = (float)sum[0];
37         ((float *)op)[1] = (float)sum[1];
38     }
39     else
40 #endif
41     {
42         float sumr = (float)0.0;
43         float sumi = (float)0.0;
44         npy_intp i;
45 
46         for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) {
47             const float ip1r = ((float *)ip1)[0];
48             const float ip1i = ((float *)ip1)[1];
49             const float ip2r = ((float *)ip2)[0];
50             const float ip2i = ((float *)ip2)[1];
51 
52             sumr += ip1r * ip2r + ip1i * ip2i;
53             sumi += ip1r * ip2i - ip1i * ip2r;
54         }
55         ((float *)op)[0] = sumr;
56         ((float *)op)[1] = sumi;
57     }
58 }
59 
60 
61 /*
62  * All data is assumed aligned.
63  */
64 NPY_NO_EXPORT void
CDOUBLE_vdot(char * ip1,npy_intp is1,char * ip2,npy_intp is2,char * op,npy_intp n,void * NPY_UNUSED (ignore))65 CDOUBLE_vdot(char *ip1, npy_intp is1, char *ip2, npy_intp is2,
66              char *op, npy_intp n, void *NPY_UNUSED(ignore))
67 {
68 #if defined(HAVE_CBLAS)
69     CBLAS_INT is1b = blas_stride(is1, sizeof(npy_cdouble));
70     CBLAS_INT is2b = blas_stride(is2, sizeof(npy_cdouble));
71 
72     if (is1b && is2b) {
73         double sum[2] = {0., 0.};  /* double for stability */
74 
75         while (n > 0) {
76             CBLAS_INT chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK;
77             double tmp[2];
78 
79             CBLAS_FUNC(cblas_zdotc_sub)((CBLAS_INT)n, ip1, is1b, ip2, is2b, tmp);
80             sum[0] += (double)tmp[0];
81             sum[1] += (double)tmp[1];
82             /* use char strides here */
83             ip1 += chunk * is1;
84             ip2 += chunk * is2;
85             n -= chunk;
86         }
87         ((double *)op)[0] = (double)sum[0];
88         ((double *)op)[1] = (double)sum[1];
89     }
90     else
91 #endif
92     {
93         double sumr = (double)0.0;
94         double sumi = (double)0.0;
95         npy_intp i;
96 
97         for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) {
98             const double ip1r = ((double *)ip1)[0];
99             const double ip1i = ((double *)ip1)[1];
100             const double ip2r = ((double *)ip2)[0];
101             const double ip2i = ((double *)ip2)[1];
102 
103             sumr += ip1r * ip2r + ip1i * ip2i;
104             sumi += ip1r * ip2i - ip1i * ip2r;
105         }
106         ((double *)op)[0] = sumr;
107         ((double *)op)[1] = sumi;
108     }
109 }
110 
111 
112 /*
113  * All data is assumed aligned.
114  */
115 NPY_NO_EXPORT void
CLONGDOUBLE_vdot(char * ip1,npy_intp is1,char * ip2,npy_intp is2,char * op,npy_intp n,void * NPY_UNUSED (ignore))116 CLONGDOUBLE_vdot(char *ip1, npy_intp is1, char *ip2, npy_intp is2,
117                  char *op, npy_intp n, void *NPY_UNUSED(ignore))
118 {
119     npy_longdouble tmpr = 0.0L;
120     npy_longdouble tmpi = 0.0L;
121     npy_intp i;
122 
123     for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) {
124         const npy_longdouble ip1r = ((npy_longdouble *)ip1)[0];
125         const npy_longdouble ip1i = ((npy_longdouble *)ip1)[1];
126         const npy_longdouble ip2r = ((npy_longdouble *)ip2)[0];
127         const npy_longdouble ip2i = ((npy_longdouble *)ip2)[1];
128 
129         tmpr += ip1r * ip2r + ip1i * ip2i;
130         tmpi += ip1r * ip2i - ip1i * ip2r;
131     }
132     ((npy_longdouble *)op)[0] = tmpr;
133     ((npy_longdouble *)op)[1] = tmpi;
134 }
135 
136 /*
137  * All data is assumed aligned.
138  */
139 NPY_NO_EXPORT void
OBJECT_vdot(char * ip1,npy_intp is1,char * ip2,npy_intp is2,char * op,npy_intp n,void * NPY_UNUSED (ignore))140 OBJECT_vdot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n,
141             void *NPY_UNUSED(ignore))
142 {
143     npy_intp i;
144     PyObject *tmp0, *tmp1, *tmp2, *tmp = NULL;
145     PyObject **tmp3;
146     for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) {
147         if ((*((PyObject **)ip1) == NULL) || (*((PyObject **)ip2) == NULL)) {
148             tmp1 = Py_False;
149             Py_INCREF(Py_False);
150         }
151         else {
152             tmp0 = PyObject_CallMethod(*((PyObject **)ip1), "conjugate", NULL);
153             if (tmp0 == NULL) {
154                 Py_XDECREF(tmp);
155                 return;
156             }
157             tmp1 = PyNumber_Multiply(tmp0, *((PyObject **)ip2));
158             Py_DECREF(tmp0);
159             if (tmp1 == NULL) {
160                 Py_XDECREF(tmp);
161                 return;
162             }
163         }
164         if (i == 0) {
165             tmp = tmp1;
166         }
167         else {
168             tmp2 = PyNumber_Add(tmp, tmp1);
169             Py_XDECREF(tmp);
170             Py_XDECREF(tmp1);
171             if (tmp2 == NULL) {
172                 return;
173             }
174             tmp = tmp2;
175         }
176     }
177     tmp3 = (PyObject**) op;
178     tmp2 = *tmp3;
179     *((PyObject **)op) = tmp;
180     Py_XDECREF(tmp2);
181 }
182