1 #ifndef BBLAS_BITMAT_HPP_ 2 #define BBLAS_BITMAT_HPP_ 3 4 // IWYU pragma: private, include "bblas.hpp" 5 // IWYU pragma: friend ".*/bblas.*" 6 7 /* very generic code on bit matrices. We don't keep much here. More 8 * concrete types such as mat64 provide most of the actual code. 9 */ 10 #include <cstdint> 11 #include <cstring> 12 #include <vector> 13 #include "macros.h" 14 #include "memory.h" // malloc_aligned in utils 15 16 template<typename T> class bitmat; 17 18 namespace bblas_bitmat_details { 19 20 template<typename T> struct bblas_bitmat_type_supported { 21 static constexpr const bool value = false; 22 }; 23 template<typename T> struct bitmat_ops { 24 typedef bitmat<T> matrix; 25 static void fill_random(matrix & w, gmp_randstate_t rstate); 26 static void add(matrix & C, matrix const & A, matrix const & B); 27 static void transpose(matrix & C, matrix const & A); 28 static void mul(matrix & C, matrix const & A, matrix const & B); mul_lt_gebblas_bitmat_details::bitmat_ops29 static void mul_lt_ge(matrix & C, matrix const & A, matrix const & B) { 30 mul(C, A, B); 31 } 32 /* do C[0] = A[0]*B, C[block_stride]=A[block_stride]*B, etc */ 33 static void mul_blocks(matrix * C, matrix const * A, matrix const& B, size_t nblocks, size_t Cstride, size_t Astride); 34 static void addmul_blocks(matrix * C, matrix const * A, matrix const& B, size_t nblocks, size_t Cstride, size_t Astride); 35 static void addmul(matrix & C, matrix const & A, matrix const & B); 36 static void addmul(matrix & C, 37 matrix const & A, 38 matrix const & B, 39 unsigned int i0, 40 unsigned int i1, 41 unsigned int yi0, 42 unsigned int yi1); 43 static void trsm(matrix const & L, 44 matrix & U, 45 unsigned int yi0, 46 unsigned int yi1); 47 static void trsm(matrix const & L, matrix & U); 48 static void extract_uppertriangular(matrix & a, matrix const & b); 49 static void extract_lowertriangular(matrix & a, matrix const & b); 50 /* Keeps only the upper triangular part in U, and copy the lower 51 * triangular, together with a unit block, to L */ 52 static void extract_LU(matrix & L, matrix & U); 53 protected: 54 /* these are accessed as _member functions_ in the matrix type */ 55 static bool is_lowertriangular(matrix const & a); 56 static bool is_uppertriangular(matrix const & a); 57 static bool triangular_is_unit(matrix const & a); 58 static void make_uppertriangular(matrix & a); 59 static void make_lowertriangular(matrix & a); 60 static void make_unit_uppertriangular(matrix & a); 61 static void make_unit_lowertriangular(matrix & a); 62 static void triangular_make_unit(matrix & a); 63 }; 64 } 65 66 template<typename T> 67 class bitmat 68 : public bblas_bitmat_details::bitmat_ops<T> 69 { 70 typedef bblas_bitmat_details::bitmat_ops<T> ops; 71 typedef bblas_bitmat_details::bblas_bitmat_type_supported<T> S; 72 static_assert(S::value, "bblas bitmap must be built on uintX_t"); 73 74 public: 75 static constexpr const int width = S::width; 76 typedef T datatype; 77 typedef std::vector<bitmat, aligned_allocator<bitmat, S::alignment>> vector_type; 78 // typedef std::vector<bitmat> vector_type; 79 80 private: 81 T x[width]; // ATTRIBUTE((aligned(64))) ; 82 83 public: alloc(size_t n)84 static inline bitmat * alloc(size_t n) { 85 return (bitmat *) malloc_aligned(n * sizeof(bitmat), S::alignment); 86 } free(bitmat * p)87 static inline void free(bitmat * p) { 88 free_aligned(p); 89 } 90 data()91 inline T* data() { return x; } data() const92 inline const T* data() const { return x; } operator [](int i)93 T& operator[](int i) { return x[i]; } operator [](int i) const94 T operator[](int i) const { return x[i]; } operator ==(bitmat const & a) const95 inline bool operator==(bitmat const& a) const 96 { 97 /* anyway we're not going to do it any smarter in instantiations, 98 * so let's rather keep this as a simple and stupid inline */ 99 return memcmp(x, a.x, sizeof(x)) == 0; 100 } operator !=(bitmat const & a) const101 inline bool operator!=(bitmat const& a) const { return !operator==(a); } bitmat()102 bitmat() {} bitmat(bitmat const & a)103 inline bitmat(bitmat const& a) { memcpy(x, a.x, sizeof(x)); } operator =(bitmat const & a)104 inline bitmat& operator=(bitmat const& a) 105 { 106 memcpy(x, a.x, sizeof(x)); 107 return *this; 108 } bitmat(int a)109 inline bitmat(int a) { *this = a; } operator =(int a)110 inline bitmat& operator=(int a) 111 { 112 if (a & 1) { 113 T mask = 1; 114 for (int j = 0; j < width; j++, mask <<= 1) 115 x[j] = mask; 116 } else { 117 memset(x, 0, sizeof(x)); 118 } 119 return *this; 120 } operator ==(int a) const121 inline bool operator==(int a) const 122 { 123 if (a&1) { 124 T mask = a&1; 125 for (int j = 0; j < width; j++, mask <<= 1) 126 if (x[j]&~mask) return false; 127 } else { 128 for (int j = 0; j < width; j++) 129 if (x[j]) return false; 130 } 131 return true; 132 } operator !=(int a) const133 inline bool operator!=(int a) const { return !operator==(a); } 134 is_lowertriangular() const135 inline bool is_lowertriangular() const { return ops::is_lowertriangular(*this); } is_uppertriangular() const136 inline bool is_uppertriangular() const { return ops::is_uppertriangular(*this); } triangular_is_unit() const137 inline bool triangular_is_unit() const { return ops::triangular_is_unit(*this); } make_uppertriangular()138 inline void make_uppertriangular() { ops::make_uppertriangular(*this); } make_lowertriangular()139 inline void make_lowertriangular() { ops::make_lowertriangular(*this); } make_unit_uppertriangular()140 inline void make_unit_uppertriangular() { ops::make_unit_uppertriangular(*this); } make_unit_lowertriangular()141 inline void make_unit_lowertriangular() { ops::make_unit_lowertriangular(*this); } triangular_make_unit()142 inline void triangular_make_unit() { ops::triangular_make_unit(*this); } 143 }; 144 145 #endif /* BBLAS_BITMAT_HPP_ */ 146