1 // -*- C++ -*-
2 /***************************************************************************
3 * blitz/meta/matmat.h TinyMatrix matrix-matrix product metaprogram
4 *
5 * $Id$
6 *
7 * Copyright (C) 1997-2011 Todd Veldhuizen <tveldhui@acm.org>
8 *
9 * This file is a part of Blitz.
10 *
11 * Blitz is free software: you can redistribute it and/or modify
12 * it under the terms of the GNU Lesser General Public License
13 * as published by the Free Software Foundation, either version 3
14 * of the License, or (at your option) any later version.
15 *
16 * Blitz is distributed in the hope that it will be useful,
17 * but WITHOUT ANY WARRANTY; without even the implied warranty of
18 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 * GNU Lesser General Public License for more details.
20 *
21 * You should have received a copy of the GNU Lesser General Public
22 * License along with Blitz. If not, see <http://www.gnu.org/licenses/>.
23 *
24 * Suggestions: blitz-devel@lists.sourceforge.net
25 * Bugs: blitz-support@lists.sourceforge.net
26 *
27 * For more information, please see the Blitz++ Home Page:
28 * https://sourceforge.net/projects/blitz/
29 *
30 ***************************************************************************/
31
32 #ifndef BZ_META_MATMAT_H
33 #define BZ_META_MATMAT_H
34
35 #ifndef BZ_TINYMAT_H
36 #error <blitz/meta/matmat.h> must be included via <blitz/tinymat.h>
37 #endif
38
39 #include <blitz/meta/metaprog.h>
40 #include <blitz/tinymatexpr.h>
41
42 namespace blitz {
43
44 // Template metaprogram for matrix-matrix multiplication
45 template<int N_rows1, int N_columns, int N_columns2, int N_rowStride1,
46 int N_colStride1, int N_rowStride2, int N_colStride2, int K>
47 class _bz_meta_matrixMatrixProduct {
48 public:
49 static const int go = (K != N_columns - 1) ? 1 : 0;
50
51 template<typename T_numtype1, typename T_numtype2>
BZ_PROMOTE(T_numtype1,T_numtype2)52 static inline BZ_PROMOTE(T_numtype1, T_numtype2)
53 f(const T_numtype1* matrix1, const T_numtype2* matrix2, int i, int j)
54 {
55 return matrix1[i * N_rowStride1 + K * N_colStride1]
56 * matrix2[K * N_rowStride2 + j * N_colStride2]
57 + _bz_meta_matrixMatrixProduct<N_rows1 * go, N_columns * go,
58 N_columns2 * go, N_rowStride1 * go, N_colStride1 * go,
59 N_rowStride2 * go, N_colStride2 * go, (K+1) * go>
60 ::f(matrix1, matrix2, i, j);
61 }
62 };
63
64 template<>
65 class _bz_meta_matrixMatrixProduct<0,0,0,0,0,0,0,0> {
66 public:
f(const void *,const void *,int,int)67 static inline _bz_meta_nullOperand f(const void*, const void*, int, int)
68 { return _bz_meta_nullOperand(); }
69 };
70
71
72
73
74 template<typename T_numtype1, typename T_numtype2, int N_rows1, int N_columns,
75 int N_columns2, int N_rowStride1, int N_colStride1,
76 int N_rowStride2, int N_colStride2>
77 class _bz_tinyMatrixMatrixProduct {
78 public:
79 typedef BZ_PROMOTE(T_numtype1, T_numtype2) T_numtype;
80
81 static const int rows = N_rows1, columns = N_columns2;
82
_bz_tinyMatrixMatrixProduct(const T_numtype1 * matrix1,const T_numtype2 * matrix2)83 _bz_tinyMatrixMatrixProduct(const T_numtype1* matrix1,
84 const T_numtype2* matrix2)
85 : matrix1_(matrix1), matrix2_(matrix2)
86 { }
87
_bz_tinyMatrixMatrixProduct(const _bz_tinyMatrixMatrixProduct<T_numtype1,T_numtype2,N_rows1,N_columns,N_columns2,N_rowStride1,N_colStride1,N_rowStride2,N_colStride2> & x)88 _bz_tinyMatrixMatrixProduct(const _bz_tinyMatrixMatrixProduct<T_numtype1,
89 T_numtype2, N_rows1, N_columns, N_columns2, N_rowStride1, N_colStride1,
90 N_rowStride2, N_colStride2>& x)
91 : matrix1_(x.matrix1_), matrix2_(x.matrix2_)
92 { }
93
matrix1()94 const T_numtype1* matrix1() const
95 { return matrix1_; }
96
matrix2()97 const T_numtype2* matrix2() const
98 { return matrix2_; }
99
operator()100 T_numtype operator()(int i, int j) const
101 {
102 return _bz_meta_matrixMatrixProduct<N_rows1, N_columns,
103 N_columns2, N_rowStride1, N_colStride1, N_rowStride2,
104 N_colStride2, 0>::f(matrix1_, matrix2_, i, j);
105 }
106
107 protected:
108 const T_numtype1* matrix1_;
109 const T_numtype2* matrix2_;
110 };
111
112 template<typename T_numtype1, typename T_numtype2, int N_rows1, int N_columns1,
113 int N_columns2>
114 inline
115 _bz_tinyMatExpr<_bz_tinyMatrixMatrixProduct<T_numtype1, T_numtype2, N_rows1,
116 N_columns1, N_columns2, N_columns1, 1, N_columns2, 1> >
product(const TinyMatrix<T_numtype1,N_rows1,N_columns1> & a,const TinyMatrix<T_numtype2,N_columns1,N_columns2> & b)117 product(const TinyMatrix<T_numtype1, N_rows1, N_columns1>& a,
118 const TinyMatrix<T_numtype2, N_columns1, N_columns2>& b)
119 {
120 typedef _bz_tinyMatrixMatrixProduct<T_numtype1, T_numtype2,
121 N_rows1, N_columns1, N_columns2, N_columns1, 1, N_columns2, 1> T_expr;
122 return _bz_tinyMatExpr<T_expr>(T_expr(a.data(), b.data()));
123 }
124
125 }
126
127 #endif // BZ_META_MATMAT_H
128
129