1 #ifndef VIENNACL_TOOLS_MATRIX_SIZE_DEDUCER_HPP_ 2 #define VIENNACL_TOOLS_MATRIX_SIZE_DEDUCER_HPP_ 3 4 /* ========================================================================= 5 Copyright (c) 2010-2016, Institute for Microelectronics, 6 Institute for Analysis and Scientific Computing, 7 TU Wien. 8 Portions of this software are copyright by UChicago Argonne, LLC. 9 10 ----------------- 11 ViennaCL - The Vienna Computing Library 12 ----------------- 13 14 Project Head: Karl Rupp rupp@iue.tuwien.ac.at 15 16 (A list of authors and contributors can be found in the manual) 17 18 License: MIT (X11), see file LICENSE in the base directory 19 ============================================================================= */ 20 21 /** @file viennacl/tools/matrix_size_deducer.hpp 22 @brief Helper implementations that deduce the dimensions of the supplied matrix-valued expressions. 23 */ 24 25 #include <string> 26 #include <fstream> 27 #include <sstream> 28 #include <cmath> 29 #include <vector> 30 #include <map> 31 32 #include "viennacl/forwards.h" 33 #include "viennacl/tools/adapter.hpp" 34 35 namespace viennacl 36 { 37 namespace tools 38 { 39 40 /** @brief Deduces the size of the resulting vector represented by a vector_expression from the operands 41 * 42 * @tparam LHS The left hand side operand 43 * @tparam RHS The right hand side operand 44 * @tparam OP The operation tag 45 */ 46 template<typename LHS, typename RHS, typename OP> 47 struct MATRIX_SIZE_DEDUCER 48 { 49 //Standard case: size1 from lhs, size2 from rhs (fits most cases) size1viennacl::tools::MATRIX_SIZE_DEDUCER50 static vcl_size_t size1(LHS & lhs, RHS & /*rhs*/) { return lhs.size1(); } size2viennacl::tools::MATRIX_SIZE_DEDUCER51 static vcl_size_t size2(LHS & /*lhs*/, RHS & rhs) { return rhs.size2(); } 52 }; 53 54 /** \cond */ 55 //special case: outer vector product: 56 template<typename ScalarType> 57 struct MATRIX_SIZE_DEDUCER<const viennacl::vector_base<ScalarType>, 58 const viennacl::vector_base<ScalarType>, 59 viennacl::op_prod> 60 { size1viennacl::tools::MATRIX_SIZE_DEDUCER61 static vcl_size_t size1(viennacl::vector_base<ScalarType> const & lhs, 62 viennacl::vector_base<ScalarType> const & /*rhs*/) { return lhs.size(); } 63 size2viennacl::tools::MATRIX_SIZE_DEDUCER64 static vcl_size_t size2(viennacl::vector_base<ScalarType> const & /*lhs*/, 65 viennacl::vector_base<ScalarType> const & rhs) { return rhs.size(); } 66 }; 67 68 69 //special case: multiplication with a scalar 70 template<typename LHS, typename RHS, typename OP, typename ScalarType> 71 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<const LHS, const RHS, OP>, 72 const ScalarType, 73 viennacl::op_mult> 74 { size1viennacl::tools::MATRIX_SIZE_DEDUCER75 static vcl_size_t size1(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs, 76 ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size1(lhs.lhs(), lhs.rhs()); } 77 size2viennacl::tools::MATRIX_SIZE_DEDUCER78 static vcl_size_t size2(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs, 79 ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size2(lhs.lhs(), lhs.rhs()); } 80 }; 81 82 //special case: multiplication with a scalar 83 template<typename T, typename ScalarType> 84 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<T>, 85 const ScalarType, 86 viennacl::op_mult> 87 { size1viennacl::tools::MATRIX_SIZE_DEDUCER88 static vcl_size_t size1(viennacl::matrix_base<T> const & lhs, 89 ScalarType const & /*rhs*/) { return lhs.size1(); } 90 size2viennacl::tools::MATRIX_SIZE_DEDUCER91 static vcl_size_t size2(viennacl::matrix_base<T> const & lhs, 92 ScalarType const & /*rhs*/) { return lhs.size2(); } 93 }; 94 95 96 //special case: division with a scalar 97 template<typename LHS, typename RHS, typename OP, typename ScalarType> 98 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<const LHS, const RHS, OP>, 99 const ScalarType, 100 viennacl::op_div> 101 { size1viennacl::tools::MATRIX_SIZE_DEDUCER102 static vcl_size_t size1(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs, 103 ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size1(lhs.lhs(), lhs.rhs()); } 104 size2viennacl::tools::MATRIX_SIZE_DEDUCER105 static vcl_size_t size2(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs, 106 ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size2(lhs.lhs(), lhs.rhs()); } 107 }; 108 109 //special case: division with a scalar 110 template<typename T, typename ScalarType> 111 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<T>, 112 const ScalarType, 113 viennacl::op_div> 114 { size1viennacl::tools::MATRIX_SIZE_DEDUCER115 static vcl_size_t size1(viennacl::matrix_base<T> const & lhs, 116 ScalarType const & /*rhs*/) { return lhs.size1(); } 117 size2viennacl::tools::MATRIX_SIZE_DEDUCER118 static vcl_size_t size2(viennacl::matrix_base<T> const & lhs, 119 ScalarType const & /*rhs*/) { return lhs.size2(); } 120 }; 121 122 //special case: diagonal from vector 123 template<typename T> 124 struct MATRIX_SIZE_DEDUCER<const viennacl::vector_base<T>, 125 const int, 126 viennacl::op_vector_diag> 127 { size1viennacl::tools::MATRIX_SIZE_DEDUCER128 static vcl_size_t size1(viennacl::vector_base<T> const & lhs, 129 const int k) { return lhs.size() + static_cast<vcl_size_t>(std::fabs(double(k))); } 130 size2viennacl::tools::MATRIX_SIZE_DEDUCER131 static vcl_size_t size2(viennacl::vector_base<T> const & lhs, 132 const int k) { return lhs.size() + static_cast<vcl_size_t>(std::fabs(double(k))); } 133 }; 134 135 //special case: transposed matrix-vector product: Return the number of rows of the matrix 136 template<typename MatrixType> 137 struct MATRIX_SIZE_DEDUCER<MatrixType, 138 MatrixType, 139 viennacl::op_trans> 140 { size1viennacl::tools::MATRIX_SIZE_DEDUCER141 static vcl_size_t size1(const MatrixType & lhs, 142 const MatrixType & /*rhs*/) { return lhs.size2(); } size2viennacl::tools::MATRIX_SIZE_DEDUCER143 static vcl_size_t size2(const MatrixType & lhs, 144 const MatrixType & /*rhs*/) { return lhs.size1(); } 145 }; 146 147 // A^T * B 148 template<typename ScalarType, typename T1> 149 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<T1, 150 T1, op_trans>, 151 const viennacl::matrix_base<ScalarType>, 152 viennacl::op_mat_mat_prod> 153 { size1viennacl::tools::MATRIX_SIZE_DEDUCER154 static vcl_size_t size1(viennacl::matrix_expression<T1, 155 T1, 156 op_trans> const & lhs, 157 viennacl::matrix_base<ScalarType> const & /*rhs*/) { return lhs.lhs().size2(); } size2viennacl::tools::MATRIX_SIZE_DEDUCER158 static vcl_size_t size2(viennacl::matrix_expression<T1, 159 T1, 160 op_trans> const & /*lhs*/, 161 viennacl::matrix_base<ScalarType> const & rhs) { return rhs.size2(); } 162 }; 163 164 165 // A * B^T 166 template<typename ScalarType, typename T2> 167 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<ScalarType>, 168 const viennacl::matrix_expression<T2, 169 T2, op_trans>, 170 viennacl::op_mat_mat_prod> 171 { size1viennacl::tools::MATRIX_SIZE_DEDUCER172 static vcl_size_t size1(viennacl::matrix_base<ScalarType> const & lhs, 173 viennacl::matrix_expression<T2, 174 T2, 175 op_trans> const & /*rhs*/) { return lhs.size1(); } size2viennacl::tools::MATRIX_SIZE_DEDUCER176 static vcl_size_t size2(viennacl::matrix_base<ScalarType> const & /*lhs*/, 177 viennacl::matrix_expression<T2, 178 T2, 179 op_trans> const & rhs) { return rhs.lhs().size1(); } 180 }; 181 182 // A^T * B^T 183 template<typename T1, typename T2> 184 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<T1, 185 T1, op_trans>, 186 const viennacl::matrix_expression<T2, 187 T2, op_trans>, 188 viennacl::op_mat_mat_prod> 189 { 190 typedef viennacl::matrix_expression<T1, T1, op_trans> LHSType; 191 typedef viennacl::matrix_expression<T2, T2, op_trans> RHSType; 192 size1viennacl::tools::MATRIX_SIZE_DEDUCER193 static vcl_size_t size1(LHSType const & lhs, 194 RHSType const & /*rhs*/) { return lhs.lhs().size2(); } size2viennacl::tools::MATRIX_SIZE_DEDUCER195 static vcl_size_t size2(LHSType const & /*lhs*/, 196 RHSType const & rhs) { return rhs.lhs().size1(); } 197 }; 198 /** \endcond */ 199 200 } 201 } 202 203 #endif 204 205