1 /*************************************************************************** 2 * Copyright (c) Wolf Vollprecht, Johan Mabille and Sylvain Corlay * 3 * Copyright (c) QuantStack * 4 * * 5 * Distributed under the terms of the BSD 3-Clause License. * 6 * * 7 * The full license is in the file LICENSE, distributed with this software. * 8 ****************************************************************************/ 9 10 #ifndef XBLAS_HPP 11 #define XBLAS_HPP 12 13 #include <algorithm> 14 15 #include "xtensor/xarray.hpp" 16 #include "xtensor/xcomplex.hpp" 17 #include "xtensor/xio.hpp" 18 #include "xtensor/xtensor.hpp" 19 #include "xtensor/xutils.hpp" 20 21 #include "xtensor-blas/xblas_config.hpp" 22 #include "xtensor-blas/xblas_utils.hpp" 23 24 #include "xflens/cxxblas/cxxblas.cxx" 25 26 namespace xt 27 { 28 29 namespace blas 30 { 31 /** 32 * Calculate the 1-norm of a vector 33 * 34 * @param a vector of n elements 35 * @returns scalar result 36 */ 37 template <class E, class R> asum(const xexpression<E> & a,R & result)38 void asum(const xexpression<E>& a, R& result) 39 { 40 auto&& ad = view_eval<E::static_layout>(a.derived_cast()); 41 XTENSOR_ASSERT(ad.dimension() == 1); 42 43 cxxblas::asum<blas_index_t>( 44 static_cast<blas_index_t>(ad.shape()[0]), 45 ad.data() + ad.data_offset(), 46 stride_front(ad), 47 result 48 ); 49 } 50 51 /** 52 * Calculate the 2-norm of a vector 53 * 54 * @param a vector of n elements 55 * @returns scalar result 56 */ 57 template <class E, class R> nrm2(const xexpression<E> & a,R & result)58 void nrm2(const xexpression<E>& a, R& result) 59 { 60 auto&& ad = view_eval<E::static_layout>(a.derived_cast()); 61 XTENSOR_ASSERT(ad.dimension() == 1); 62 63 cxxblas::nrm2<blas_index_t>( 64 static_cast<blas_index_t>(ad.shape()[0]), 65 ad.data() + ad.data_offset(), 66 stride_front(ad), 67 result 68 ); 69 } 70 71 /** 72 * Calculate the dot product between two vectors, conjugating 73 * the first argument \em a in the case of complex vectors. 74 * 75 * @param a vector of n elements 76 * @param b vector of n elements 77 * @returns scalar result 78 */ 79 template <class E1, class E2, class R> dot(const xexpression<E1> & a,const xexpression<E2> & b,R & result)80 void dot(const xexpression<E1>& a, const xexpression<E2>& b, 81 R& result) 82 { 83 auto&& ad = view_eval<E1::static_layout>(a.derived_cast()); 84 auto&& bd = view_eval<E2::static_layout>(b.derived_cast()); 85 XTENSOR_ASSERT(ad.dimension() == 1); 86 87 blas_index_t stride_a = stride_front(ad); 88 blas_index_t stride_b = stride_front(bd); 89 90 auto* adt = ad.data() + ad.data_offset(); 91 auto* bdt = bd.data() + bd.data_offset(); 92 93 // we need to have a pointer that points to the "real" start of the memory 94 // not to the first element (BLAS is doing that transformation itself) 95 if (stride_a < 0) { 96 adt += (static_cast<blas_index_t>(ad.shape()[0]) - 1) * stride_a; // go back to the start 97 } 98 if (stride_b < 0) { 99 bdt += (static_cast<blas_index_t>(ad.shape()[0]) - 1) * stride_b; // go back to the start 100 } 101 102 cxxblas::dot<blas_index_t>( 103 static_cast<blas_index_t>(ad.shape()[0]), 104 adt, 105 stride_a, 106 bdt, 107 stride_b, 108 result 109 ); 110 } 111 112 /** 113 * Calculate the dot product between two complex vectors, not conjugating the 114 * first argument \em a. 115 * 116 * @param a vector of n elements 117 * @param b vector of n elements 118 * @returns scalar result 119 */ 120 template <class E1, class E2, class R> dotu(const xexpression<E1> & a,const xexpression<E2> & b,R & result)121 void dotu(const xexpression<E1>& a, const xexpression<E2>& b, R& result) 122 { 123 auto&& ad = view_eval<E1::static_layout>(a.derived_cast()); 124 auto&& bd = view_eval<E2::static_layout>(b.derived_cast()); 125 XTENSOR_ASSERT(ad.dimension() == 1); 126 127 blas_index_t stride_a = stride_front(ad); 128 blas_index_t stride_b = stride_front(bd); 129 130 auto* adt = ad.data() + ad.data_offset(); 131 auto* bdt = bd.data() + bd.data_offset(); 132 133 // we need to have a pointer that points to the "real" start of the memory 134 // not to the first element (BLAS is doing that transformation itself) 135 if (stride_a < 0) { 136 adt += (static_cast<blas_index_t>(ad.shape()[0]) - 1) * stride_a; // go back to the start 137 } 138 if (stride_b < 0) { 139 bdt += (static_cast<blas_index_t>(ad.shape()[0]) - 1) * stride_b; // go back to the start 140 } 141 142 cxxblas::dotu<blas_index_t>( 143 static_cast<blas_index_t>(ad.shape()[0]), 144 adt, 145 stride_a, 146 bdt, 147 stride_b, 148 result 149 ); 150 } 151 152 /** 153 * Calculate the general matrix times vector product according to 154 * ``y := alpha * A * x + beta * y``. 155 * 156 * @param A matrix of n x m elements 157 * @param x vector of n elements 158 * @param transpose select if A should be transposed 159 * @param alpha scalar scale factor 160 * @returns the resulting vector 161 */ 162 template <class E1, class E2, class R, class value_type = typename E1::value_type> gemv(const xexpression<E1> & A,const xexpression<E2> & x,R & result,bool transpose_A=false,const value_type & alpha=value_type (1.0),const value_type & beta=value_type (0.0))163 void gemv(const xexpression<E1>& A, const xexpression<E2>& x, 164 R& result, 165 bool transpose_A = false, 166 const value_type& alpha = value_type(1.0), 167 const value_type& beta = value_type(0.0)) 168 { 169 auto&& dA = view_eval<E1::static_layout>(A.derived_cast()); 170 auto&& dx = view_eval<E2::static_layout>(x.derived_cast()); 171 172 cxxblas::gemv<blas_index_t>( 173 get_blas_storage_order(result), 174 transpose_A ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans, 175 static_cast<blas_index_t>(dA.shape()[0]), 176 static_cast<blas_index_t>(dA.shape()[1]), 177 alpha, 178 dA.data() + dA.data_offset(), 179 get_leading_stride(dA), 180 dx.data() + dx.data_offset(), 181 get_leading_stride(dx), 182 beta, 183 result.data() + result.data_offset(), 184 get_leading_stride(result) 185 ); 186 } 187 188 /** 189 * Calculate the matrix-matrix product of matrix @A and matrix @B 190 * 191 * C := alpha * A * B + beta * C 192 * 193 * @param A matrix of m-by-n elements 194 * @param B matrix of n-by-k elements 195 * @param transpose_A transpose A on the fly 196 * @param transpose_B transpose B on the fly 197 * @param alpha scale factor for A * B (defaults to 1) 198 * @param beta scale factor for C (defaults to 0) 199 */ 200 template <class E, class F, class R, class value_type = typename E::value_type> gemm(const xexpression<E> & A,const xexpression<F> & B,R & result,char transpose_A=false,char transpose_B=false,const value_type & alpha=value_type (1.0),const value_type & beta=value_type (0.0))201 void gemm(const xexpression<E>& A, const xexpression<F>& B, R& result, 202 char transpose_A = false, 203 char transpose_B = false, 204 const value_type& alpha = value_type(1.0), 205 const value_type& beta = value_type(0.0)) 206 { 207 static_assert(R::static_layout != layout_type::dynamic, "GEMM result layout cannot be dynamic."); 208 auto&& dA = view_eval<R::static_layout>(A.derived_cast()); 209 auto&& dB = view_eval<R::static_layout>(B.derived_cast()); 210 211 XTENSOR_ASSERT(dA.layout() == dB.layout()); 212 XTENSOR_ASSERT(result.layout() == dA.layout()); 213 XTENSOR_ASSERT(dA.dimension() == 2); 214 XTENSOR_ASSERT(dB.dimension() == 2); 215 216 cxxblas::gemm<blas_index_t>( 217 get_blas_storage_order(result), 218 transpose_A ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans, 219 transpose_B ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans, 220 static_cast<blas_index_t>(transpose_A ? dA.shape()[1] : dA.shape()[0]), 221 static_cast<blas_index_t>(transpose_B ? dB.shape()[0] : dB.shape()[1]), 222 static_cast<blas_index_t>(transpose_B ? dB.shape()[1] : dB.shape()[0]), 223 alpha, 224 dA.data() + dA.data_offset(), 225 get_leading_stride(dA), 226 dB.data() + dB.data_offset(), 227 get_leading_stride(dB), 228 beta, 229 result.data() + result.data_offset(), 230 get_leading_stride(result) 231 ); 232 } 233 234 /** 235 * Calculate the outer product of vector x and y. 236 * According to A:= alpha * x * y' + A 237 * 238 * @param x vector of n elements 239 * @param y vector of m elements 240 * @param alpha scalar scale factor 241 * @returns matrix of n-by-m elements 242 */ 243 template <class E1, class E2, class R, class value_type = typename E1::value_type> ger(const xexpression<E1> & x,const xexpression<E2> & y,R & result,const value_type & alpha=value_type (1.0))244 void ger(const xexpression<E1>& x, const xexpression<E2>& y, 245 R& result, 246 const value_type& alpha = value_type(1.0)) 247 { 248 auto&& dx = view_eval(x.derived_cast()); 249 auto&& dy = view_eval(y.derived_cast()); 250 251 XTENSOR_ASSERT(dx.dimension() == 1); 252 XTENSOR_ASSERT(dy.dimension() == 1); 253 254 cxxblas::ger<blas_index_t>( 255 get_blas_storage_order(result), 256 static_cast<blas_index_t>(dx.shape()[0]), 257 static_cast<blas_index_t>(dy.shape()[0]), 258 alpha, 259 dx.data() + dx.data_offset(), 260 get_leading_stride(dx), 261 dy.data() + dy.data_offset(), 262 get_leading_stride(dy), 263 result.data() + result.data_offset(), 264 get_leading_stride(result) 265 ); 266 } 267 } 268 } 269 #endif 270