1 // Copyright 2015-2017 the openage authors. See copying.md for legal info.
2 
3 #pragma once
4 
5 #include <array>
6 #include <cmath>
7 #include <cstring>
8 #include <iostream>
9 #include <type_traits>
10 
11 #include "vector.h"
12 
13 namespace openage {
14 namespace util {
15 
16 /**
17  * Matrix class with arithmetic. M rows, N columns.
18  */
19 template<size_t M, size_t N>
20 class Matrix : public std::array<std::array<float, N>, M> {
21 public:
22 	static_assert(M > 0 and N > 0, "0-dimensional matrix not allowed");
23 
24 	static constexpr float default_eps = 1e-5;
25 
26 	static constexpr size_t rows = M;
27 	static constexpr size_t cols = N;
28 	static constexpr bool is_square = (M == N);
29 	static constexpr bool is_row_vector = (M == 1);
30 	static constexpr bool is_column_vector = (N == 1);
31 
32 	/**
33 	 * Initialize the matrix to zeroes.
34 	 */
Matrix()35 	Matrix() {
36 		for (size_t i = 0; i < M; i++) {
37 			for (size_t j = 0; j < N; j++) {
38 				(*this)[i][j] = 0;
39 			}
40 		}
41 	}
42 
43 	~Matrix() = default;
44 
45 	/**
46 	 * Constructor from Vector
47 	 */
48 	template <bool cond=is_column_vector,
49 	          typename=typename std::enable_if<cond>::type>
Matrix(const Vector<M> & vec)50 	Matrix(const Vector<M> &vec) {
51 		for (size_t i = 0; i < M; i++) {
52 			(*this)[i][0] = vec[i];
53 		}
54 	}
55 
56 	/**
57 	 * Constructor with N*M values
58 	 */
59 	template <typename ... T>
Matrix(T...args)60 	Matrix(T ... args) {
61 		static_assert(sizeof...(args) == N*M, "not all values supplied");
62 
63 		std::array<float, N*M> temp{{static_cast<float>(args)...}};
64 		for (size_t i = 0; i < N*M; i++) {
65 			(*this)[i / (N*M)][i % (N*M)] = temp[i];
66 		}
67 	}
68 
69 	/**
70 	 * Constructs the identity matrix for the current size.
71 	 */
72 	template <bool cond=is_square,
73 	          typename=typename std::enable_if<cond>::type>
identity()74 	static Matrix<N, M> identity() {
75 		Matrix<N, M> res;
76 
77 		for (size_t i = 0; i < N; i++) {
78 			res[i][i] = 1;
79 		}
80 
81 		return res;
82 	}
83 
84 	/**
85 	 * Test if both matrices contain the same values within epsilon.
86 	 */
87 	bool equals(const Matrix<N, M> &other, float eps=default_eps) const {
88 		for (size_t i = 0; i < N; i++) {
89 			for (size_t j = 0; j < M; j++) {
90 				if (std::abs((*this)[i][j] - other[i][j]) >= eps) {
91 					return false;
92 				}
93 			}
94 		}
95 		return true;
96 	}
97 
98 	/**
99 	 * Matrix multiplication
100 	 */
101 	template <size_t P>
102 	Matrix<M, P> operator *(const Matrix<N, P> &other) const {
103 		Matrix<M, P> res;
104 		for (size_t i = 0; i < M; i++) {
105 			for (size_t j = 0; j < P; j++) {
106 				res[i][j] = 0;
107 				for (size_t k = 0; k < N; k++) {
108 					res[i][j] += (*this)[i][k] * other[k][j];
109 				}
110 			}
111 		}
112 		return res;
113 	}
114 
115 	/**
116 	 * Matrix-Vector multiplication
117 	 */
118 	Matrix <M, 1> operator *(const Vector<M> &vec) const {
119 		return (*this) * static_cast<Matrix<M, 1>>(vec);
120 	}
121 
122 	/**
123 	 * Matrix addition
124 	 */
125 	Matrix<M, N> operator +(const Matrix<M, N> &other) const {
126 		Matrix<M, N> res;
127 		for (size_t i = 0; i < M; i++) {
128 			for (size_t j = 0; j < N; j++) {
129 				res[i][j] = (*this)[i][j] + other[i][j];
130 			}
131 		}
132 		return res;
133 	}
134 
135 	/**
136 	 * Matrix subtraction
137 	 */
138 	Matrix<M, N> operator -(const Matrix<M, N> &other) const {
139 		Matrix<M, N> res;
140 		for (size_t i = 0; i < M; i++) {
141 			for (size_t j = 0; j < N; j++) {
142 				res[i][j] = (*this)[i][j] - other[i][j];
143 			}
144 		}
145 		return res;
146 	}
147 
148 	/**
149 	 * Scalar multiplication with assignment
150 	 */
151 	void operator *=(float other) {
152 		for (size_t i = 0; i < M; i++) {
153 			for (size_t j = 0; j < N; j++) {
154 				(*this)[i][j] *= other;
155 			}
156 		}
157 	}
158 
159 	/**
160 	 * Scalar multiplication
161 	 */
162 	Matrix<M, N> operator *(float other) const {
163 		Matrix<M, N> res(*this);
164 		res *= other;
165 		return res;
166 	}
167 
168 	/**
169 	 * Scalar division with assignment
170 	 */
171 	void operator /=(float other) {
172 		for (size_t i = 0; i < M; i++) {
173 			for (size_t j = 0; j < N; j++) {
174 				(*this)[i][j] /= other;
175 			}
176 		}
177 	}
178 
179 	/**
180 	 * Scalar division
181 	 */
182 	Matrix<M, N> operator /(float other) const {
183 		Matrix<M, N> res(*this);
184 		res /= other;
185 		return res;
186 	}
187 
188 	/**
189 	 * Transposition
190 	 */
transpose()191 	Matrix<N, M> transpose() const {
192 		Matrix <N, M> res;
193 		for (size_t i = 0; i < M; i++) {
194 			for (size_t j = 0; j < N; j++) {
195 				res[j][i] = (*this)[i][j];
196 			}
197 		}
198 		return res;
199 	}
200 
201 	/**
202 	 * Conversion to Vector
203 	 */
204 	template<bool cond=is_column_vector,
205 	         typename=typename std::enable_if<cond>::type>
to_vector()206 	Vector<M> to_vector() const {
207 		auto res = Vector<M>();
208 		for (size_t i = 0; i < M; i++) {
209 			res[i] = (*this)[i][0];
210 		}
211 		return res;
212 	}
213 
214 	/**
215 	 * Matrix trace: the sum of all diagonal entries
216 	 */
217 	template<bool cond=is_square,
218 	         typename=typename std::enable_if<cond>::type>
trace()219 	float trace() const {
220 		float res = 0.0f;
221 
222 		for (size_t i = 0; i < N; i++) {
223 			res += (*this)[i][i];
224 		}
225 
226 		return res;
227 	}
228 
229 	/**
230 	 * Print to output stream using '<<'
231 	 */
232 	friend std::ostream &operator <<(std::ostream &o,
233 	                                 const Matrix<M, N> &mat) {
234 		o << "(";
235 		for (size_t j = 0; j < M-1; j++) {
236 			o << "(";
237 			for (size_t i = 0; i < N-1; i++) {
238 				o << mat[j][i] << ",\t";
239 			}
240 			o << mat[j][N-1] << ")";
241 			o << "," << std::endl << " ";
242 		}
243 		o << "(";
244 		for (size_t i = 0; i < N-1; i++) {
245 			o << mat[M-1][i] << ",\t";
246 		}
247 		o << mat[M-1][N-1] << "))";
248 		return o;
249 	}
250 };
251 
252 /**
253  * Scalar multiplication with swapped arguments
254  */
255 template<size_t M, size_t N>
256 Matrix<M, N> operator *(float a, const Matrix<M, N> &mat) {
257 	return mat * a;
258 }
259 
260 using Matrix2 = Matrix<2, 2>;
261 using Matrix3 = Matrix<3, 3>;
262 using Matrix4 = Matrix<4, 4>;
263 
264 }} // openage::util
265