1 #include "cado.h" // IWYU pragma: keep
2 
3 #include <cstdarg>         // for va_list, va_end, va_start
4 #include <cstddef>         // for ptrdiff_t
5 #include <cstdio>
6 #include <cstdlib>
7 #include <cstdint>
8 #include <cstring>
9 #include <climits>
10 #include <gmp.h>
11 #include "matmul.h"
12 #include "matmul-common.h"
13 #include "mpfq_layer.h"
14 
15 #include "matmul_facade.h"
16 
17 #include "arith-modp.hpp"
18 #include "macros.h"
19 #include "params.h"
20 
21 /* define "gfp" as being our c++ type built from the number of words in
22  * the underlying mpfq data type.
23  *
24  * It's a bit hacky, but mpfq should provide a ``number of words''
25  * implementation info.
26  */
27 template<typename F, unsigned int m> struct our_gfp_type {
28     static const int mpfq_base_field_width = sizeof(F)/sizeof(unsigned long);
29     typedef arith_modp::gfp<mpfq_base_field_width> type;
30 };
31 
32 template<typename F> struct our_gfp_type<F, UINT_MAX> {
33     static const int mpfq_base_field_width = 0;
34     /* We *intentionally* do not provide a variable-width GF(p) type with
35      * the C++ code. That wouldn't be totally impossible, but I can assure
36      * that it would be a royal pain (something like a 200+-line patch of
37      * barely parseable c++ hacks to arith-modp.hpp -- tried it out and
38      * gave up...).
39      */
40 };
41 
42 /* If this line complains that ::type is not a type, then see above */
43 typedef our_gfp_type<abelt,abimpl_max_characteristic_bits()>::type gfp;
44 
45 
46 
47 /* This extension is used to distinguish between several possible
48  * implementations of the product */
49 #define MM_EXTENSION   "-basicp"
50 #define MM_MAGIC_FAMILY        0xb001UL
51 #define MM_MAGIC_VERSION       0x1001UL
52 #define MM_MAGIC (MM_MAGIC_FAMILY << 16 | MM_MAGIC_VERSION)
53 
54 /* This selects the default behaviour as to which is our best code
55  * for multiplying. If this flag is 1, then a multiplication matrix times
56  * vector (direction==1) performs best if the in-memory structure
57  * reflects the non-transposed matrix. Similarly, a vector times matrix
58  * multiplication (direction==0) performs best if the in-memory structure
59  * reflects the transposed matrix. When the flag is 1, the converse
60  * happens.
61  * This flag depends on the implementation, and possibly even on the cpu
62  * type under certain circumstances.
63  */
64 #define MM_DIR0_PREFERS_TRANSP_MULT   1
65 
66 struct matmul_basicp_data_s {
67     /* repeat the fields from the public interface */
68     struct matmul_public_s public_[1];
69     /* now our private fields */
70     size_t datasize;
71     abdst_field xab;
72     uint32_t * q;
73 };
74 
MATMUL_NAME(clear)75 void MATMUL_NAME(clear)(matmul_ptr mm0)
76 {
77     struct matmul_basicp_data_s * mm = (struct matmul_basicp_data_s *) mm0;
78     matmul_common_clear(mm->public_);
79     free(mm->q);
80     free(mm);
81 }
82 
MATMUL_NAME(init)83 matmul_ptr MATMUL_NAME(init)(void* xx, param_list pl, int optimized_direction)
84 {
85     struct matmul_basicp_data_s * mm;
86     mm = (struct matmul_basicp_data_s *) malloc(sizeof(struct matmul_basicp_data_s));
87     memset(mm, 0, sizeof(struct matmul_basicp_data_s));
88     mm->xab = (abdst_field) xx;
89 
90     int suggest = optimized_direction ^ MM_DIR0_PREFERS_TRANSP_MULT;
91     mm->public_->store_transposed = suggest;
92     if (pl) {
93         param_list_parse_int(pl, "mm_store_transposed",
94                 &mm->public_->store_transposed);
95         if (mm->public_->store_transposed != suggest) {
96             fprintf(stderr, "Warning, mm_store_transposed"
97                     " overrides suggested matrix storage ordering\n");
98         }
99     }
100 
101     return (matmul_ptr) mm;
102 }
103 
MATMUL_NAME(build_cache)104 void MATMUL_NAME(build_cache)(matmul_ptr mm0, uint32_t * data, size_t size)
105 {
106     ASSERT_ALWAYS(data);
107 
108     struct matmul_basicp_data_s * mm = (struct matmul_basicp_data_s *) mm0;
109     unsigned int nrows_t = mm->public_->dim[ mm->public_->store_transposed];
110 
111     uint32_t * ptr = data;
112     unsigned int i = 0;
113 
114     /* count coefficients */
115     for( ; i < nrows_t ; i++) {
116         unsigned int weight = 0;
117         weight += *ptr;
118         mm->public_->ncoeffs += weight;
119 
120         ptr++;
121         ptr += weight;
122         ptr += weight;
123     }
124 
125     mm->q = data;
126 
127     mm->datasize = nrows_t + 2 * mm->public_->ncoeffs;
128 
129     ASSERT_ALWAYS(size == mm->datasize);
130     ASSERT_ALWAYS(ptr - data == (ptrdiff_t) mm->datasize);
131 }
132 
MATMUL_NAME(reload_cache)133 int MATMUL_NAME(reload_cache)(matmul_ptr mm0)
134 {
135     FILE * f;
136     struct matmul_basicp_data_s * mm = (struct matmul_basicp_data_s *) mm0;
137     f = matmul_common_reload_cache_fopen(sizeof(abelt), mm->public_, MM_MAGIC);
138     if (!f) return 0;
139 
140     MATMUL_COMMON_READ_ONE32(mm->datasize, f);
141     mm->q = (uint32_t *) malloc(mm->datasize * sizeof(uint32_t));
142     MATMUL_COMMON_READ_MANY32(mm->q, mm->datasize, f);
143     fclose(f);
144 
145     return 1;
146 }
147 
MATMUL_NAME(save_cache)148 void MATMUL_NAME(save_cache)(matmul_ptr mm0)
149 {
150     FILE * f;
151 
152     struct matmul_basicp_data_s * mm = (struct matmul_basicp_data_s *) mm0;
153     f = matmul_common_save_cache_fopen(sizeof(abelt), mm->public_, MM_MAGIC);
154     if (!f) return;
155 
156     MATMUL_COMMON_WRITE_ONE32(mm->datasize, f);
157     MATMUL_COMMON_WRITE_MANY32(mm->q, mm->datasize, f);
158 
159     fclose(f);
160 }
161 
MATMUL_NAME(mul)162 void MATMUL_NAME(mul)(matmul_ptr mm0, void * xdst, void const * xsrc, int d)
163 {
164     struct matmul_basicp_data_s * mm = (struct matmul_basicp_data_s *) mm0;
165     ASM_COMMENT("multiplication code");
166     uint32_t * q = mm->q;
167     abdst_field x = mm->xab;
168     const gfp::elt * src = (const gfp::elt *) xsrc;
169     gfp::elt * dst = (gfp::elt *) xdst;
170 
171     gfp::preinv preinverse;
172     gfp::elt prime;
173 
174     {
175         mpz_t p;
176         mpz_init(p);
177         abfield_characteristic(x, p);
178         prime = p;
179         mpz_clear(p);
180     }
181 
182     gfp::compute_preinv(preinverse, prime);
183 
184     /* d == 1: matrix times vector product */
185     /* d == 0: vector times matrix product */
186 
187     /* However the matrix may be stored either row-major
188      * (store_transposed == 0) or column-major (store_transposed == 1)
189      */
190 
191     gfp::elt::zero(dst, mm->public_->dim[!d]);
192 
193     if (d == !mm->public_->store_transposed) {
194         gfp::elt_ur rowsum;
195         ASM_COMMENT("critical loop");
196         for(unsigned int i = 0 ; i < mm->public_->dim[!d] ; i++) {
197             uint32_t len = *q++;
198             unsigned int j = 0;
199             rowsum.zero();
200             for( ; len-- ; ) {
201                 j = *q++;
202                 int32_t c = *(int32_t*)q++;
203                 ASSERT(j < mm->public_->dim[d]);
204                 if (c == 1) {
205                     gfp::add(rowsum, src[j]);
206                 } else if (c == -1) {
207                     gfp::sub(rowsum, src[j]);
208                 } else if (c > 0) {
209                     gfp::addmul_ui(rowsum, src[j], c, prime, preinverse);
210                 } else {
211                     gfp::submul_ui(rowsum, src[j], -c, prime, preinverse);
212                 }
213             }
214             gfp::reduce(dst[i], rowsum, prime, preinverse);
215         }
216         ASM_COMMENT("end of critical loop");
217     } else {
218         gfp::elt_ur * tdst = new gfp::elt_ur[mm->public_->dim[!d]];
219         // gfp::elt::zero(tdst, mm->public_->dim[!d]);
220         if (mm->public_->iteration[d] == 10) {
221             fprintf(stderr, "Warning: Doing many iterations with transposed code (not a huge problem for impl=basicp)\n");
222         }
223         ASM_COMMENT("critical loop (transposed mult)");
224         for(unsigned int i = 0 ; i < mm->public_->dim[d] ; i++) {
225             uint32_t len = *q++;
226             unsigned int j = 0;
227             for( ; len-- ; ) {
228                 j = *q++;
229                 int32_t c = *(int32_t*)q++;
230                 ASSERT(j < mm->public_->dim[!d]);
231                 if (c == 1) {
232                     gfp::add(tdst[j], src[i]);
233                 } else if (c == -1) {
234                     gfp::sub(tdst[j], src[i]);
235                 } else if (c > 0) {
236                     gfp::addmul_ui(tdst[j], src[i], c, prime, preinverse);
237                 } else {
238                     gfp::submul_ui(tdst[j], src[i], -c, prime, preinverse);
239                 }
240             }
241         }
242         for(unsigned int j = 0 ; j < mm->public_->dim[!d] ; j++) {
243             gfp::reduce(dst[j], tdst[j], prime, preinverse);
244         }
245         ASM_COMMENT("end of critical loop (transposed mult)");
246         delete[] tdst;
247     }
248     ASM_COMMENT("end of multiplication code");
249 
250     mm->public_->iteration[d]++;
251 }
252 
MATMUL_NAME(report)253 void MATMUL_NAME(report)(matmul_ptr mm0 MAYBE_UNUSED, double scale MAYBE_UNUSED) {
254 }
255 
MATMUL_NAME(auxv)256 void MATMUL_NAME(auxv)(matmul_ptr mm0 MAYBE_UNUSED, int op MAYBE_UNUSED, va_list ap MAYBE_UNUSED)
257 {
258 }
259 
MATMUL_NAME(aux)260 void MATMUL_NAME(aux)(matmul_ptr mm0, int op, ...)
261 {
262     va_list ap;
263     va_start(ap, op);
264     MATMUL_NAME(auxv) (mm0, op, ap);
265     va_end(ap);
266 }
267 
268 /* vim: set sw=4: */
269