1 /* Shirokauer maps
2
3 This is largely copied from sm_simple.cpp ; however the purpose of
4 sm_simple was to be kept simple and stupid, so let's keep it like
5 this.
6
7 This program is to be used as an mpi accelerator for reconstructlog-dl
8
9 Given a relation file where each line begins with an a,b pair IN
10 HEXADECIMAL FORMAT, output the same line, appending the SM values in
11 the end.
12
13 SM computation is offloaded to the (many) MPI jobs.
14
15 */
16
17 #include "cado.h" // IWYU pragma: keep
18 #include <cstdio>
19 #include <cstdlib>
20 #include <cstring>
21 #include <cinttypes>
22 #include <vector>
23 #include <string>
24 #include <cstdarg> // for va_end, va_list, va_start
25 #include <cstdint> // for int64_t, uint64_t
26 #include <iosfwd> // for std
27 #include <memory> // for allocator_traits<>::value_type
28 #include <gmp.h>
29 #include "cado_poly.h" // for NB_POLYS_MAX, cado_poly_clear, cado_poly_init
30 #include "gzip.h" // fopen_maybe_compressed
31 #include "macros.h"
32 #include "mpz_poly.h" // for mpz_poly_clear, mpz_poly_init, mpz_poly, mpz_...
33 #include "params.h"
34 #include "select_mpi.h"
35 #include "sm_utils.h" // sm_side_info
36 #include "timing.h" // seconds
37 #include "verbose.h" // verbose_output_print
38
39 using namespace std;
40
41 struct ab_pair {
42 int64_t a; /* only a is allowed to be negative */
43 uint64_t b;
44 };
45
46 typedef vector<ab_pair> ab_pair_batch;
47
48 unsigned int batch_size = 128;
49 unsigned int debug = 0;
50
51 struct task_globals {
52 int nsm_total;
53 size_t limbs_per_ell;
54 FILE * in;
55 FILE * out;
56 size_t nrels_in;
57 size_t nrels_out;
58 };
59
60 struct peer_status {
61 MPI_Request req;
62 ab_pair_batch batch;
63 vector<string> rels;
64 void receive(task_globals & tg, int peer, int turn);
65 void send_finish(task_globals & tg, int peer, int turn);
66 /* returns 1 on eof, 0 normally */
67 int create_and_send_batch(task_globals& tg, int peer, int turn);
68 };
69
70 static int debug_fprintf(FILE * out, const char * fmt, ...)
71 ATTR_PRINTF(2, 3);
debug_fprintf(FILE * out,const char * fmt,...)72 static int debug_fprintf(FILE * out, const char * fmt, ...)
73 {
74 va_list ap;
75 va_start(ap, fmt);
76 int rc = debug ? vfprintf(out, fmt, ap) : 1;
77 va_end(ap);
78 return rc;
79 }
80
receive(task_globals & tg,int peer,int turn)81 void peer_status::receive(task_globals & tg, int peer, int turn)
82 {
83 unsigned long bsize = batch.size();
84
85 if (!bsize) return;
86
87 MPI_Wait(&req, MPI_STATUS_IGNORE);
88 batch.clear();
89
90 mp_limb_t * returns = new mp_limb_t[bsize * tg.nsm_total * tg.limbs_per_ell];
91 // [bsize][tg.nsm_total][tg.limbs_per_ell];
92
93 MPI_Recv(returns, bsize * tg.nsm_total * tg.limbs_per_ell * sizeof(mp_limb_t), MPI_BYTE, peer, turn, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
94
95 ASSERT_ALWAYS(rels.size() == bsize);
96
97 for(unsigned long i = 0 ; i < bsize ; i++) {
98 fputs(rels[i].c_str(), tg.out);
99 bool comma = false;
100 for (int j = 0 ; j < tg.nsm_total ; j++) {
101 mp_limb_t * rij = returns + ((i * tg.nsm_total) + j) * tg.limbs_per_ell;
102 gmp_fprintf(tg.out, "%c%Nd", comma ? ',' : ':', rij, tg.limbs_per_ell);
103 comma=true;
104 }
105 fputc('\n', tg.out);
106 tg.nrels_out++;
107 }
108 rels.clear();
109
110 delete[] returns;
111 }
112
send_finish(task_globals &,int peer,int turn)113 void peer_status::send_finish(task_globals &, int peer, int turn)
114 {
115 /* tell this slave to finish */
116 unsigned long bsize = 0;
117 MPI_Send(&bsize, 1, MPI_UNSIGNED_LONG, peer, turn, MPI_COMM_WORLD);
118 }
119
create_and_send_batch(task_globals & tg,int peer,int turn)120 int peer_status::create_and_send_batch(task_globals& tg, int peer, int turn)
121 {
122 int eof = 0;
123 char buf[1024];
124
125 ASSERT_ALWAYS(batch.empty());
126 while (!eof && batch.size() < batch_size && fgets(buf, 1024, tg.in)) {
127 int n = strlen(buf);
128 if (!n) {
129 fprintf(stderr, "Got 0-sized buffer in fgets, shouldn't happen. Assuming EOF.\n");
130 eof=true;
131 break;
132 }
133 buf[n-1]='\0';
134
135 if (buf[0] == '#') {
136 fputs(buf, tg.out);
137 fputc('\n', tg.out);
138 continue;
139 }
140
141 tg.nrels_in++;
142 rels.push_back(string(buf));
143
144 char * p = buf;
145 ab_pair ab;
146 int64_t sign = 1;
147 if (*p=='-') {
148 sign=-1;
149 p++;
150 }
151 if (sscanf(p, "%" SCNx64 ",%" SCNx64 ":", &ab.a, &ab.b) < 2) {
152 fprintf(stderr, "Parse error at line: %s\n", buf);
153 exit(EXIT_FAILURE);
154 }
155 ab.a *= sign;
156
157 batch.push_back(ab);
158 }
159 if (!eof && batch.size() < batch_size) {
160 eof=true;
161 if (ferror(stdin)) {
162 fprintf(stderr, "Error on stdin\n");
163 }
164 }
165 /* 0 bsize will be recognized by slaves as a
166 * reason to stop processing */
167 unsigned long bsize = batch.size();
168 MPI_Send(&bsize, 1, MPI_UNSIGNED_LONG, peer, turn, MPI_COMM_WORLD);
169 if (bsize)
170 MPI_Isend((char*) batch.data(), bsize * sizeof(ab_pair), MPI_BYTE, peer, turn, MPI_COMM_WORLD, &req);
171 return eof;
172 }
173
174 #define CSI_RED "\033[00;31m"
175 #define CSI_GREEN "\033[00;32m"
176 #define CSI_YELLOW "\033[00;33m"
177 #define CSI_BLUE "\033[00;34m"
178 #define CSI_PINK "\033[00;35m"
179 #define CSI_BOLDRED "\033[01;31m"
180 #define CSI_BOLDGREEN "\033[01;32m"
181 #define CSI_BOLDYELLOW "\033[01;33m"
182 #define CSI_BOLDBLUE "\033[01;34m"
183 #define CSI_BOLDPINK "\033[01;35m"
184 #define CSI_RESET "\033[m"
185
sm_append_master(FILE * in,FILE * out,sm_side_info * sm_info,int nb_polys,int size)186 static void sm_append_master(FILE * in, FILE * out, sm_side_info *sm_info, int nb_polys, int size)
187 {
188 /* need to know how many mp_limb_t's we'll get back from each batch */
189 task_globals tg;
190 tg.limbs_per_ell = 0;
191 tg.nsm_total=0;
192 tg.in = in;
193 tg.out = out;
194 tg.nrels_in = 0;
195 tg.nrels_out = 0;
196 for(int side = 0; side < nb_polys; side++) {
197 tg.nsm_total += sm_info[side]->nsm;
198 if (sm_info[side]->nsm)
199 tg.limbs_per_ell = mpz_size(sm_info[side]->ell);
200 }
201
202 std::vector<peer_status> peers(size);
203
204 int eof = 0;
205 /* eof = 1 on first time. eof = 2 when all receives are done */
206
207 fprintf(stderr, "# running master with %d slaves, batch size %u\n",
208 size-1, batch_size);
209 fprintf(stderr, "# make sure you use \"--bind-to core\" or equivalent\n");
210
211 double t0 = wct_seconds();
212 int turn;
213 for(turn = 0 ; eof <= 2 ; turn++, eof += !!eof) {
214 double t = wct_seconds();
215 debug_fprintf(stderr, "%.3f " CSI_BOLDRED "start turn %d" CSI_RESET "\n", t0, turn);
216 for(int peer = 1; peer < size; peer++) {
217 if (eof && peers[peer].batch.empty()) {
218 /* Our last send was a 0-send, so we have nothing to do */
219 continue;
220 }
221
222 double dt = wct_seconds();
223 debug_fprintf(stderr, "%.3f start turn %d receive from peer %d\n", wct_seconds(), turn - 1, peer);
224 peers[peer].receive(tg, peer, turn - 1);
225 dt = wct_seconds() - dt;
226 debug_fprintf(stderr, "%.3f done turn %d receive from peer %d [taken %.1f]\n", wct_seconds(), turn - 1, peer, dt);
227
228 if (eof) {
229 debug_fprintf(stderr, "%.3f start turn %d send finish to peer %d\n", wct_seconds(), turn, peer);
230 peers[peer].send_finish(tg, peer, turn);
231 } else if (!eof) {
232 dt = wct_seconds();
233 debug_fprintf(stderr, "%.3f start turn %d send to peer %d\n", wct_seconds(), turn, peer);
234 eof = peers[peer].create_and_send_batch(tg, peer, turn);
235 dt = wct_seconds() - dt;
236 debug_fprintf(stderr, "%.3f done turn %d send to peer %d [taken %.1f]\n", wct_seconds(), turn, peer, dt);
237 }
238 }
239 debug_fprintf(stderr, "%.3f " CSI_BOLDRED "done turn %d " CSI_RESET "[taken %.1f] s\n", wct_seconds(), turn, wct_seconds()-t);
240 if (turn && !(turn & (turn+1))) {
241 /* print only when turn is a power of two */
242 fprintf(stderr, "# printed %zu rels in %.1f s"
243 " (%.1f / batch, %.1f rels/s)\n",
244 tg.nrels_out, wct_seconds()-t0,
245 (wct_seconds()-t0) / turn,
246 tg.nrels_out / (wct_seconds()-t0));
247 }
248 }
249 fprintf(stderr, "# final: printed %zu rels in %.1f s"
250 " (%.1f / batch, %.1f rels/s)\n",
251 tg.nrels_out, wct_seconds()-t0,
252 (wct_seconds()-t0) / turn,
253 tg.nrels_out / (wct_seconds()-t0));
254 }
255
sm_append_slave(sm_side_info * sm_info,int nb_polys)256 static void sm_append_slave(sm_side_info *sm_info, int nb_polys)
257 {
258 /* need to know how many mp_limb_t's we'll get back from each batch */
259 size_t limbs_per_ell = 0;
260 int nsm_total=0;
261 int maxdeg = 0;
262 int rank;
263 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
264
265 for(int side = 0; side < nb_polys; side++) {
266 nsm_total += sm_info[side]->nsm;
267 maxdeg = MAX(maxdeg, sm_info[side]->f->deg);
268 if (sm_info[side]->nsm) limbs_per_ell = mpz_size(sm_info[side]->ell);
269 }
270
271
272 for(int turn = 0 ; ; turn++) {
273 unsigned long bsize;
274 MPI_Recv(&bsize, 1, MPI_UNSIGNED_LONG, 0, turn, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
275 if (bsize == 0) {
276 debug_fprintf(stderr, "%.3f turn %d peer %d receive finish\n", wct_seconds(), turn, rank);
277 break;
278 }
279 ab_pair_batch batch(bsize);
280 MPI_Recv((char*) batch.data(), bsize * sizeof(ab_pair), MPI_BYTE, 0, turn, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
281
282 double t0 = wct_seconds();
283 debug_fprintf(stderr, "%.3f turn %d peer %d start on batch of size %lu\n", wct_seconds(), turn, rank, bsize);
284 mp_limb_t * returns = new mp_limb_t[bsize * nsm_total * limbs_per_ell];
285 memset(returns, 0, bsize*nsm_total*limbs_per_ell*sizeof(mp_limb_t));
286
287 #ifdef HAVE_OPENMP
288 // #pragma omp parallel
289 #endif
290 {
291 cxx_mpz_poly smpol(maxdeg), pol(1);
292 #ifdef HAVE_OPENMP
293 // #pragma omp for
294 #endif
295 for(unsigned long i = 0 ; i < bsize ; i++) {
296 mpz_poly_setcoeff_int64(pol, 0, batch[i].a);
297 mpz_poly_setcoeff_int64(pol, 1, -(int64_t) batch[i].b);
298 int smidx = 0;
299 for (int side = 0; side < nb_polys; ++side) {
300 compute_sm_piecewise(smpol, pol, sm_info[side]);
301 for(int k = 0 ; k < sm_info[side]->nsm ; k++, smidx++) {
302 if (k <= smpol->deg) {
303 mp_limb_t * rix = returns + (i * nsm_total + smidx) * limbs_per_ell;
304 for(size_t j = 0 ; j < limbs_per_ell ; j++) {
305 rix[j] = mpz_getlimbn(smpol->coeff[k], j);
306 }
307 }
308 }
309 }
310 }
311 }
312 debug_fprintf(stderr, "%.3f " CSI_BLUE "turn %d peer %d done batch of size %lu" CSI_RESET " [taken %.1f]\n", wct_seconds(), turn, rank, bsize, wct_seconds() - t0);
313 if (rank == 1 && turn == 2)
314 fprintf(stderr, "# peer processes batch of %lu in %.1f [%.1f SMs/s]\n",
315 bsize,
316 wct_seconds() - t0,
317 bsize / (wct_seconds() - t0));
318
319 t0 = wct_seconds();
320 MPI_Send(returns, bsize * nsm_total * limbs_per_ell * sizeof(mp_limb_t), MPI_BYTE, 0, turn, MPI_COMM_WORLD);
321 delete[] returns;
322 debug_fprintf(stderr, "%.3f turn %d peer %d send return took %.1f\n", wct_seconds(), turn, rank, wct_seconds() - t0);
323 }
324 }
325
sm_append_sync(FILE * in,FILE * out,sm_side_info * sm_info,int nb_polys)326 static void sm_append_sync(FILE * in, FILE * out, sm_side_info *sm_info, int nb_polys)
327 {
328 char buf[1024];
329 mpz_poly pol, smpol;
330 int maxdeg = sm_info[0]->f->deg;
331 for(int side = 1; side < nb_polys; side++)
332 maxdeg = MAX(maxdeg, sm_info[side]->f->deg);
333 mpz_poly_init(pol, maxdeg);
334 mpz_poly_init(smpol, maxdeg);
335 while (fgets(buf, 1024, in)) {
336 int n = strlen(buf);
337 if (!n) break;
338 buf[n-1]='\0';
339
340 if (buf[0] == '#') {
341 fputs(buf, out);
342 fputc('\n', out);
343 continue;
344 }
345
346 char * p = buf;
347 int64_t a; /* only a is allowed to be negative */
348 uint64_t b;
349 int64_t sign = 1;
350 if (*p=='-') {
351 sign=-1;
352 p++;
353 }
354 if (sscanf(p, "%" SCNx64 ",%" SCNx64 ":", &a, &b) < 2) {
355 fprintf(stderr, "Parse error at line: %s\n", buf);
356 exit(EXIT_FAILURE);
357 }
358
359 mpz_poly_init_set_ab(pol, a*sign, b);
360
361 fputs(buf, out);
362 fputc(':', out);
363 for (int side = 0; side < nb_polys; ++side) {
364 compute_sm_piecewise(smpol, pol, sm_info[side]);
365 print_sm2(out, sm_info[side], smpol, ",");
366 if (side == 0 && sm_info[0]->nsm > 0 && sm_info[1]->nsm > 0)
367 fputc(',', out);
368 }
369 fputc('\n', out);
370 mpz_poly_clear(pol);
371 }
372 mpz_poly_clear(smpol);
373 }
374
375
sm_append(FILE * in,FILE * out,sm_side_info * sm_info,int nb_polys)376 static void sm_append(FILE * in, FILE * out, sm_side_info *sm_info, int nb_polys)
377 {
378 int rank;
379 int size;
380 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
381 MPI_Comm_size(MPI_COMM_WORLD, &size);
382
383 if (size > 1) {
384 if (rank == 0) {
385 sm_append_master(in, out, sm_info, nb_polys, size);
386 } else {
387 sm_append_slave(sm_info, nb_polys);
388 }
389 } else {
390 sm_append_sync(in, out, sm_info, nb_polys);
391 }
392 }
393
394
declare_usage(param_list pl)395 static void declare_usage(param_list pl)
396 {
397 param_list_decl_usage(pl, "poly", "(required) poly file");
398 param_list_decl_usage(pl, "ell", "(required) group order");
399 param_list_decl_usage(pl, "nsm", "number of SMs to use per side");
400 param_list_decl_usage(pl, "sm-mode", "SM mode (see sm-portability.h)");
401 param_list_decl_usage(pl, "in", "data input (defaults to stdin)");
402 param_list_decl_usage(pl, "out", "data output (defaults to stdout)");
403 param_list_decl_usage(pl, "b", "batch size for MPI loop");
404 verbose_decl_usage(pl);
405 }
406
usage(const char * argv,const char * missing,param_list pl)407 static void usage (const char *argv, const char * missing, param_list pl)
408 {
409 if (missing) {
410 fprintf(stderr, "\nError: missing or invalid parameter \"-%s\"\n",
411 missing);
412 }
413 param_list_print_usage(pl, argv, stderr);
414 exit (EXIT_FAILURE);
415 }
416
417 /* -------------------------------------------------------------------------- */
418
419 // coverity[root_function]
main(int argc,char ** argv)420 int main (int argc, char **argv)
421 {
422 MPI_Init(&argc, & argv);
423 int rank;
424 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
425
426 char *argv0 = argv[0];
427
428 const char *polyfile = NULL;
429
430 param_list pl;
431 cado_poly pol;
432 mpz_poly_ptr F[NB_POLYS_MAX];
433
434 mpz_t ell;
435
436 /* read params */
437 param_list_init(pl);
438 declare_usage(pl);
439
440 if (argc == 1)
441 usage (argv[0], NULL, pl);
442
443 argc--,argv++;
444 for ( ; argc ; ) {
445 if (param_list_update_cmdline (pl, &argc, &argv)) { continue; }
446 fprintf (stderr, "Unhandled parameter %s\n", argv[0]);
447 usage (argv0, NULL, pl);
448 }
449
450 /* Read poly filename from command line */
451 if ((polyfile = param_list_lookup_string(pl, "poly")) == NULL) {
452 fprintf(stderr, "Error: parameter -poly is mandatory\n");
453 param_list_print_usage(pl, argv0, stderr);
454 exit(EXIT_FAILURE);
455 }
456
457 /* Read ell from command line (assuming radix 10) */
458 mpz_init (ell);
459 if (!param_list_parse_mpz(pl, "ell", ell)) {
460 fprintf(stderr, "Error: parameter -ell is mandatory\n");
461 param_list_print_usage(pl, argv0, stderr);
462 exit(EXIT_FAILURE);
463 }
464
465 /* Init polynomial */
466 cado_poly_init (pol);
467 cado_poly_read(pol, polyfile);
468 for(int side = 0; side < pol->nb_polys; side++)
469 F[side] = pol->pols[side];
470
471 int nsm_arg[NB_POLYS_MAX];
472 for(int side = 0; side < pol->nb_polys; side++)
473 nsm_arg[side]=-1;
474
475 param_list_parse_int_list (pl, "nsm", nsm_arg, pol->nb_polys, ",");
476
477 FILE * in = rank ? NULL : stdin;
478 FILE * out = rank ? NULL: stdout;
479 const char * infilename = param_list_lookup_string(pl, "in");
480 const char * outfilename = param_list_lookup_string(pl, "out");
481
482 if (!rank && infilename) {
483 in = fopen_maybe_compressed(infilename, "r");
484 ASSERT_ALWAYS(in != NULL);
485 }
486 if (!rank && outfilename) {
487 out = fopen_maybe_compressed(outfilename, "w");
488 ASSERT_ALWAYS(out != NULL);
489 }
490
491 param_list_parse_uint(pl, "b", &batch_size);
492
493 verbose_interpret_parameters(pl);
494
495 const char * sm_mode_string = param_list_lookup_string(pl, "sm-mode");
496
497 if (param_list_warn_unused(pl))
498 usage (argv0, NULL, pl);
499
500 if (!rank)
501 param_list_print_command_line (stdout, pl);
502
503 sm_side_info sm_info[NB_POLYS_MAX];
504
505 for(int side = 0 ; side < pol->nb_polys; side++) {
506 sm_side_info_init(sm_info[side], F[side], ell);
507 sm_side_info_set_mode(sm_info[side], sm_mode_string);
508 if (nsm_arg[side] >= 0)
509 sm_info[side]->nsm = nsm_arg[side]; /* command line wins */
510 if (!rank)
511 printf("# Using %d SMs on side %d\n", sm_info[side]->nsm, side);
512 }
513
514 /*
515 if (!rank) {
516 for (int side = 0; side < pol->nb_polys; side++) {
517 printf("\n# Polynomial on side %d:\nF[%d] = ", side, side);
518 mpz_poly_fprintf(stdout, F[side]);
519
520 printf("# SM info on side %d:\n", side);
521 sm_side_info_print(stdout, sm_info[side]);
522
523 fflush(stdout);
524 }
525 }
526 */
527
528 sm_append(in, out, sm_info, pol->nb_polys);
529
530 /* Make sure we print no footer line, because reconstructlog-dl won't
531 * grok it */
532 if (!rank) {
533 fflush(stdout);
534 }
535
536 if (!rank && infilename) fclose_maybe_compressed(in, infilename);
537 if (!rank && out != stdout) fclose_maybe_compressed(out, outfilename);
538
539 for(int side = 0 ; side < pol->nb_polys ; side++) {
540 sm_side_info_clear(sm_info[side]);
541 }
542
543 mpz_clear(ell);
544 cado_poly_clear(pol);
545 param_list_clear(pl);
546
547 MPI_Finalize();
548
549 return 0;
550 }
551