1 #ifndef BBLAS_BITMAT_HPP_
2 #define BBLAS_BITMAT_HPP_
3 
4 // IWYU pragma: private, include "bblas.hpp"
5 // IWYU pragma: friend ".*/bblas.*"
6 
7 /* very generic code on bit matrices. We don't keep much here. More
8  * concrete types such as mat64 provide most of the actual code.
9  */
10 #include <cstdint>
11 #include <cstring>
12 #include <vector>
13 #include "macros.h"
14 #include "memory.h"     // malloc_aligned in utils
15 
16 template<typename T> class bitmat;
17 
18 namespace bblas_bitmat_details {
19 
20     template<typename T> struct bblas_bitmat_type_supported {
21         static constexpr const bool value = false;
22     };
23     template<typename T> struct bitmat_ops {
24         typedef bitmat<T> matrix;
25         static void fill_random(matrix & w, gmp_randstate_t rstate);
26         static void add(matrix & C, matrix const & A, matrix const & B);
27         static void transpose(matrix & C, matrix const & A);
28         static void mul(matrix & C, matrix const & A, matrix const & B);
mul_lt_gebblas_bitmat_details::bitmat_ops29         static void mul_lt_ge(matrix & C, matrix const & A, matrix const & B) {
30             mul(C, A, B);
31         }
32         /* do C[0] = A[0]*B, C[block_stride]=A[block_stride]*B, etc */
33         static void mul_blocks(matrix * C, matrix const * A, matrix const& B, size_t nblocks, size_t Cstride, size_t Astride);
34         static void addmul_blocks(matrix * C, matrix const * A, matrix const& B, size_t nblocks, size_t Cstride, size_t Astride);
35         static void addmul(matrix & C, matrix const & A, matrix const & B);
36         static void addmul(matrix & C,
37                    matrix const & A,
38                    matrix const & B,
39                    unsigned int i0,
40                    unsigned int i1,
41                    unsigned int yi0,
42                    unsigned int yi1);
43         static void trsm(matrix const & L,
44                 matrix & U,
45                 unsigned int yi0,
46                 unsigned int yi1);
47         static void trsm(matrix const & L, matrix & U);
48         static void extract_uppertriangular(matrix & a, matrix const & b);
49         static void extract_lowertriangular(matrix & a, matrix const & b);
50         /* Keeps only the upper triangular part in U, and copy the lower
51          * triangular, together with a unit block, to L */
52         static void extract_LU(matrix & L, matrix & U);
53         protected:
54         /* these are accessed as _member functions_ in the matrix type */
55         static bool is_lowertriangular(matrix const & a);
56         static bool is_uppertriangular(matrix const & a);
57         static bool triangular_is_unit(matrix const & a);
58         static void make_uppertriangular(matrix & a);
59         static void make_lowertriangular(matrix & a);
60         static void make_unit_uppertriangular(matrix & a);
61         static void make_unit_lowertriangular(matrix & a);
62         static void triangular_make_unit(matrix & a);
63     };
64 }
65 
66 template<typename T>
67 class bitmat
68     : public bblas_bitmat_details::bitmat_ops<T>
69 {
70     typedef bblas_bitmat_details::bitmat_ops<T> ops;
71     typedef bblas_bitmat_details::bblas_bitmat_type_supported<T> S;
72     static_assert(S::value, "bblas bitmap must be built on uintX_t");
73 
74     public:
75     static constexpr const int width = S::width;
76     typedef T datatype;
77     typedef std::vector<bitmat, aligned_allocator<bitmat, S::alignment>> vector_type;
78     // typedef std::vector<bitmat> vector_type;
79 
80     private:
81     T x[width]; // ATTRIBUTE((aligned(64))) ;
82 
83     public:
alloc(size_t n)84     static inline bitmat * alloc(size_t n) {
85         return (bitmat *) malloc_aligned(n * sizeof(bitmat), S::alignment);
86     }
free(bitmat * p)87     static inline void free(bitmat * p) {
88         free_aligned(p);
89     }
90 
data()91     inline T* data() { return x; }
data() const92     inline const T* data() const { return x; }
operator [](int i)93     T& operator[](int i) { return x[i]; }
operator [](int i) const94     T operator[](int i) const { return x[i]; }
operator ==(bitmat const & a) const95     inline bool operator==(bitmat const& a) const
96     {
97         /* anyway we're not going to do it any smarter in instantiations,
98          * so let's rather keep this as a simple and stupid inline */
99         return memcmp(x, a.x, sizeof(x)) == 0;
100     }
operator !=(bitmat const & a) const101     inline bool operator!=(bitmat const& a) const { return !operator==(a); }
bitmat()102     bitmat() {}
bitmat(bitmat const & a)103     inline bitmat(bitmat const& a) { memcpy(x, a.x, sizeof(x)); }
operator =(bitmat const & a)104     inline bitmat& operator=(bitmat const& a)
105     {
106         memcpy(x, a.x, sizeof(x));
107         return *this;
108     }
bitmat(int a)109     inline bitmat(int a) { *this = a; }
operator =(int a)110     inline bitmat& operator=(int a)
111     {
112         if (a & 1) {
113             T mask = 1;
114             for (int j = 0; j < width; j++, mask <<= 1)
115                 x[j] = mask;
116         } else {
117             memset(x, 0, sizeof(x));
118         }
119         return *this;
120     }
operator ==(int a) const121     inline bool operator==(int a) const
122     {
123         if (a&1) {
124             T mask = a&1;
125             for (int j = 0; j < width; j++, mask <<= 1)
126                 if (x[j]&~mask) return false;
127         } else {
128             for (int j = 0; j < width; j++)
129                 if (x[j]) return false;
130         }
131         return true;
132     }
operator !=(int a) const133     inline bool operator!=(int a) const { return !operator==(a); }
134 
is_lowertriangular() const135     inline bool is_lowertriangular() const { return ops::is_lowertriangular(*this); }
is_uppertriangular() const136     inline bool is_uppertriangular() const { return ops::is_uppertriangular(*this); }
triangular_is_unit() const137     inline bool triangular_is_unit() const { return ops::triangular_is_unit(*this); }
make_uppertriangular()138     inline void make_uppertriangular() { ops::make_uppertriangular(*this); }
make_lowertriangular()139     inline void make_lowertriangular() { ops::make_lowertriangular(*this); }
make_unit_uppertriangular()140     inline void make_unit_uppertriangular() { ops::make_unit_uppertriangular(*this); }
make_unit_lowertriangular()141     inline void make_unit_lowertriangular() { ops::make_unit_lowertriangular(*this); }
triangular_make_unit()142     inline void triangular_make_unit() { ops::triangular_make_unit(*this); }
143 };
144 
145 #endif	/* BBLAS_BITMAT_HPP_ */
146