1 //===- Math.h - PBQP Vector and Matrix classes ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_CODEGEN_PBQP_MATH_H 10 #define LLVM_CODEGEN_PBQP_MATH_H 11 12 #include "llvm/ADT/Hashing.h" 13 #include "llvm/ADT/STLExtras.h" 14 #include <algorithm> 15 #include <cassert> 16 #include <functional> 17 #include <memory> 18 19 namespace llvm { 20 namespace PBQP { 21 22 using PBQPNum = float; 23 24 /// PBQP Vector class. 25 class Vector { 26 friend hash_code hash_value(const Vector &); 27 28 public: 29 /// Construct a PBQP vector of the given size. 30 explicit Vector(unsigned Length) 31 : Length(Length), Data(std::make_unique<PBQPNum []>(Length)) {} 32 33 /// Construct a PBQP vector with initializer. 34 Vector(unsigned Length, PBQPNum InitVal) 35 : Length(Length), Data(std::make_unique<PBQPNum []>(Length)) { 36 std::fill(Data.get(), Data.get() + Length, InitVal); 37 } 38 39 /// Copy construct a PBQP vector. 40 Vector(const Vector &V) 41 : Length(V.Length), Data(std::make_unique<PBQPNum []>(Length)) { 42 std::copy(V.Data.get(), V.Data.get() + Length, Data.get()); 43 } 44 45 /// Move construct a PBQP vector. 46 Vector(Vector &&V) 47 : Length(V.Length), Data(std::move(V.Data)) { 48 V.Length = 0; 49 } 50 51 /// Comparison operator. 52 bool operator==(const Vector &V) const { 53 assert(Length != 0 && Data && "Invalid vector"); 54 if (Length != V.Length) 55 return false; 56 return std::equal(Data.get(), Data.get() + Length, V.Data.get()); 57 } 58 59 /// Return the length of the vector 60 unsigned getLength() const { 61 assert(Length != 0 && Data && "Invalid vector"); 62 return Length; 63 } 64 65 /// Element access. 66 PBQPNum& operator[](unsigned Index) { 67 assert(Length != 0 && Data && "Invalid vector"); 68 assert(Index < Length && "Vector element access out of bounds."); 69 return Data[Index]; 70 } 71 72 /// Const element access. 73 const PBQPNum& operator[](unsigned Index) const { 74 assert(Length != 0 && Data && "Invalid vector"); 75 assert(Index < Length && "Vector element access out of bounds."); 76 return Data[Index]; 77 } 78 79 /// Add another vector to this one. 80 Vector& operator+=(const Vector &V) { 81 assert(Length != 0 && Data && "Invalid vector"); 82 assert(Length == V.Length && "Vector length mismatch."); 83 std::transform(Data.get(), Data.get() + Length, V.Data.get(), Data.get(), 84 std::plus<PBQPNum>()); 85 return *this; 86 } 87 88 /// Returns the index of the minimum value in this vector 89 unsigned minIndex() const { 90 assert(Length != 0 && Data && "Invalid vector"); 91 return std::min_element(Data.get(), Data.get() + Length) - Data.get(); 92 } 93 94 private: 95 unsigned Length; 96 std::unique_ptr<PBQPNum []> Data; 97 }; 98 99 /// Return a hash_value for the given vector. 100 inline hash_code hash_value(const Vector &V) { 101 unsigned *VBegin = reinterpret_cast<unsigned*>(V.Data.get()); 102 unsigned *VEnd = reinterpret_cast<unsigned*>(V.Data.get() + V.Length); 103 return hash_combine(V.Length, hash_combine_range(VBegin, VEnd)); 104 } 105 106 /// Output a textual representation of the given vector on the given 107 /// output stream. 108 template <typename OStream> 109 OStream& operator<<(OStream &OS, const Vector &V) { 110 assert((V.getLength() != 0) && "Zero-length vector badness."); 111 112 OS << "[ " << V[0]; 113 for (unsigned i = 1; i < V.getLength(); ++i) 114 OS << ", " << V[i]; 115 OS << " ]"; 116 117 return OS; 118 } 119 120 /// PBQP Matrix class 121 class Matrix { 122 private: 123 friend hash_code hash_value(const Matrix &); 124 125 public: 126 /// Construct a PBQP Matrix with the given dimensions. 127 Matrix(unsigned Rows, unsigned Cols) : 128 Rows(Rows), Cols(Cols), Data(std::make_unique<PBQPNum []>(Rows * Cols)) { 129 } 130 131 /// Construct a PBQP Matrix with the given dimensions and initial 132 /// value. 133 Matrix(unsigned Rows, unsigned Cols, PBQPNum InitVal) 134 : Rows(Rows), Cols(Cols), 135 Data(std::make_unique<PBQPNum []>(Rows * Cols)) { 136 std::fill(Data.get(), Data.get() + (Rows * Cols), InitVal); 137 } 138 139 /// Copy construct a PBQP matrix. 140 Matrix(const Matrix &M) 141 : Rows(M.Rows), Cols(M.Cols), 142 Data(std::make_unique<PBQPNum []>(Rows * Cols)) { 143 std::copy(M.Data.get(), M.Data.get() + (Rows * Cols), Data.get()); 144 } 145 146 /// Move construct a PBQP matrix. 147 Matrix(Matrix &&M) 148 : Rows(M.Rows), Cols(M.Cols), Data(std::move(M.Data)) { 149 M.Rows = M.Cols = 0; 150 } 151 152 /// Comparison operator. 153 bool operator==(const Matrix &M) const { 154 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 155 if (Rows != M.Rows || Cols != M.Cols) 156 return false; 157 return std::equal(Data.get(), Data.get() + (Rows * Cols), M.Data.get()); 158 } 159 160 /// Return the number of rows in this matrix. 161 unsigned getRows() const { 162 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 163 return Rows; 164 } 165 166 /// Return the number of cols in this matrix. 167 unsigned getCols() const { 168 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 169 return Cols; 170 } 171 172 /// Matrix element access. 173 PBQPNum* operator[](unsigned R) { 174 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 175 assert(R < Rows && "Row out of bounds."); 176 return Data.get() + (R * Cols); 177 } 178 179 /// Matrix element access. 180 const PBQPNum* operator[](unsigned R) const { 181 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 182 assert(R < Rows && "Row out of bounds."); 183 return Data.get() + (R * Cols); 184 } 185 186 /// Returns the given row as a vector. 187 Vector getRowAsVector(unsigned R) const { 188 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 189 Vector V(Cols); 190 for (unsigned C = 0; C < Cols; ++C) 191 V[C] = (*this)[R][C]; 192 return V; 193 } 194 195 /// Returns the given column as a vector. 196 Vector getColAsVector(unsigned C) const { 197 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 198 Vector V(Rows); 199 for (unsigned R = 0; R < Rows; ++R) 200 V[R] = (*this)[R][C]; 201 return V; 202 } 203 204 /// Matrix transpose. 205 Matrix transpose() const { 206 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 207 Matrix M(Cols, Rows); 208 for (unsigned r = 0; r < Rows; ++r) 209 for (unsigned c = 0; c < Cols; ++c) 210 M[c][r] = (*this)[r][c]; 211 return M; 212 } 213 214 /// Add the given matrix to this one. 215 Matrix& operator+=(const Matrix &M) { 216 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 217 assert(Rows == M.Rows && Cols == M.Cols && 218 "Matrix dimensions mismatch."); 219 std::transform(Data.get(), Data.get() + (Rows * Cols), M.Data.get(), 220 Data.get(), std::plus<PBQPNum>()); 221 return *this; 222 } 223 224 Matrix operator+(const Matrix &M) { 225 assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix"); 226 Matrix Tmp(*this); 227 Tmp += M; 228 return Tmp; 229 } 230 231 private: 232 unsigned Rows, Cols; 233 std::unique_ptr<PBQPNum []> Data; 234 }; 235 236 /// Return a hash_code for the given matrix. 237 inline hash_code hash_value(const Matrix &M) { 238 unsigned *MBegin = reinterpret_cast<unsigned*>(M.Data.get()); 239 unsigned *MEnd = 240 reinterpret_cast<unsigned*>(M.Data.get() + (M.Rows * M.Cols)); 241 return hash_combine(M.Rows, M.Cols, hash_combine_range(MBegin, MEnd)); 242 } 243 244 /// Output a textual representation of the given matrix on the given 245 /// output stream. 246 template <typename OStream> 247 OStream& operator<<(OStream &OS, const Matrix &M) { 248 assert((M.getRows() != 0) && "Zero-row matrix badness."); 249 for (unsigned i = 0; i < M.getRows(); ++i) 250 OS << M.getRowAsVector(i) << "\n"; 251 return OS; 252 } 253 254 template <typename Metadata> 255 class MDVector : public Vector { 256 public: 257 MDVector(const Vector &v) : Vector(v), md(*this) {} 258 MDVector(Vector &&v) : Vector(std::move(v)), md(*this) { } 259 260 const Metadata& getMetadata() const { return md; } 261 262 private: 263 Metadata md; 264 }; 265 266 template <typename Metadata> 267 inline hash_code hash_value(const MDVector<Metadata> &V) { 268 return hash_value(static_cast<const Vector&>(V)); 269 } 270 271 template <typename Metadata> 272 class MDMatrix : public Matrix { 273 public: 274 MDMatrix(const Matrix &m) : Matrix(m), md(*this) {} 275 MDMatrix(Matrix &&m) : Matrix(std::move(m)), md(*this) { } 276 277 const Metadata& getMetadata() const { return md; } 278 279 private: 280 Metadata md; 281 }; 282 283 template <typename Metadata> 284 inline hash_code hash_value(const MDMatrix<Metadata> &M) { 285 return hash_value(static_cast<const Matrix&>(M)); 286 } 287 288 } // end namespace PBQP 289 } // end namespace llvm 290 291 #endif // LLVM_CODEGEN_PBQP_MATH_H 292