1 #ifndef VIENNACL_TOOLS_MATRIX_SIZE_DEDUCER_HPP_
2 #define VIENNACL_TOOLS_MATRIX_SIZE_DEDUCER_HPP_
3 
4 /* =========================================================================
5    Copyright (c) 2010-2016, Institute for Microelectronics,
6                             Institute for Analysis and Scientific Computing,
7                             TU Wien.
8    Portions of this software are copyright by UChicago Argonne, LLC.
9 
10                             -----------------
11                   ViennaCL - The Vienna Computing Library
12                             -----------------
13 
14    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
15 
16    (A list of authors and contributors can be found in the manual)
17 
18    License:         MIT (X11), see file LICENSE in the base directory
19 ============================================================================= */
20 
21 /** @file viennacl/tools/matrix_size_deducer.hpp
22     @brief Helper implementations that deduce the dimensions of the supplied matrix-valued expressions.
23 */
24 
25 #include <string>
26 #include <fstream>
27 #include <sstream>
28 #include <cmath>
29 #include <vector>
30 #include <map>
31 
32 #include "viennacl/forwards.h"
33 #include "viennacl/tools/adapter.hpp"
34 
35 namespace viennacl
36 {
37 namespace tools
38 {
39 
40 /** @brief Deduces the size of the resulting vector represented by a vector_expression from the operands
41 *
42 * @tparam LHS   The left hand side operand
43 * @tparam RHS   The right hand side operand
44 * @tparam OP    The operation tag
45 */
46 template<typename LHS, typename RHS, typename OP>
47 struct MATRIX_SIZE_DEDUCER
48 {
49   //Standard case: size1 from lhs, size2 from rhs (fits most cases)
size1viennacl::tools::MATRIX_SIZE_DEDUCER50   static vcl_size_t size1(LHS & lhs, RHS & /*rhs*/) { return lhs.size1(); }
size2viennacl::tools::MATRIX_SIZE_DEDUCER51   static vcl_size_t size2(LHS & /*lhs*/, RHS & rhs) { return rhs.size2(); }
52 };
53 
54 /** \cond */
55 //special case: outer vector product:
56 template<typename ScalarType>
57 struct MATRIX_SIZE_DEDUCER<const viennacl::vector_base<ScalarType>,
58     const viennacl::vector_base<ScalarType>,
59     viennacl::op_prod>
60 {
size1viennacl::tools::MATRIX_SIZE_DEDUCER61   static vcl_size_t size1(viennacl::vector_base<ScalarType> const & lhs,
62                           viennacl::vector_base<ScalarType> const & /*rhs*/) { return lhs.size(); }
63 
size2viennacl::tools::MATRIX_SIZE_DEDUCER64   static vcl_size_t size2(viennacl::vector_base<ScalarType> const & /*lhs*/,
65                           viennacl::vector_base<ScalarType> const & rhs) { return rhs.size(); }
66 };
67 
68 
69 //special case: multiplication with a scalar
70 template<typename LHS, typename RHS, typename OP, typename ScalarType>
71 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<const LHS, const RHS, OP>,
72     const ScalarType,
73     viennacl::op_mult>
74 {
size1viennacl::tools::MATRIX_SIZE_DEDUCER75   static vcl_size_t size1(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs,
76                           ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size1(lhs.lhs(), lhs.rhs()); }
77 
size2viennacl::tools::MATRIX_SIZE_DEDUCER78   static vcl_size_t size2(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs,
79                           ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size2(lhs.lhs(), lhs.rhs()); }
80 };
81 
82 //special case: multiplication with a scalar
83 template<typename T, typename ScalarType>
84 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<T>,
85     const ScalarType,
86     viennacl::op_mult>
87 {
size1viennacl::tools::MATRIX_SIZE_DEDUCER88   static vcl_size_t size1(viennacl::matrix_base<T> const & lhs,
89                           ScalarType const & /*rhs*/) { return lhs.size1(); }
90 
size2viennacl::tools::MATRIX_SIZE_DEDUCER91   static vcl_size_t size2(viennacl::matrix_base<T> const & lhs,
92                           ScalarType const & /*rhs*/) { return lhs.size2(); }
93 };
94 
95 
96 //special case: division with a scalar
97 template<typename LHS, typename RHS, typename OP, typename ScalarType>
98 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<const LHS, const RHS, OP>,
99     const ScalarType,
100     viennacl::op_div>
101 {
size1viennacl::tools::MATRIX_SIZE_DEDUCER102   static vcl_size_t size1(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs,
103                           ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size1(lhs.lhs(), lhs.rhs()); }
104 
size2viennacl::tools::MATRIX_SIZE_DEDUCER105   static vcl_size_t size2(viennacl::matrix_expression<const LHS, const RHS, OP> const & lhs,
106                           ScalarType const & /*rhs*/) { return MATRIX_SIZE_DEDUCER<const LHS, const RHS, OP>::size2(lhs.lhs(), lhs.rhs()); }
107 };
108 
109 //special case: division with a scalar
110 template<typename T, typename ScalarType>
111 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<T>,
112     const ScalarType,
113     viennacl::op_div>
114 {
size1viennacl::tools::MATRIX_SIZE_DEDUCER115   static vcl_size_t size1(viennacl::matrix_base<T> const & lhs,
116                           ScalarType const & /*rhs*/) { return lhs.size1(); }
117 
size2viennacl::tools::MATRIX_SIZE_DEDUCER118   static vcl_size_t size2(viennacl::matrix_base<T> const & lhs,
119                           ScalarType const & /*rhs*/) { return lhs.size2(); }
120 };
121 
122 //special case: diagonal from vector
123 template<typename T>
124 struct MATRIX_SIZE_DEDUCER<const viennacl::vector_base<T>,
125     const int,
126     viennacl::op_vector_diag>
127 {
size1viennacl::tools::MATRIX_SIZE_DEDUCER128   static vcl_size_t size1(viennacl::vector_base<T> const & lhs,
129                           const int k) { return lhs.size() + static_cast<vcl_size_t>(std::fabs(double(k))); }
130 
size2viennacl::tools::MATRIX_SIZE_DEDUCER131   static vcl_size_t size2(viennacl::vector_base<T> const & lhs,
132                           const int k) { return lhs.size() + static_cast<vcl_size_t>(std::fabs(double(k))); }
133 };
134 
135 //special case: transposed matrix-vector product: Return the number of rows of the matrix
136 template<typename MatrixType>
137 struct MATRIX_SIZE_DEDUCER<MatrixType,
138     MatrixType,
139     viennacl::op_trans>
140 {
size1viennacl::tools::MATRIX_SIZE_DEDUCER141   static vcl_size_t size1(const MatrixType & lhs,
142                           const MatrixType & /*rhs*/) { return lhs.size2(); }
size2viennacl::tools::MATRIX_SIZE_DEDUCER143   static vcl_size_t size2(const MatrixType & lhs,
144                           const MatrixType & /*rhs*/) { return lhs.size1(); }
145 };
146 
147 // A^T * B
148 template<typename ScalarType, typename T1>
149 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<T1,
150     T1, op_trans>,
151     const viennacl::matrix_base<ScalarType>,
152     viennacl::op_mat_mat_prod>
153 {
size1viennacl::tools::MATRIX_SIZE_DEDUCER154   static vcl_size_t size1(viennacl::matrix_expression<T1,
155                           T1,
156                           op_trans> const & lhs,
157                           viennacl::matrix_base<ScalarType> const & /*rhs*/) { return lhs.lhs().size2(); }
size2viennacl::tools::MATRIX_SIZE_DEDUCER158   static vcl_size_t size2(viennacl::matrix_expression<T1,
159                           T1,
160                           op_trans> const & /*lhs*/,
161                           viennacl::matrix_base<ScalarType> const & rhs) { return rhs.size2(); }
162 };
163 
164 
165 // A * B^T
166 template<typename ScalarType, typename T2>
167 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_base<ScalarType>,
168     const viennacl::matrix_expression<T2,
169     T2, op_trans>,
170     viennacl::op_mat_mat_prod>
171 {
size1viennacl::tools::MATRIX_SIZE_DEDUCER172   static vcl_size_t size1(viennacl::matrix_base<ScalarType> const & lhs,
173                           viennacl::matrix_expression<T2,
174                           T2,
175                           op_trans> const & /*rhs*/) { return lhs.size1(); }
size2viennacl::tools::MATRIX_SIZE_DEDUCER176   static vcl_size_t size2(viennacl::matrix_base<ScalarType> const & /*lhs*/,
177                           viennacl::matrix_expression<T2,
178                           T2,
179                           op_trans> const & rhs) { return rhs.lhs().size1(); }
180 };
181 
182 // A^T * B^T
183 template<typename T1, typename T2>
184 struct MATRIX_SIZE_DEDUCER<const viennacl::matrix_expression<T1,
185     T1, op_trans>,
186     const viennacl::matrix_expression<T2,
187     T2, op_trans>,
188     viennacl::op_mat_mat_prod>
189 {
190   typedef viennacl::matrix_expression<T1, T1, op_trans>   LHSType;
191   typedef viennacl::matrix_expression<T2, T2, op_trans>   RHSType;
192 
size1viennacl::tools::MATRIX_SIZE_DEDUCER193   static vcl_size_t size1(LHSType const & lhs,
194                           RHSType const & /*rhs*/) { return lhs.lhs().size2(); }
size2viennacl::tools::MATRIX_SIZE_DEDUCER195   static vcl_size_t size2(LHSType const & /*lhs*/,
196                           RHSType const & rhs) { return rhs.lhs().size1(); }
197 };
198 /** \endcond */
199 
200 }
201 }
202 
203 #endif
204 
205