1 /***************************************************************************
2 * Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay          *
3 * Copyright (c) QuantStack                                                 *
4 *                                                                          *
5 * Distributed under the terms of the BSD 3-Clause License.                 *
6 *                                                                          *
7 * The full license is in the file LICENSE, distributed with this software. *
8 ****************************************************************************/
9 
10 #ifndef XBLAS_HPP
11 #define XBLAS_HPP
12 
13 #include <algorithm>
14 
15 #include "xtensor/xarray.hpp"
16 #include "xtensor/xcomplex.hpp"
17 #include "xtensor/xio.hpp"
18 #include "xtensor/xtensor.hpp"
19 #include "xtensor/xutils.hpp"
20 
21 #include "xtensor-blas/xblas_config.hpp"
22 #include "xtensor-blas/xblas_utils.hpp"
23 
24 #include "xflens/cxxblas/cxxblas.cxx"
25 
26 namespace xt
27 {
28 
29 namespace blas
30 {
31     /**
32      * Calculate the 1-norm of a vector
33      *
34      * @param a vector of n elements
35      * @returns scalar result
36      */
37     template <class E, class R>
asum(const xexpression<E> & a,R & result)38     void asum(const xexpression<E>& a, R& result)
39     {
40         auto&& ad = view_eval<E::static_layout>(a.derived_cast());
41         XTENSOR_ASSERT(ad.dimension() == 1);
42 
43         cxxblas::asum<blas_index_t>(
44             static_cast<blas_index_t>(ad.shape()[0]),
45             ad.data() + ad.data_offset(),
46             stride_front(ad),
47             result
48         );
49     }
50 
51     /**
52      * Calculate the 2-norm of a vector
53      *
54      * @param a vector of n elements
55      * @returns scalar result
56      */
57     template <class E, class R>
nrm2(const xexpression<E> & a,R & result)58     void nrm2(const xexpression<E>& a, R& result)
59     {
60         auto&& ad = view_eval<E::static_layout>(a.derived_cast());
61         XTENSOR_ASSERT(ad.dimension() == 1);
62 
63         cxxblas::nrm2<blas_index_t>(
64             static_cast<blas_index_t>(ad.shape()[0]),
65             ad.data() + ad.data_offset(),
66             stride_front(ad),
67             result
68         );
69     }
70 
71     /**
72      * Calculate the dot product between two vectors, conjugating
73      * the first argument \em a in the case of complex vectors.
74      *
75      * @param a vector of n elements
76      * @param b vector of n elements
77      * @returns scalar result
78      */
79     template <class E1, class E2, class R>
dot(const xexpression<E1> & a,const xexpression<E2> & b,R & result)80     void dot(const xexpression<E1>& a, const xexpression<E2>& b,
81              R& result)
82     {
83         auto&& ad = view_eval<E1::static_layout>(a.derived_cast());
84         auto&& bd = view_eval<E2::static_layout>(b.derived_cast());
85         XTENSOR_ASSERT(ad.dimension() == 1);
86 
87         blas_index_t stride_a = stride_front(ad);
88         blas_index_t stride_b = stride_front(bd);
89 
90         auto* adt = ad.data() + ad.data_offset();
91         auto* bdt = bd.data() + bd.data_offset();
92 
93         // we need to have a pointer that points to the "real" start of the memory
94         // not to the first element (BLAS is doing that transformation itself)
95         if (stride_a < 0) {
96             adt += (static_cast<blas_index_t>(ad.shape()[0]) - 1) * stride_a; // go back to the start
97         }
98         if (stride_b < 0) {
99             bdt += (static_cast<blas_index_t>(ad.shape()[0]) - 1) * stride_b; // go back to the start
100         }
101 
102         cxxblas::dot<blas_index_t>(
103             static_cast<blas_index_t>(ad.shape()[0]),
104             adt,
105             stride_a,
106             bdt,
107             stride_b,
108             result
109         );
110     }
111 
112     /**
113      * Calculate the dot product between two complex vectors, not conjugating the
114      * first argument \em a.
115      *
116      * @param a vector of n elements
117      * @param b vector of n elements
118      * @returns scalar result
119      */
120     template <class E1, class E2, class R>
dotu(const xexpression<E1> & a,const xexpression<E2> & b,R & result)121     void dotu(const xexpression<E1>& a, const xexpression<E2>& b, R& result)
122     {
123         auto&& ad = view_eval<E1::static_layout>(a.derived_cast());
124         auto&& bd = view_eval<E2::static_layout>(b.derived_cast());
125         XTENSOR_ASSERT(ad.dimension() == 1);
126 
127         blas_index_t stride_a = stride_front(ad);
128         blas_index_t stride_b = stride_front(bd);
129 
130         auto* adt = ad.data() + ad.data_offset();
131         auto* bdt = bd.data() + bd.data_offset();
132 
133         // we need to have a pointer that points to the "real" start of the memory
134         // not to the first element (BLAS is doing that transformation itself)
135         if (stride_a < 0) {
136             adt += (static_cast<blas_index_t>(ad.shape()[0]) - 1) * stride_a; // go back to the start
137         }
138         if (stride_b < 0) {
139             bdt += (static_cast<blas_index_t>(ad.shape()[0]) - 1) * stride_b; // go back to the start
140         }
141 
142         cxxblas::dotu<blas_index_t>(
143             static_cast<blas_index_t>(ad.shape()[0]),
144             adt,
145             stride_a,
146             bdt,
147             stride_b,
148             result
149         );
150     }
151 
152     /**
153      * Calculate the general matrix times vector product according to
154      * ``y := alpha * A * x + beta * y``.
155      *
156      * @param A matrix of n x m elements
157      * @param x vector of n elements
158      * @param transpose select if A should be transposed
159      * @param alpha scalar scale factor
160      * @returns the resulting vector
161      */
162     template <class E1, class E2, class R, class value_type = typename E1::value_type>
gemv(const xexpression<E1> & A,const xexpression<E2> & x,R & result,bool transpose_A=false,const value_type & alpha=value_type (1.0),const value_type & beta=value_type (0.0))163     void gemv(const xexpression<E1>& A, const xexpression<E2>& x,
164               R& result,
165               bool transpose_A = false,
166               const value_type& alpha = value_type(1.0),
167               const value_type& beta = value_type(0.0))
168     {
169         auto&& dA = view_eval<E1::static_layout>(A.derived_cast());
170         auto&& dx = view_eval<E2::static_layout>(x.derived_cast());
171 
172         cxxblas::gemv<blas_index_t>(
173             get_blas_storage_order(result),
174             transpose_A ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans,
175             static_cast<blas_index_t>(dA.shape()[0]),
176             static_cast<blas_index_t>(dA.shape()[1]),
177             alpha,
178             dA.data() + dA.data_offset(),
179             get_leading_stride(dA),
180             dx.data() + dx.data_offset(),
181             get_leading_stride(dx),
182             beta,
183             result.data() + result.data_offset(),
184             get_leading_stride(result)
185         );
186     }
187 
188     /**
189      * Calculate the matrix-matrix product of matrix @A and matrix @B
190      *
191      * C := alpha * A * B + beta * C
192      *
193      * @param A matrix of m-by-n elements
194      * @param B matrix of n-by-k elements
195      * @param transpose_A transpose A on the fly
196      * @param transpose_B transpose B on the fly
197      * @param alpha scale factor for A * B (defaults to 1)
198      * @param beta scale factor for C (defaults to 0)
199      */
200     template <class E, class F, class R, class value_type = typename E::value_type>
gemm(const xexpression<E> & A,const xexpression<F> & B,R & result,char transpose_A=false,char transpose_B=false,const value_type & alpha=value_type (1.0),const value_type & beta=value_type (0.0))201     void gemm(const xexpression<E>& A, const xexpression<F>& B, R& result,
202               char transpose_A = false,
203               char transpose_B = false,
204               const value_type& alpha = value_type(1.0),
205               const value_type& beta = value_type(0.0))
206     {
207         static_assert(R::static_layout != layout_type::dynamic, "GEMM result layout cannot be dynamic.");
208         auto&& dA = view_eval<R::static_layout>(A.derived_cast());
209         auto&& dB = view_eval<R::static_layout>(B.derived_cast());
210 
211         XTENSOR_ASSERT(dA.layout() == dB.layout());
212         XTENSOR_ASSERT(result.layout() == dA.layout());
213         XTENSOR_ASSERT(dA.dimension() == 2);
214         XTENSOR_ASSERT(dB.dimension() == 2);
215 
216         cxxblas::gemm<blas_index_t>(
217             get_blas_storage_order(result),
218             transpose_A ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans,
219             transpose_B ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans,
220             static_cast<blas_index_t>(transpose_A ? dA.shape()[1] : dA.shape()[0]),
221             static_cast<blas_index_t>(transpose_B ? dB.shape()[0] : dB.shape()[1]),
222             static_cast<blas_index_t>(transpose_B ? dB.shape()[1] : dB.shape()[0]),
223             alpha,
224             dA.data() + dA.data_offset(),
225             get_leading_stride(dA),
226             dB.data() + dB.data_offset(),
227             get_leading_stride(dB),
228             beta,
229             result.data() + result.data_offset(),
230             get_leading_stride(result)
231         );
232     }
233 
234     /**
235      * Calculate the outer product of vector x and y.
236      * According to A:= alpha * x * y' + A
237      *
238      * @param x vector of n elements
239      * @param y vector of m elements
240      * @param alpha scalar scale factor
241      * @returns matrix of n-by-m elements
242      */
243     template <class E1, class E2, class R, class value_type = typename E1::value_type>
ger(const xexpression<E1> & x,const xexpression<E2> & y,R & result,const value_type & alpha=value_type (1.0))244     void ger(const xexpression<E1>& x, const xexpression<E2>& y,
245              R& result,
246              const value_type& alpha = value_type(1.0))
247     {
248         auto&& dx = view_eval(x.derived_cast());
249         auto&& dy = view_eval(y.derived_cast());
250 
251         XTENSOR_ASSERT(dx.dimension() == 1);
252         XTENSOR_ASSERT(dy.dimension() == 1);
253 
254         cxxblas::ger<blas_index_t>(
255             get_blas_storage_order(result),
256             static_cast<blas_index_t>(dx.shape()[0]),
257             static_cast<blas_index_t>(dy.shape()[0]),
258             alpha,
259             dx.data() + dx.data_offset(),
260             get_leading_stride(dx),
261             dy.data() + dy.data_offset(),
262             get_leading_stride(dy),
263             result.data() + result.data_offset(),
264             get_leading_stride(result)
265         );
266     }
267 }
268 }
269 #endif
270