1 // -*- C++ -*-
2 /***************************************************************************
3  * blitz/meta/matmat.h   TinyMatrix matrix-matrix product metaprogram
4  *
5  * $Id$
6  *
7  * Copyright (C) 1997-2011 Todd Veldhuizen <tveldhui@acm.org>
8  *
9  * This file is a part of Blitz.
10  *
11  * Blitz is free software: you can redistribute it and/or modify
12  * it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation, either version 3
14  * of the License, or (at your option) any later version.
15  *
16  * Blitz is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19  * GNU Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with Blitz.  If not, see <http://www.gnu.org/licenses/>.
23  *
24  * Suggestions:          blitz-devel@lists.sourceforge.net
25  * Bugs:                 blitz-support@lists.sourceforge.net
26  *
27  * For more information, please see the Blitz++ Home Page:
28  *    https://sourceforge.net/projects/blitz/
29  *
30  ***************************************************************************/
31 
32 #ifndef BZ_META_MATMAT_H
33 #define BZ_META_MATMAT_H
34 
35 #ifndef BZ_TINYMAT_H
36  #error <blitz/meta/matmat.h> must be included via <blitz/tinymat.h>
37 #endif
38 
39 #include <blitz/meta/metaprog.h>
40 #include <blitz/tinymatexpr.h>
41 
42 namespace blitz {
43 
44 // Template metaprogram for matrix-matrix multiplication
45 template<int N_rows1, int N_columns, int N_columns2, int N_rowStride1,
46     int N_colStride1, int N_rowStride2, int N_colStride2, int K>
47 class _bz_meta_matrixMatrixProduct {
48 public:
49     static const int go = (K != N_columns - 1) ? 1 : 0;
50 
51     template<typename T_numtype1, typename T_numtype2>
BZ_PROMOTE(T_numtype1,T_numtype2)52     static inline BZ_PROMOTE(T_numtype1, T_numtype2)
53     f(const T_numtype1* matrix1, const T_numtype2* matrix2, int i, int j)
54     {
55         return matrix1[i * N_rowStride1 + K * N_colStride1]
56             * matrix2[K * N_rowStride2 + j * N_colStride2]
57             + _bz_meta_matrixMatrixProduct<N_rows1 * go, N_columns * go,
58                 N_columns2 * go, N_rowStride1 * go, N_colStride1 * go,
59                 N_rowStride2 * go, N_colStride2 * go, (K+1) * go>
60               ::f(matrix1, matrix2, i, j);
61     }
62 };
63 
64 template<>
65 class _bz_meta_matrixMatrixProduct<0,0,0,0,0,0,0,0> {
66 public:
f(const void *,const void *,int,int)67     static inline _bz_meta_nullOperand f(const void*, const void*, int, int)
68     { return _bz_meta_nullOperand(); }
69 };
70 
71 
72 
73 
74 template<typename T_numtype1, typename T_numtype2, int N_rows1, int N_columns,
75     int N_columns2, int N_rowStride1, int N_colStride1,
76     int N_rowStride2, int N_colStride2>
77 class _bz_tinyMatrixMatrixProduct {
78 public:
79     typedef BZ_PROMOTE(T_numtype1, T_numtype2) T_numtype;
80 
81     static const int rows = N_rows1, columns = N_columns2;
82 
_bz_tinyMatrixMatrixProduct(const T_numtype1 * matrix1,const T_numtype2 * matrix2)83     _bz_tinyMatrixMatrixProduct(const T_numtype1* matrix1,
84         const T_numtype2* matrix2)
85         : matrix1_(matrix1), matrix2_(matrix2)
86     { }
87 
_bz_tinyMatrixMatrixProduct(const _bz_tinyMatrixMatrixProduct<T_numtype1,T_numtype2,N_rows1,N_columns,N_columns2,N_rowStride1,N_colStride1,N_rowStride2,N_colStride2> & x)88     _bz_tinyMatrixMatrixProduct(const _bz_tinyMatrixMatrixProduct<T_numtype1,
89         T_numtype2, N_rows1, N_columns, N_columns2, N_rowStride1, N_colStride1,
90         N_rowStride2, N_colStride2>& x)
91         : matrix1_(x.matrix1_), matrix2_(x.matrix2_)
92     { }
93 
matrix1()94     const T_numtype1* matrix1() const
95     { return matrix1_; }
96 
matrix2()97     const T_numtype2* matrix2() const
98     { return matrix2_; }
99 
operator()100     T_numtype operator()(int i, int j) const
101     {
102         return _bz_meta_matrixMatrixProduct<N_rows1, N_columns,
103             N_columns2, N_rowStride1, N_colStride1, N_rowStride2,
104             N_colStride2, 0>::f(matrix1_, matrix2_, i, j);
105     }
106 
107 protected:
108     const T_numtype1* matrix1_;
109     const T_numtype2* matrix2_;
110 };
111 
112 template<typename T_numtype1, typename T_numtype2, int N_rows1, int N_columns1,
113     int N_columns2>
114 inline
115 _bz_tinyMatExpr<_bz_tinyMatrixMatrixProduct<T_numtype1, T_numtype2, N_rows1,
116     N_columns1, N_columns2, N_columns1, 1, N_columns2, 1> >
product(const TinyMatrix<T_numtype1,N_rows1,N_columns1> & a,const TinyMatrix<T_numtype2,N_columns1,N_columns2> & b)117 product(const TinyMatrix<T_numtype1, N_rows1, N_columns1>& a,
118     const TinyMatrix<T_numtype2, N_columns1, N_columns2>& b)
119 {
120     typedef _bz_tinyMatrixMatrixProduct<T_numtype1, T_numtype2,
121         N_rows1, N_columns1, N_columns2, N_columns1, 1, N_columns2, 1> T_expr;
122     return _bz_tinyMatExpr<T_expr>(T_expr(a.data(), b.data()));
123 }
124 
125 }
126 
127 #endif // BZ_META_MATMAT_H
128 
129