1 #include "cado.h" // IWYU pragma: keep
2 
3 /* This is a very silly program which merely reads the matrix and then
4  * exits. It must come before the prep program, and coalescing this one
5  * and prep into one would be difficult and/or artificial, because they
6  * handle different data widths.
7  */
8 #include <cstdint>     /* AIX wants it first (it's a bug) */
9 #include <cstdio>
10 #include <cstring>
11 #include <cstdlib>
12 #include <gmp.h>                 // for mpz_cmp_ui
13 #include "balancing.h"           // for DUMMY_VECTOR_COORD_VALUE, DUMMY_VECT...
14 #include "parallelizing_info.h"
15 #include "matmul_top.h"
16 #include "select_mpi.h"
17 #include "params.h"
18 #include "bw-common.h"
19 #include "async.h"
20 #include "mpfq/mpfq.h"
21 #include "mpfq/mpfq_vbase.h"
22 #include "cheating_vec_init.h"
23 #include "intersections.h"
24 #include "macros.h"
25 
dispatch_prog(parallelizing_info_ptr pi,param_list pl,void * arg MAYBE_UNUSED)26 void * dispatch_prog(parallelizing_info_ptr pi, param_list pl, void * arg MAYBE_UNUSED)
27 {
28     matmul_top_data mmt;
29 
30     int ys[2] = { bw->ys[0], bw->ys[1], };
31     /*
32      * Hmm. Interleaving doesn't make a lot of sense for this program,
33      * right ? Furthermore, it gets in the way for the sanity checks. We
34      * tend to always receive ys=0..64 as an argument.
35     if (pi->interleaved) {
36         ASSERT_ALWAYS((bw->ys[1]-bw->ys[0]) % 2 == 0);
37         ys[0] = bw->ys[0] + pi->interleaved->idx * (bw->ys[1]-bw->ys[0])/2;
38         ys[1] = ys[0] + (bw->ys[1]-bw->ys[0])/2;
39     }
40     */
41 
42     mpfq_vbase A;
43     mpfq_vbase_oo_field_init_byfeatures(A,
44             MPFQ_PRIME_MPZ, bw->p,
45             MPFQ_SIMD_GROUPSIZE, ys[1]-ys[0],
46             MPFQ_DONE);
47 
48     block_control_signals();
49 
50     /*****************************************
51      *             Watch out !               *
52      *****************************************/
53 
54     /* HERE is the place where something actually happens. The rest of
55      * this function are just sanity checks. Matrix dispatch can, in
56      * fact, be done from any program which does matmul_top_init. This
57      * function calls mmt_finish_init, which calls
58      * matmul_top_read_submatrix, which eventuallmy,a!ls
59      * balancing_get_matrix_u32, which hooks into the balancing code.
60      * This is UNLESS either of the two command-line arguments are set,
61      * because these trigger special behaviour:
62      * export_cachelist : see this file, in fact. This makes "dispatch" a
63      * quick cache file listing tool, which helps the perl script a bit.
64      * random_matrix_size : for creating test matrices, essentially for
65      * krylov/mksol speed testing.
66      */
67     matmul_top_init(mmt, A, pi, pl, bw->dir);
68 
69     mmt_vec ymy[2];
70     mmt_vec_ptr y = ymy[0];
71     mmt_vec_ptr my = ymy[1];
72     mmt_vec_init(mmt,0,0, y,  1, 0, mmt->n[1]);
73     mmt_vec_init(mmt,0,0, my, 0, 0, mmt->n[0]);
74 
75     unsigned int unpadded = MAX(mmt->n0[0], mmt->n0[1]);
76 
77     const char * sanity_check_vector = param_list_lookup_string(pl, "sanity_check_vector");
78     int only_export = param_list_lookup_string(pl, "export_cachelist") != NULL;
79 
80     // in no situation shall we try to do our sanity check if we've just
81     // been told to export our cache list. Note also that this sanity
82     // check is currently only valid for GF(2).
83     if (sanity_check_vector != NULL && !only_export && mpz_cmp_ui(bw->p, 2) == 0 && mmt->abase->simd_groupsize(mmt->abase)) {
84         /* We have computed a sanity check vector, which is H=M*K, with K
85          * constant and easily given. Note that we have not computed K*M,
86          * but really M*K. Thus independently of which side we prefer, we
87          * are going to check the matrix product in a rigid direction.
88          *
89          * check 1: compute M*K again using the mmt structures, and
90          * compare with H1. This is matmul_top_mul(mmt, 1).
91          *
92          * check 2: for some vector L, compute (L, H=M*K) and (L*M, K).
93          * This means matmul_top_mul(mmt, 0).
94          */
95         const char * checkname;
96 
97         checkname = "1st check: consistency of M*arbitrary1 (Hx == H1)";
98 
99         mmt_full_vec_set_zero(y);
100         ASSERT_ALWAYS(y->siblings);     /* shared vector undesired */
101         for(unsigned int i = y->i0 ; i < y->i1 && i < unpadded ; i++) {
102             void * dst = A->vec_subvec(A, y->v, i - y->i0);
103             uint64_t value = DUMMY_VECTOR_COORD_VALUE(i);
104             memcpy(dst, &value, sizeof(uint64_t));
105         }
106         mmt_vec_twist(mmt, y);
107         matmul_top_mul(mmt, ymy, NULL);
108         mmt_vec_untwist(mmt, y);
109 
110         mmt_vec_save(y, "Hx%u-%u", unpadded, 0);
111 
112         // compare if files are equal.
113         if (pi->m->jrank == 0 && pi->m->trank == 0) {
114             char cmd[1024];
115             int rc = snprintf(cmd, 80, "diff -q %s Hx0-64", sanity_check_vector);
116             ASSERT_ALWAYS(rc>=0);
117             rc = system(cmd);
118             if (rc) {
119                 printf("%s : failed\n", checkname);
120 #ifdef WEXITSTATUS
121                 fprintf(stderr, "%s returned %d\n", cmd, WEXITSTATUS(rc));
122 #else
123                 fprintf(stderr, "%s returned %d\n", cmd, rc);
124 #endif
125                 exit(EXIT_FAILURE);
126             } else {
127                 printf("%s : ok\n", checkname);
128             }
129         }
130         serialize(pi->m);
131 
132         checkname = "2nd check: (arbitrary2, M*arbitrary1) == (arbitrary2*M==Hy, arbitrary1)";
133 
134         mmt_full_vec_set_zero(my);
135         ASSERT_ALWAYS(my->siblings);     /* shared vector undesired */
136         for(unsigned int i = my->i0 ; i < my->i1 && i < unpadded ; i++) {
137             void * dst = A->vec_subvec(A, my->v, i - my->i0);
138             uint64_t value = DUMMY_VECTOR_COORD_VALUE2(i);
139             memcpy(dst, &value, sizeof(uint64_t));
140         }
141         /* This is L. Now compute the dot product. */
142         void * dp0;
143         void * dp1;
144         cheating_vec_init(A, &dp0, A->simd_groupsize(A));
145         cheating_vec_init(A, &dp1, A->simd_groupsize(A));
146         unsigned int how_many;
147         unsigned int offset_c;
148         unsigned int offset_v;
149         how_many = intersect_two_intervals(&offset_c, &offset_v,
150                 my->i0, my->i1,
151                 y->i0, y->i1);
152         A->vec_set_zero(A, dp0, A->simd_groupsize(A));
153         A->add_dotprod(A, dp0,
154                 A->vec_subvec(A, my->v, offset_c),
155                 A->vec_subvec(A, y->v, offset_v),
156                 how_many);
157         pi_allreduce(NULL, dp0, A->simd_groupsize(A), mmt->pitype, BWC_PI_SUM, pi->m);
158 
159         /* now we can throw away Hx */
160 
161         /* we do a transposed multiplication, here. It's a bit of a
162          * quirk, admittedly. We need to build a reversed vector list.
163          */
164         mmt_vec_twist(mmt, my);
165         {
166             mmt_vec myy[2];
167             mmt_vec_init(mmt,0,0, myy[0],  0, 0, mmt->n[0]);
168             mmt_vec_init(mmt,0,0, myy[1],  1, 0, mmt->n[1]);
169             mmt_full_vec_set(myy[0], my);
170             matmul_top_mul(mmt, myy, NULL);
171             mmt_full_vec_set(my, myy[0]);
172             mmt_vec_clear(mmt, myy[0]);
173             mmt_vec_clear(mmt, myy[1]);
174         }
175         mmt_vec_untwist(mmt, my);
176         mmt_vec_save(my, "Hy%u-%u", unpadded, 0);
177 
178         mmt_full_vec_set_zero(y);
179         ASSERT_ALWAYS(y->siblings);     /* shared vector undesired */
180         for(unsigned int i = y->i0 ; i < y->i1 && i < unpadded ; i++) {
181             void * dst = A->vec_subvec(A, y->v, i - y->i0);
182             uint64_t value = DUMMY_VECTOR_COORD_VALUE(i);
183             memcpy(dst, &value, sizeof(uint64_t));
184         }
185         A->vec_set_zero(A, dp1, A->simd_groupsize(A));
186         A->add_dotprod(A, dp1,
187                 A->vec_subvec(A, my->v, offset_c),
188                 A->vec_subvec(A, y->v, offset_v),
189                 how_many);
190         pi_allreduce(NULL, dp1, A->simd_groupsize(A), mmt->pitype, BWC_PI_SUM, pi->m);
191         int diff = memcmp(dp0, dp1, A->vec_elt_stride(A, A->simd_groupsize(A)));
192         if (pi->m->jrank == 0 && pi->m->trank == 0) {
193             if (diff) {
194                 printf("%s : failed\n", checkname);
195                 fprintf(stderr, "aborting on sanity check failure.\n");
196                 exit(1);
197             }
198             printf("%s : ok\n", checkname);
199         }
200         cheating_vec_clear(A, &dp0, A->simd_groupsize(A));
201         cheating_vec_clear(A, &dp1, A->simd_groupsize(A));
202     }
203 
204     mmt_vec_clear(mmt, y);
205     mmt_vec_clear(mmt, my);
206     matmul_top_clear(mmt);
207 
208     A->oo_field_clear(A);
209 
210     return NULL;
211 }
212 
213 
214 // coverity[root_function]
main(int argc,char * argv[])215 int main(int argc, char * argv[])
216 {
217     param_list pl;
218 
219     bw_common_init(bw, &argc, &argv);
220     param_list_init(pl);
221     parallelizing_info_init();
222 
223     bw_common_decl_usage(pl);
224     parallelizing_info_decl_usage(pl);
225     matmul_top_decl_usage(pl);
226     /* declare local parameters and switches */
227 
228     bw_common_parse_cmdline(bw, pl, &argc, &argv);
229 
230     bw_common_interpret_parameters(bw, pl);
231     parallelizing_info_lookup_parameters(pl);
232     matmul_top_lookup_parameters(pl);
233     /* interpret our parameters: none here (so far). */
234 
235     ASSERT_ALWAYS(param_list_lookup_string(pl, "ys"));
236     ASSERT_ALWAYS(!param_list_lookup_string(pl, "solutions"));
237 
238     if (param_list_warn_unused(pl)) {
239         int rank;
240         MPI_Comm_rank(MPI_COMM_WORLD, &rank);
241         if (!rank) param_list_print_usage(pl, bw->original_argv[0], stderr);
242         MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
243     }
244     // param_list_clear(pl);
245 
246     if (bw->ys[0] < 0) { fprintf(stderr, "no ys value set\n"); exit(1); }
247 
248     /* Forcibly disable interleaving here */
249     param_list_remove_key(pl, "interleaving");
250 
251     catch_control_signals();
252     pi_go(dispatch_prog, pl, 0);
253 
254     parallelizing_info_finish();
255     param_list_clear(pl);
256     bw_common_clear(bw);
257 
258     return 0;
259 }
260 
261