1 /* Shirokauer maps
2 
3    This is largely copied from sm_simple.cpp ; however the purpose of
4    sm_simple was to be kept simple and stupid, so let's keep it like
5    this.
6 
7    This program is to be used as an mpi accelerator for reconstructlog-dl
8 
9    Given a relation file where each line begins with an a,b pair IN
10    HEXADECIMAL FORMAT, output the same line, appending the SM values in
11    the end.
12 
13    SM computation is offloaded to the (many) MPI jobs.
14 
15 */
16 
17 #include "cado.h" // IWYU pragma: keep
18 #include <cstdio>
19 #include <cstdlib>
20 #include <cstring>
21 #include <cinttypes>
22 #include <vector>
23 #include <string>
24 #include <cstdarg>     // for va_end, va_list, va_start
25 #include <cstdint>     // for int64_t, uint64_t
26 #include <iosfwd>       // for std
27 #include <memory>       // for allocator_traits<>::value_type
28 #include <gmp.h>
29 #include "cado_poly.h"  // for NB_POLYS_MAX, cado_poly_clear, cado_poly_init
30 #include "gzip.h"       // fopen_maybe_compressed
31 #include "macros.h"
32 #include "mpz_poly.h"   // for mpz_poly_clear, mpz_poly_init, mpz_poly, mpz_...
33 #include "params.h"
34 #include "select_mpi.h"
35 #include "sm_utils.h"   // sm_side_info
36 #include "timing.h"     // seconds
37 #include "verbose.h"    // verbose_output_print
38 
39 using namespace std;
40 
41 struct ab_pair {
42     int64_t a;		/* only a is allowed to be negative */
43     uint64_t b;
44 };
45 
46 typedef vector<ab_pair> ab_pair_batch;
47 
48 unsigned int batch_size = 128;
49 unsigned int debug = 0;
50 
51 struct task_globals {
52     int nsm_total;
53     size_t limbs_per_ell;
54     FILE * in;
55     FILE * out;
56     size_t nrels_in;
57     size_t nrels_out;
58 };
59 
60 struct peer_status {
61     MPI_Request req;
62     ab_pair_batch batch;
63     vector<string> rels;
64     void receive(task_globals & tg, int peer, int turn);
65     void send_finish(task_globals & tg, int peer, int turn);
66     /* returns 1 on eof, 0 normally */
67     int create_and_send_batch(task_globals& tg, int peer, int turn);
68 };
69 
70 static int debug_fprintf(FILE * out, const char * fmt, ...)
71     ATTR_PRINTF(2, 3);
debug_fprintf(FILE * out,const char * fmt,...)72 static int debug_fprintf(FILE * out, const char * fmt, ...)
73 {
74     va_list ap;
75     va_start(ap, fmt);
76     int rc = debug ? vfprintf(out, fmt, ap) : 1;
77     va_end(ap);
78     return rc;
79 }
80 
receive(task_globals & tg,int peer,int turn)81 void peer_status::receive(task_globals & tg, int peer, int turn)
82 {
83     unsigned long bsize = batch.size();
84 
85     if (!bsize) return;
86 
87     MPI_Wait(&req, MPI_STATUS_IGNORE);
88     batch.clear();
89 
90     mp_limb_t * returns = new mp_limb_t[bsize * tg.nsm_total * tg.limbs_per_ell];
91     // [bsize][tg.nsm_total][tg.limbs_per_ell];
92 
93     MPI_Recv(returns, bsize * tg.nsm_total * tg.limbs_per_ell * sizeof(mp_limb_t), MPI_BYTE, peer, turn, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
94 
95     ASSERT_ALWAYS(rels.size() == bsize);
96 
97     for(unsigned long i = 0 ; i < bsize ; i++) {
98         fputs(rels[i].c_str(), tg.out);
99         bool comma = false;
100         for (int j = 0 ; j < tg.nsm_total ; j++) {
101             mp_limb_t * rij = returns + ((i * tg.nsm_total) + j) * tg.limbs_per_ell;
102             gmp_fprintf(tg.out, "%c%Nd", comma ? ',' : ':', rij, tg.limbs_per_ell);
103             comma=true;
104         }
105         fputc('\n', tg.out);
106         tg.nrels_out++;
107     }
108     rels.clear();
109 
110     delete[] returns;
111 }
112 
send_finish(task_globals &,int peer,int turn)113 void peer_status::send_finish(task_globals &, int peer, int turn)
114 {
115     /* tell this slave to finish */
116     unsigned long bsize = 0;
117     MPI_Send(&bsize, 1, MPI_UNSIGNED_LONG, peer, turn, MPI_COMM_WORLD);
118 }
119 
create_and_send_batch(task_globals & tg,int peer,int turn)120 int peer_status::create_and_send_batch(task_globals& tg, int peer, int turn)
121 {
122     int eof = 0;
123     char buf[1024];
124 
125     ASSERT_ALWAYS(batch.empty());
126     while (!eof && batch.size() < batch_size && fgets(buf, 1024, tg.in)) {
127         int n = strlen(buf);
128         if (!n) {
129             fprintf(stderr, "Got 0-sized buffer in fgets, shouldn't happen. Assuming EOF.\n");
130             eof=true;
131             break;
132         }
133         buf[n-1]='\0';
134 
135         if (buf[0] == '#') {
136             fputs(buf, tg.out);
137             fputc('\n', tg.out);
138             continue;
139         }
140 
141         tg.nrels_in++;
142         rels.push_back(string(buf));
143 
144         char * p = buf;
145         ab_pair ab;
146         int64_t sign = 1;
147         if (*p=='-') {
148             sign=-1;
149             p++;
150         }
151         if (sscanf(p, "%" SCNx64 ",%" SCNx64 ":", &ab.a, &ab.b) < 2) {
152             fprintf(stderr, "Parse error at line: %s\n", buf);
153             exit(EXIT_FAILURE);
154         }
155         ab.a *= sign;
156 
157         batch.push_back(ab);
158     }
159     if (!eof && batch.size() < batch_size) {
160         eof=true;
161         if (ferror(stdin)) {
162             fprintf(stderr, "Error on stdin\n");
163         }
164     }
165     /* 0 bsize will be recognized by slaves as a
166      * reason to stop processing */
167     unsigned long bsize = batch.size();
168     MPI_Send(&bsize, 1, MPI_UNSIGNED_LONG, peer, turn, MPI_COMM_WORLD);
169     if (bsize)
170         MPI_Isend((char*) batch.data(), bsize * sizeof(ab_pair), MPI_BYTE, peer, turn, MPI_COMM_WORLD, &req);
171     return eof;
172 }
173 
174 #define CSI_RED "\033[00;31m"
175 #define CSI_GREEN "\033[00;32m"
176 #define CSI_YELLOW "\033[00;33m"
177 #define CSI_BLUE "\033[00;34m"
178 #define CSI_PINK "\033[00;35m"
179 #define CSI_BOLDRED "\033[01;31m"
180 #define CSI_BOLDGREEN "\033[01;32m"
181 #define CSI_BOLDYELLOW "\033[01;33m"
182 #define CSI_BOLDBLUE "\033[01;34m"
183 #define CSI_BOLDPINK "\033[01;35m"
184 #define CSI_RESET "\033[m"
185 
sm_append_master(FILE * in,FILE * out,sm_side_info * sm_info,int nb_polys,int size)186 static void sm_append_master(FILE * in, FILE * out, sm_side_info *sm_info, int nb_polys, int size)
187 {
188     /* need to know how many mp_limb_t's we'll get back from each batch */
189     task_globals tg;
190     tg.limbs_per_ell = 0;
191     tg.nsm_total=0;
192     tg.in = in;
193     tg.out = out;
194     tg.nrels_in = 0;
195     tg.nrels_out = 0;
196     for(int side = 0; side < nb_polys; side++) {
197         tg.nsm_total += sm_info[side]->nsm;
198         if (sm_info[side]->nsm)
199             tg.limbs_per_ell = mpz_size(sm_info[side]->ell);
200     }
201 
202     std::vector<peer_status> peers(size);
203 
204     int eof = 0;
205     /* eof = 1 on first time. eof = 2 when all receives are done */
206 
207     fprintf(stderr, "# running master with %d slaves, batch size %u\n",
208             size-1, batch_size);
209     fprintf(stderr, "# make sure you use \"--bind-to core\" or equivalent\n");
210 
211     double t0 = wct_seconds();
212     int turn;
213     for(turn = 0 ; eof <= 2 ; turn++, eof += !!eof) {
214         double t = wct_seconds();
215         debug_fprintf(stderr, "%.3f " CSI_BOLDRED "start turn %d" CSI_RESET "\n", t0, turn);
216         for(int peer = 1; peer < size; peer++) {
217             if (eof && peers[peer].batch.empty()) {
218                 /* Our last send was a 0-send, so we have nothing to do */
219                 continue;
220             }
221 
222             double dt = wct_seconds();
223             debug_fprintf(stderr, "%.3f start turn %d receive from peer %d\n", wct_seconds(), turn - 1, peer);
224             peers[peer].receive(tg, peer, turn - 1);
225             dt = wct_seconds() - dt;
226             debug_fprintf(stderr, "%.3f done turn %d receive from peer %d [taken %.1f]\n", wct_seconds(), turn - 1, peer, dt);
227 
228             if (eof) {
229                 debug_fprintf(stderr, "%.3f start turn %d send finish to peer %d\n", wct_seconds(), turn, peer);
230                 peers[peer].send_finish(tg, peer, turn);
231             } else if (!eof) {
232                 dt = wct_seconds();
233                 debug_fprintf(stderr, "%.3f start turn %d send to peer %d\n", wct_seconds(), turn, peer);
234                 eof = peers[peer].create_and_send_batch(tg, peer, turn);
235                 dt = wct_seconds() - dt;
236                 debug_fprintf(stderr, "%.3f done turn %d send to peer %d [taken %.1f]\n", wct_seconds(), turn, peer, dt);
237             }
238         }
239         debug_fprintf(stderr, "%.3f " CSI_BOLDRED "done turn %d " CSI_RESET "[taken %.1f] s\n", wct_seconds(), turn, wct_seconds()-t);
240         if (turn && !(turn & (turn+1))) {
241             /* print only when turn is a power of two */
242             fprintf(stderr, "# printed %zu rels in %.1f s"
243                     " (%.1f / batch, %.1f rels/s)\n",
244                     tg.nrels_out, wct_seconds()-t0,
245                     (wct_seconds()-t0) / turn,
246                     tg.nrels_out / (wct_seconds()-t0));
247         }
248     }
249     fprintf(stderr, "# final: printed %zu rels in %.1f s"
250             " (%.1f / batch, %.1f rels/s)\n",
251             tg.nrels_out, wct_seconds()-t0,
252             (wct_seconds()-t0) / turn,
253             tg.nrels_out / (wct_seconds()-t0));
254 }
255 
sm_append_slave(sm_side_info * sm_info,int nb_polys)256 static void sm_append_slave(sm_side_info *sm_info, int nb_polys)
257 {
258     /* need to know how many mp_limb_t's we'll get back from each batch */
259     size_t limbs_per_ell = 0;
260     int nsm_total=0;
261     int maxdeg = 0;
262     int rank;
263     MPI_Comm_rank(MPI_COMM_WORLD, &rank);
264 
265     for(int side = 0; side < nb_polys; side++) {
266         nsm_total += sm_info[side]->nsm;
267         maxdeg = MAX(maxdeg, sm_info[side]->f->deg);
268         if (sm_info[side]->nsm) limbs_per_ell = mpz_size(sm_info[side]->ell);
269     }
270 
271 
272     for(int turn = 0 ; ; turn++) {
273         unsigned long bsize;
274         MPI_Recv(&bsize, 1, MPI_UNSIGNED_LONG, 0, turn, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
275         if (bsize == 0) {
276             debug_fprintf(stderr, "%.3f turn %d peer %d receive finish\n", wct_seconds(), turn, rank);
277             break;
278         }
279         ab_pair_batch batch(bsize);
280         MPI_Recv((char*) batch.data(), bsize * sizeof(ab_pair), MPI_BYTE, 0, turn, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
281 
282         double t0 = wct_seconds();
283         debug_fprintf(stderr, "%.3f turn %d peer %d start on batch of size %lu\n", wct_seconds(), turn, rank, bsize);
284         mp_limb_t * returns = new mp_limb_t[bsize * nsm_total * limbs_per_ell];
285         memset(returns, 0, bsize*nsm_total*limbs_per_ell*sizeof(mp_limb_t));
286 
287 #ifdef HAVE_OPENMP
288 // #pragma omp parallel
289 #endif
290         {
291             cxx_mpz_poly smpol(maxdeg), pol(1);
292 #ifdef HAVE_OPENMP
293 // #pragma omp for
294 #endif
295             for(unsigned long i = 0 ; i < bsize ; i++) {
296                 mpz_poly_setcoeff_int64(pol, 0, batch[i].a);
297                 mpz_poly_setcoeff_int64(pol, 1, -(int64_t) batch[i].b);
298                 int smidx = 0;
299                 for (int side = 0; side < nb_polys; ++side) {
300                     compute_sm_piecewise(smpol, pol, sm_info[side]);
301                     for(int k = 0 ; k < sm_info[side]->nsm ; k++, smidx++) {
302                         if (k <= smpol->deg) {
303                             mp_limb_t * rix = returns + (i * nsm_total + smidx) * limbs_per_ell;
304                             for(size_t j = 0 ; j < limbs_per_ell ; j++) {
305                                 rix[j] = mpz_getlimbn(smpol->coeff[k], j);
306                             }
307                         }
308                     }
309                 }
310             }
311         }
312         debug_fprintf(stderr, "%.3f " CSI_BLUE "turn %d peer %d done batch of size %lu" CSI_RESET " [taken %.1f]\n", wct_seconds(), turn, rank, bsize, wct_seconds() - t0);
313         if (rank == 1 && turn == 2)
314         fprintf(stderr, "# peer processes batch of %lu in %.1f [%.1f SMs/s]\n",
315                 bsize,
316                 wct_seconds() - t0,
317                 bsize / (wct_seconds() - t0));
318 
319         t0 = wct_seconds();
320         MPI_Send(returns, bsize * nsm_total * limbs_per_ell * sizeof(mp_limb_t), MPI_BYTE, 0, turn, MPI_COMM_WORLD);
321         delete[] returns;
322         debug_fprintf(stderr, "%.3f turn %d peer %d send return took %.1f\n", wct_seconds(), turn, rank, wct_seconds() - t0);
323     }
324 }
325 
sm_append_sync(FILE * in,FILE * out,sm_side_info * sm_info,int nb_polys)326 static void sm_append_sync(FILE * in, FILE * out, sm_side_info *sm_info, int nb_polys)
327 {
328     char buf[1024];
329     mpz_poly pol, smpol;
330     int maxdeg = sm_info[0]->f->deg;
331     for(int side = 1; side < nb_polys; side++)
332         maxdeg = MAX(maxdeg, sm_info[side]->f->deg);
333     mpz_poly_init(pol, maxdeg);
334     mpz_poly_init(smpol, maxdeg);
335     while (fgets(buf, 1024, in)) {
336         int n = strlen(buf);
337         if (!n) break;
338         buf[n-1]='\0';
339 
340         if (buf[0] == '#') {
341             fputs(buf, out);
342             fputc('\n', out);
343             continue;
344         }
345 
346         char * p = buf;
347         int64_t a;		/* only a is allowed to be negative */
348         uint64_t b;
349         int64_t sign = 1;
350         if (*p=='-') {
351             sign=-1;
352             p++;
353         }
354         if (sscanf(p, "%" SCNx64 ",%" SCNx64 ":", &a, &b) < 2) {
355             fprintf(stderr, "Parse error at line: %s\n", buf);
356             exit(EXIT_FAILURE);
357         }
358 
359         mpz_poly_init_set_ab(pol, a*sign, b);
360 
361         fputs(buf, out);
362         fputc(':', out);
363         for (int side = 0; side < nb_polys; ++side) {
364             compute_sm_piecewise(smpol, pol, sm_info[side]);
365             print_sm2(out, sm_info[side], smpol, ",");
366             if (side == 0 && sm_info[0]->nsm > 0 && sm_info[1]->nsm > 0)
367                 fputc(',', out);
368         }
369         fputc('\n', out);
370         mpz_poly_clear(pol);
371     }
372     mpz_poly_clear(smpol);
373 }
374 
375 
sm_append(FILE * in,FILE * out,sm_side_info * sm_info,int nb_polys)376 static void sm_append(FILE * in, FILE * out, sm_side_info *sm_info, int nb_polys)
377 {
378     int rank;
379     int size;
380     MPI_Comm_rank(MPI_COMM_WORLD, &rank);
381     MPI_Comm_size(MPI_COMM_WORLD, &size);
382 
383     if (size > 1) {
384         if (rank == 0) {
385             sm_append_master(in, out, sm_info, nb_polys, size);
386         } else {
387             sm_append_slave(sm_info, nb_polys);
388         }
389     } else {
390         sm_append_sync(in, out, sm_info, nb_polys);
391     }
392 }
393 
394 
declare_usage(param_list pl)395 static void declare_usage(param_list pl)
396 {
397     param_list_decl_usage(pl, "poly", "(required) poly file");
398     param_list_decl_usage(pl, "ell", "(required) group order");
399     param_list_decl_usage(pl, "nsm", "number of SMs to use per side");
400     param_list_decl_usage(pl, "sm-mode", "SM mode (see sm-portability.h)");
401     param_list_decl_usage(pl, "in", "data input (defaults to stdin)");
402     param_list_decl_usage(pl, "out", "data output (defaults to stdout)");
403     param_list_decl_usage(pl, "b", "batch size for MPI loop");
404     verbose_decl_usage(pl);
405 }
406 
usage(const char * argv,const char * missing,param_list pl)407 static void usage (const char *argv, const char * missing, param_list pl)
408 {
409     if (missing) {
410         fprintf(stderr, "\nError: missing or invalid parameter \"-%s\"\n",
411                 missing);
412     }
413     param_list_print_usage(pl, argv, stderr);
414     exit (EXIT_FAILURE);
415 }
416 
417 /* -------------------------------------------------------------------------- */
418 
419 // coverity[root_function]
main(int argc,char ** argv)420 int main (int argc, char **argv)
421 {
422     MPI_Init(&argc, & argv);
423     int rank;
424     MPI_Comm_rank(MPI_COMM_WORLD, &rank);
425 
426     char *argv0 = argv[0];
427 
428     const char *polyfile = NULL;
429 
430     param_list pl;
431     cado_poly pol;
432     mpz_poly_ptr F[NB_POLYS_MAX];
433 
434     mpz_t ell;
435 
436     /* read params */
437     param_list_init(pl);
438     declare_usage(pl);
439 
440     if (argc == 1)
441         usage (argv[0], NULL, pl);
442 
443     argc--,argv++;
444     for ( ; argc ; ) {
445         if (param_list_update_cmdline (pl, &argc, &argv)) { continue; }
446         fprintf (stderr, "Unhandled parameter %s\n", argv[0]);
447         usage (argv0, NULL, pl);
448     }
449 
450     /* Read poly filename from command line */
451     if ((polyfile = param_list_lookup_string(pl, "poly")) == NULL) {
452         fprintf(stderr, "Error: parameter -poly is mandatory\n");
453         param_list_print_usage(pl, argv0, stderr);
454         exit(EXIT_FAILURE);
455     }
456 
457     /* Read ell from command line (assuming radix 10) */
458     mpz_init (ell);
459     if (!param_list_parse_mpz(pl, "ell", ell)) {
460         fprintf(stderr, "Error: parameter -ell is mandatory\n");
461         param_list_print_usage(pl, argv0, stderr);
462         exit(EXIT_FAILURE);
463     }
464 
465     /* Init polynomial */
466     cado_poly_init (pol);
467     cado_poly_read(pol, polyfile);
468     for(int side = 0; side < pol->nb_polys; side++)
469         F[side] = pol->pols[side];
470 
471     int nsm_arg[NB_POLYS_MAX];
472     for(int side = 0; side < pol->nb_polys; side++)
473         nsm_arg[side]=-1;
474 
475     param_list_parse_int_list (pl, "nsm", nsm_arg, pol->nb_polys, ",");
476 
477     FILE * in = rank ? NULL : stdin;
478     FILE * out = rank ? NULL: stdout;
479     const char * infilename = param_list_lookup_string(pl, "in");
480     const char * outfilename = param_list_lookup_string(pl, "out");
481 
482     if (!rank && infilename) {
483         in = fopen_maybe_compressed(infilename, "r");
484         ASSERT_ALWAYS(in != NULL);
485     }
486     if (!rank && outfilename) {
487         out = fopen_maybe_compressed(outfilename, "w");
488         ASSERT_ALWAYS(out != NULL);
489     }
490 
491     param_list_parse_uint(pl, "b", &batch_size);
492 
493     verbose_interpret_parameters(pl);
494 
495     const char * sm_mode_string = param_list_lookup_string(pl, "sm-mode");
496 
497     if (param_list_warn_unused(pl))
498         usage (argv0, NULL, pl);
499 
500     if (!rank)
501         param_list_print_command_line (stdout, pl);
502 
503     sm_side_info sm_info[NB_POLYS_MAX];
504 
505     for(int side = 0 ; side < pol->nb_polys; side++) {
506         sm_side_info_init(sm_info[side], F[side], ell);
507         sm_side_info_set_mode(sm_info[side], sm_mode_string);
508         if (nsm_arg[side] >= 0)
509             sm_info[side]->nsm = nsm_arg[side]; /* command line wins */
510         if (!rank)
511             printf("# Using %d SMs on side %d\n", sm_info[side]->nsm, side);
512     }
513 
514     /*
515        if (!rank) {
516        for (int side = 0; side < pol->nb_polys; side++) {
517        printf("\n# Polynomial on side %d:\nF[%d] = ", side, side);
518        mpz_poly_fprintf(stdout, F[side]);
519 
520        printf("# SM info on side %d:\n", side);
521        sm_side_info_print(stdout, sm_info[side]);
522 
523        fflush(stdout);
524        }
525        }
526        */
527 
528     sm_append(in, out, sm_info, pol->nb_polys);
529 
530     /* Make sure we print no footer line, because reconstructlog-dl won't
531      * grok it */
532     if (!rank) {
533         fflush(stdout);
534     }
535 
536     if (!rank && infilename) fclose_maybe_compressed(in, infilename);
537     if (!rank && out != stdout) fclose_maybe_compressed(out, outfilename);
538 
539     for(int side = 0 ; side < pol->nb_polys ; side++) {
540         sm_side_info_clear(sm_info[side]);
541     }
542 
543     mpz_clear(ell);
544     cado_poly_clear(pol);
545     param_list_clear(pl);
546 
547     MPI_Finalize();
548 
549     return 0;
550 }
551