1 #include "cado.h" // IWYU pragma: keep
2
3 #include <cstdarg> // for va_list, va_end, va_start
4 #include <cstddef> // for ptrdiff_t
5 #include <cstdio>
6 #include <cstdlib>
7 #include <cstdint>
8 #include <cstring>
9 #include <climits>
10 #include <gmp.h>
11 #include "matmul.h"
12 #include "matmul-common.h"
13 #include "mpfq_layer.h"
14
15 #include "matmul_facade.h"
16
17 #include "arith-modp.hpp"
18 #include "macros.h"
19 #include "params.h"
20
21 /* define "gfp" as being our c++ type built from the number of words in
22 * the underlying mpfq data type.
23 *
24 * It's a bit hacky, but mpfq should provide a ``number of words''
25 * implementation info.
26 */
27 template<typename F, unsigned int m> struct our_gfp_type {
28 static const int mpfq_base_field_width = sizeof(F)/sizeof(unsigned long);
29 typedef arith_modp::gfp<mpfq_base_field_width> type;
30 };
31
32 template<typename F> struct our_gfp_type<F, UINT_MAX> {
33 static const int mpfq_base_field_width = 0;
34 /* We *intentionally* do not provide a variable-width GF(p) type with
35 * the C++ code. That wouldn't be totally impossible, but I can assure
36 * that it would be a royal pain (something like a 200+-line patch of
37 * barely parseable c++ hacks to arith-modp.hpp -- tried it out and
38 * gave up...).
39 */
40 };
41
42 /* If this line complains that ::type is not a type, then see above */
43 typedef our_gfp_type<abelt,abimpl_max_characteristic_bits()>::type gfp;
44
45
46
47 /* This extension is used to distinguish between several possible
48 * implementations of the product */
49 #define MM_EXTENSION "-basicp"
50 #define MM_MAGIC_FAMILY 0xb001UL
51 #define MM_MAGIC_VERSION 0x1001UL
52 #define MM_MAGIC (MM_MAGIC_FAMILY << 16 | MM_MAGIC_VERSION)
53
54 /* This selects the default behaviour as to which is our best code
55 * for multiplying. If this flag is 1, then a multiplication matrix times
56 * vector (direction==1) performs best if the in-memory structure
57 * reflects the non-transposed matrix. Similarly, a vector times matrix
58 * multiplication (direction==0) performs best if the in-memory structure
59 * reflects the transposed matrix. When the flag is 1, the converse
60 * happens.
61 * This flag depends on the implementation, and possibly even on the cpu
62 * type under certain circumstances.
63 */
64 #define MM_DIR0_PREFERS_TRANSP_MULT 1
65
66 struct matmul_basicp_data_s {
67 /* repeat the fields from the public interface */
68 struct matmul_public_s public_[1];
69 /* now our private fields */
70 size_t datasize;
71 abdst_field xab;
72 uint32_t * q;
73 };
74
MATMUL_NAME(clear)75 void MATMUL_NAME(clear)(matmul_ptr mm0)
76 {
77 struct matmul_basicp_data_s * mm = (struct matmul_basicp_data_s *) mm0;
78 matmul_common_clear(mm->public_);
79 free(mm->q);
80 free(mm);
81 }
82
MATMUL_NAME(init)83 matmul_ptr MATMUL_NAME(init)(void* xx, param_list pl, int optimized_direction)
84 {
85 struct matmul_basicp_data_s * mm;
86 mm = (struct matmul_basicp_data_s *) malloc(sizeof(struct matmul_basicp_data_s));
87 memset(mm, 0, sizeof(struct matmul_basicp_data_s));
88 mm->xab = (abdst_field) xx;
89
90 int suggest = optimized_direction ^ MM_DIR0_PREFERS_TRANSP_MULT;
91 mm->public_->store_transposed = suggest;
92 if (pl) {
93 param_list_parse_int(pl, "mm_store_transposed",
94 &mm->public_->store_transposed);
95 if (mm->public_->store_transposed != suggest) {
96 fprintf(stderr, "Warning, mm_store_transposed"
97 " overrides suggested matrix storage ordering\n");
98 }
99 }
100
101 return (matmul_ptr) mm;
102 }
103
MATMUL_NAME(build_cache)104 void MATMUL_NAME(build_cache)(matmul_ptr mm0, uint32_t * data, size_t size)
105 {
106 ASSERT_ALWAYS(data);
107
108 struct matmul_basicp_data_s * mm = (struct matmul_basicp_data_s *) mm0;
109 unsigned int nrows_t = mm->public_->dim[ mm->public_->store_transposed];
110
111 uint32_t * ptr = data;
112 unsigned int i = 0;
113
114 /* count coefficients */
115 for( ; i < nrows_t ; i++) {
116 unsigned int weight = 0;
117 weight += *ptr;
118 mm->public_->ncoeffs += weight;
119
120 ptr++;
121 ptr += weight;
122 ptr += weight;
123 }
124
125 mm->q = data;
126
127 mm->datasize = nrows_t + 2 * mm->public_->ncoeffs;
128
129 ASSERT_ALWAYS(size == mm->datasize);
130 ASSERT_ALWAYS(ptr - data == (ptrdiff_t) mm->datasize);
131 }
132
MATMUL_NAME(reload_cache)133 int MATMUL_NAME(reload_cache)(matmul_ptr mm0)
134 {
135 FILE * f;
136 struct matmul_basicp_data_s * mm = (struct matmul_basicp_data_s *) mm0;
137 f = matmul_common_reload_cache_fopen(sizeof(abelt), mm->public_, MM_MAGIC);
138 if (!f) return 0;
139
140 MATMUL_COMMON_READ_ONE32(mm->datasize, f);
141 mm->q = (uint32_t *) malloc(mm->datasize * sizeof(uint32_t));
142 MATMUL_COMMON_READ_MANY32(mm->q, mm->datasize, f);
143 fclose(f);
144
145 return 1;
146 }
147
MATMUL_NAME(save_cache)148 void MATMUL_NAME(save_cache)(matmul_ptr mm0)
149 {
150 FILE * f;
151
152 struct matmul_basicp_data_s * mm = (struct matmul_basicp_data_s *) mm0;
153 f = matmul_common_save_cache_fopen(sizeof(abelt), mm->public_, MM_MAGIC);
154 if (!f) return;
155
156 MATMUL_COMMON_WRITE_ONE32(mm->datasize, f);
157 MATMUL_COMMON_WRITE_MANY32(mm->q, mm->datasize, f);
158
159 fclose(f);
160 }
161
MATMUL_NAME(mul)162 void MATMUL_NAME(mul)(matmul_ptr mm0, void * xdst, void const * xsrc, int d)
163 {
164 struct matmul_basicp_data_s * mm = (struct matmul_basicp_data_s *) mm0;
165 ASM_COMMENT("multiplication code");
166 uint32_t * q = mm->q;
167 abdst_field x = mm->xab;
168 const gfp::elt * src = (const gfp::elt *) xsrc;
169 gfp::elt * dst = (gfp::elt *) xdst;
170
171 gfp::preinv preinverse;
172 gfp::elt prime;
173
174 {
175 mpz_t p;
176 mpz_init(p);
177 abfield_characteristic(x, p);
178 prime = p;
179 mpz_clear(p);
180 }
181
182 gfp::compute_preinv(preinverse, prime);
183
184 /* d == 1: matrix times vector product */
185 /* d == 0: vector times matrix product */
186
187 /* However the matrix may be stored either row-major
188 * (store_transposed == 0) or column-major (store_transposed == 1)
189 */
190
191 gfp::elt::zero(dst, mm->public_->dim[!d]);
192
193 if (d == !mm->public_->store_transposed) {
194 gfp::elt_ur rowsum;
195 ASM_COMMENT("critical loop");
196 for(unsigned int i = 0 ; i < mm->public_->dim[!d] ; i++) {
197 uint32_t len = *q++;
198 unsigned int j = 0;
199 rowsum.zero();
200 for( ; len-- ; ) {
201 j = *q++;
202 int32_t c = *(int32_t*)q++;
203 ASSERT(j < mm->public_->dim[d]);
204 if (c == 1) {
205 gfp::add(rowsum, src[j]);
206 } else if (c == -1) {
207 gfp::sub(rowsum, src[j]);
208 } else if (c > 0) {
209 gfp::addmul_ui(rowsum, src[j], c, prime, preinverse);
210 } else {
211 gfp::submul_ui(rowsum, src[j], -c, prime, preinverse);
212 }
213 }
214 gfp::reduce(dst[i], rowsum, prime, preinverse);
215 }
216 ASM_COMMENT("end of critical loop");
217 } else {
218 gfp::elt_ur * tdst = new gfp::elt_ur[mm->public_->dim[!d]];
219 // gfp::elt::zero(tdst, mm->public_->dim[!d]);
220 if (mm->public_->iteration[d] == 10) {
221 fprintf(stderr, "Warning: Doing many iterations with transposed code (not a huge problem for impl=basicp)\n");
222 }
223 ASM_COMMENT("critical loop (transposed mult)");
224 for(unsigned int i = 0 ; i < mm->public_->dim[d] ; i++) {
225 uint32_t len = *q++;
226 unsigned int j = 0;
227 for( ; len-- ; ) {
228 j = *q++;
229 int32_t c = *(int32_t*)q++;
230 ASSERT(j < mm->public_->dim[!d]);
231 if (c == 1) {
232 gfp::add(tdst[j], src[i]);
233 } else if (c == -1) {
234 gfp::sub(tdst[j], src[i]);
235 } else if (c > 0) {
236 gfp::addmul_ui(tdst[j], src[i], c, prime, preinverse);
237 } else {
238 gfp::submul_ui(tdst[j], src[i], -c, prime, preinverse);
239 }
240 }
241 }
242 for(unsigned int j = 0 ; j < mm->public_->dim[!d] ; j++) {
243 gfp::reduce(dst[j], tdst[j], prime, preinverse);
244 }
245 ASM_COMMENT("end of critical loop (transposed mult)");
246 delete[] tdst;
247 }
248 ASM_COMMENT("end of multiplication code");
249
250 mm->public_->iteration[d]++;
251 }
252
MATMUL_NAME(report)253 void MATMUL_NAME(report)(matmul_ptr mm0 MAYBE_UNUSED, double scale MAYBE_UNUSED) {
254 }
255
MATMUL_NAME(auxv)256 void MATMUL_NAME(auxv)(matmul_ptr mm0 MAYBE_UNUSED, int op MAYBE_UNUSED, va_list ap MAYBE_UNUSED)
257 {
258 }
259
MATMUL_NAME(aux)260 void MATMUL_NAME(aux)(matmul_ptr mm0, int op, ...)
261 {
262 va_list ap;
263 va_start(ap, op);
264 MATMUL_NAME(auxv) (mm0, op, ap);
265 va_end(ap);
266 }
267
268 /* vim: set sw=4: */
269