1 /******************************************
2 Copyright (c) 2016, Mate Soos
3 
4 Permission is hereby granted, free of charge, to any person obtaining a copy
5 of this software and associated documentation files (the "Software"), to deal
6 in the Software without restriction, including without limitation the rights
7 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 copies of the Software, and to permit persons to whom the Software is
9 furnished to do so, subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in
12 all copies or substantial portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 THE SOFTWARE.
21 ***********************************************/
22 
23 #include "datasync.h"
24 #include "varreplacer.h"
25 #include "solver.h"
26 #include "shareddata.h"
27 #include <iomanip>
28 
29 using namespace CMSat;
30 
DataSync(Solver * _solver,SharedData * _sharedData,bool _is_mpi)31 DataSync::DataSync(Solver* _solver, SharedData* _sharedData, bool _is_mpi) :
32     solver(_solver)
33     , sharedData(_sharedData)
34     , is_mpi(_is_mpi)
35     , seen(solver->seen)
36     , toClear(solver->toClear)
37 {
38 #ifdef USE_MPI
39     if (is_mpi) {
40         int err;
41         err = MPI_Comm_rank(MPI_COMM_WORLD, &mpiRank);
42         assert(err == MPI_SUCCESS);
43 
44         err = MPI_Comm_size(MPI_COMM_WORLD, &mpiSize);
45         assert(err == MPI_SUCCESS);
46         release_assert(!(mpiSize > 1 && mpiRank == 0));
47         assert(!(mpiSize > 1 && sharedData == NULL));
48     }
49 #endif
50 }
51 
set_shared_data(SharedData * _sharedData)52 void DataSync::set_shared_data(SharedData* _sharedData)
53 {
54     sharedData = _sharedData;
55 }
56 
new_var(const bool bva)57 void DataSync::new_var(const bool bva)
58 {
59     if (!enabled())
60         return;
61 
62     if (!bva) {
63         syncFinish.push_back(0);
64         syncFinish.push_back(0);
65     }
66     assert(solver->nVarsOutside()*2 == syncFinish.size());
67 }
68 
new_vars(size_t n)69 void DataSync::new_vars(size_t n)
70 {
71     if (!enabled())
72         return;
73 
74     syncFinish.insert(syncFinish.end(), 2*n, 0);
75     assert(solver->nVarsOutside()*2 == syncFinish.size());
76 }
77 
save_on_var_memory()78 void DataSync::save_on_var_memory()
79 {
80 }
81 
rebuild_bva_map()82 void DataSync::rebuild_bva_map()
83 {
84     must_rebuild_bva_map = true;
85 }
86 
updateVars(const vector<uint32_t> &,const vector<uint32_t> &)87 void DataSync::updateVars(
88     const vector<uint32_t>& /*outerToInter*/
89     , const vector<uint32_t>& /*interToOuter*/
90 ) {
91 }
92 
syncData()93 bool DataSync::syncData()
94 {
95     if (!enabled()
96         || lastSyncConf + solver->conf.sync_every_confl >= solver->sumConflicts
97     ) {
98         return true;
99     }
100     numCalls++;
101 
102     assert(sharedData != NULL);
103     assert(solver->decisionLevel() == 0);
104 
105     if (must_rebuild_bva_map) {
106         outer_to_without_bva_map = solver->build_outer_to_without_bva_map();
107         must_rebuild_bva_map = false;
108     }
109 
110     bool ok;
111     sharedData->unit_mutex.lock();
112     ok = shareUnitData();
113     sharedData->unit_mutex.unlock();
114     if (!ok) return false;
115 
116     sharedData->bin_mutex.lock();
117     extend_bins_if_needed();
118     clear_set_binary_values();
119     ok = shareBinData();
120     sharedData->bin_mutex.unlock();
121     if (!ok) return false;
122 
123     #ifdef USE_MPI
124     if (is_mpi && mpiSize > 1 && solver->conf.thread_num == 0) {
125         sharedData->unit_mutex.lock();
126         sharedData->bin_mutex.lock();
127         ok = syncFromMPI();
128         if (ok && numCalls % 2 == 1) {
129             syncToMPI();
130         }
131         if (!ok) return false;
132     }
133 
134     if (is_mpi) {
135         getNeedToInterruptFromMPI();
136     }
137     #endif
138 
139     lastSyncConf = solver->sumConflicts;
140 
141     return true;
142 }
143 
clear_set_binary_values()144 void DataSync::clear_set_binary_values()
145 {
146     for(size_t i = 0; i < solver->nVarsOutside()*2; i++) {
147         Lit lit1 = Lit::toLit(i);
148         lit1 = solver->map_to_with_bva(lit1);
149         lit1 = solver->varReplacer->get_lit_replaced_with_outer(lit1);
150         lit1 = solver->map_outer_to_inter(lit1);
151         if (solver->value(lit1) != l_Undef) {
152             sharedData->bins[i].clear();
153         }
154     }
155 }
156 
extend_bins_if_needed()157 void DataSync::extend_bins_if_needed()
158 {
159     assert(sharedData->bins.size() <= (solver->nVarsOutside())*2);
160     if (sharedData->bins.size() == (solver->nVarsOutside())*2)
161         return;
162 
163     sharedData->bins.resize(solver->nVarsOutside()*2);
164 }
165 
shareBinData()166 bool DataSync::shareBinData()
167 {
168     uint32_t oldRecvBinData = stats.recvBinData;
169     uint32_t oldSentBinData = stats.sentBinData;
170 
171     syncBinFromOthers();
172     syncBinToOthers();
173     size_t mem = sharedData->calc_memory_use_bins();
174 
175     if (solver->conf.verbosity >= 3) {
176         cout
177         << "c [sync] got bins " << (stats.recvBinData - oldRecvBinData)
178         << " sent bins " << (stats.sentBinData - oldSentBinData)
179         << " mem use: " << mem/(1024*1024) << " M"
180         << endl;
181     }
182 
183     return true;
184 }
185 
syncBinFromOthers()186 bool DataSync::syncBinFromOthers()
187 {
188     for (uint32_t wsLit = 0; wsLit < sharedData->bins.size(); wsLit++) {
189         if (sharedData->bins[wsLit].data == NULL) {
190             continue;
191         }
192 
193         Lit lit1 = Lit::toLit(wsLit);
194         lit1 = solver->map_to_with_bva(lit1);
195         lit1 = solver->varReplacer->get_lit_replaced_with_outer(lit1);
196         lit1 = solver->map_outer_to_inter(lit1);
197         if (solver->varData[lit1.var()].removed != Removed::none
198             || solver->value(lit1.var()) != l_Undef
199         ) {
200             continue;
201         }
202 
203         vector<Lit>& bins = *sharedData->bins[wsLit].data;
204         watch_subarray ws = solver->watches[lit1];
205 
206         assert(syncFinish.size() > wsLit);
207         if (bins.size() > syncFinish[wsLit]
208             && !syncBinFromOthers(lit1, bins, syncFinish[wsLit], ws)
209         ) {
210             return false;
211         }
212     }
213 
214     return true;
215 }
216 
syncBinFromOthers(const Lit lit,const vector<Lit> & bins,uint32_t & finished,watch_subarray ws)217 bool DataSync::syncBinFromOthers(
218     const Lit lit
219     , const vector<Lit>& bins
220     , uint32_t& finished
221     , watch_subarray ws
222 ) {
223     assert(solver->varReplacer->get_lit_replaced_with(lit) == lit);
224     assert(solver->varData[lit.var()].removed == Removed::none);
225 
226     assert(toClear.empty());
227     for (const Watched& w: ws) {
228         if (w.isBin()) {
229             toClear.push_back(w.lit2());
230             assert(seen.size() > w.lit2().toInt());
231             seen[w.lit2().toInt()] = true;
232         }
233     }
234 
235     vector<Lit> lits(2);
236     for (uint32_t i = finished; i < bins.size(); i++) {
237         Lit otherLit = bins[i];
238         otherLit = solver->map_to_with_bva(otherLit);
239         otherLit = solver->varReplacer->get_lit_replaced_with_outer(otherLit);
240         otherLit = solver->map_outer_to_inter(otherLit);
241         if (solver->varData[otherLit.var()].removed != Removed::none
242             || solver->value(otherLit) != l_Undef
243         ) {
244             continue;
245         }
246         assert(seen.size() > otherLit.toInt());
247         if (!seen[otherLit.toInt()]) {
248             stats.recvBinData++;
249             lits[0] = lit;
250             lits[1] = otherLit;
251 
252             //Don't add DRAT: it would add to the thread data, too
253             solver->add_clause_int(lits, true, ClauseStats(), true, NULL, false);
254             if (!solver->ok) {
255                 goto end;
256             }
257         }
258     }
259     finished = bins.size();
260 
261     end:
262     for (const Lit l: toClear) {
263         seen[l.toInt()] = false;
264     }
265     toClear.clear();
266 
267     return solver->okay();
268 }
269 
syncBinToOthers()270 void DataSync::syncBinToOthers()
271 {
272     for(const std::pair<Lit, Lit>& bin: newBinClauses) {
273         addOneBinToOthers(bin.first, bin.second);
274     }
275 
276     newBinClauses.clear();
277 }
278 
addOneBinToOthers(Lit lit1,Lit lit2)279 void DataSync::addOneBinToOthers(Lit lit1, Lit lit2)
280 {
281     assert(lit1 < lit2);
282     if (sharedData->bins[lit1.toInt()].data == NULL) {
283         return;
284     }
285 
286     vector<Lit>& bins = *sharedData->bins[lit1.toInt()].data;
287     for (const Lit lit : bins) {
288         if (lit == lit2)
289             return;
290     }
291 
292     bins.push_back(lit2);
293     stats.sentBinData++;
294 }
295 
shareUnitData()296 bool DataSync::shareUnitData()
297 {
298     uint32_t thisGotUnitData = 0;
299     uint32_t thisSentUnitData = 0;
300 
301     SharedData& shared = *sharedData;
302     if (shared.value.size() < solver->nVarsOutside()) {
303         shared.value.resize(solver->nVarsOutside(), l_Undef);
304     }
305     for (uint32_t var = 0; var < solver->nVarsOutside(); var++) {
306         Lit thisLit = Lit(var, false);
307         thisLit = solver->map_to_with_bva(thisLit);
308         thisLit = solver->varReplacer->get_lit_replaced_with_outer(thisLit);
309         thisLit = solver->map_outer_to_inter(thisLit);
310         const lbool thisVal = solver->value(thisLit);
311         const lbool otherVal = shared.value[var];
312 
313         if (thisVal == l_Undef && otherVal == l_Undef) {
314             continue;
315         }
316 
317         if (thisVal != l_Undef && otherVal != l_Undef) {
318             if (thisVal != otherVal) {
319                 solver->ok = false;
320                 return false;
321             } else {
322                 continue;
323             }
324         }
325 
326         if (otherVal != l_Undef) {
327             assert(thisVal == l_Undef);
328             Lit litToEnqueue = thisLit ^ (otherVal == l_False);
329             if (solver->varData[litToEnqueue.var()].removed != Removed::none) {
330                 continue;
331             }
332 
333             solver->enqueue(litToEnqueue);
334             solver->ok = solver->propagate<false>().isNULL();
335             if (!solver->ok) {
336                 return false;
337             }
338 
339             thisGotUnitData++;
340             continue;
341         }
342 
343         if (thisVal != l_Undef) {
344             assert(otherVal == l_Undef);
345             shared.value[var] = thisVal;
346             thisSentUnitData++;
347             continue;
348         }
349     }
350 
351     if (solver->conf.verbosity >= 3
352         //&& (thisGotUnitData > 0 || thisSentUnitData > 0)
353     ) {
354         cout
355         << "c [sync] got units " << thisGotUnitData
356         << " sent units " << thisSentUnitData
357         << endl;
358     }
359 
360     stats.recvUnitData += thisGotUnitData;
361     stats.sentUnitData += thisSentUnitData;
362 
363     return true;
364 }
365 
signalNewBinClause(Lit lit1,Lit lit2)366 void DataSync::signalNewBinClause(Lit lit1, Lit lit2)
367 {
368     if (!enabled()) {
369         return;
370     }
371 
372     if (must_rebuild_bva_map) {
373         outer_to_without_bva_map = solver->build_outer_to_without_bva_map();
374         must_rebuild_bva_map = false;
375     }
376 
377     if (solver->varData[lit1.var()].is_bva)
378         return;
379     if (solver->varData[lit2.var()].is_bva)
380         return;
381 
382     lit1 = solver->map_inter_to_outer(lit1);
383     lit1 = map_outside_without_bva(lit1);
384     lit2 = solver->map_inter_to_outer(lit2);
385     lit2 = map_outside_without_bva(lit2);
386 
387     if (lit1.toInt() > lit2.toInt()) {
388         std::swap(lit1, lit2);
389     }
390     newBinClauses.push_back(std::make_pair(lit1, lit2));
391 }
392 
393 
394 ///////////////////////////////////////
395 // MPI
396 ///////////////////////////////////////
397 
398 #ifdef USE_MPI
getNeedToInterruptFromMPI()399 void DataSync::getNeedToInterruptFromMPI()
400 {
401     int flag;
402     MPI_Status status;
403     int err = MPI_Iprobe(0, 1, MPI_COMM_WORLD, &flag, &status);
404     assert(err == MPI_SUCCESS);
405     if (flag == false) {
406         return;
407     }
408 
409     char* buf = NULL;
410     err = MPI_Recv((unsigned*)buf, 0, MPI_UNSIGNED, 0, 1, MPI_COMM_WORLD, &status);
411     assert(err == MPI_SUCCESS);
412     solver->set_must_interrupt_asap();
413 }
414 
syncFromMPI()415 bool DataSync::syncFromMPI()
416 {
417     int err;
418     MPI_Status status;
419     int flag;
420     int count;
421     uint32_t tmp = 0;
422 
423     uint32_t thisMpiRecvUnitData = 0;
424     uint32_t thisMpiRecvBinData = 0;
425 
426     err = MPI_Iprobe(0, 0, MPI_COMM_WORLD, &flag, &status);
427     assert(err == MPI_SUCCESS);
428     if (flag == false) return true;
429 
430     err = MPI_Get_count(&status, MPI_UNSIGNED, &count);
431     assert(err == MPI_SUCCESS);
432     #ifdef VERBOSE_DEBUG_MPI_SENDRCV
433     std::cout << "-->> MPI " << mpiRank << " Received " << count << " uint32_t-s" << std::endl;
434     #endif
435 
436     uint32_t* buf = new uint32_t[count];
437     err = MPI_Recv((unsigned*)buf, count, MPI_UNSIGNED, 0, 0, MPI_COMM_WORLD, &status);
438     assert(err == MPI_SUCCESS);
439 
440     //Unit clauses
441     int at = 0;
442     assert(solver->nVars() == buf[at]);
443     at++;
444     for (uint32_t var = 0; var < solver->nVars(); var++, at++) {
445         const lbool otherVal = toLbool(buf[at]);
446         if (!sync_mpi_unit(otherVal, var, NULL, thisMpiRecvUnitData, tmp)) {
447             #ifdef VERBOSE_DEBUG_MPI_SENDRCV
448             std::cout << "-->> MPI " << mpiRank << " solver FALSE" << std::endl;
449             #endif
450             goto end;
451         }
452     }
453     solver->ok = solver->propagate<true>().isNULL();
454     if (!solver->ok) goto end;
455     mpiRecvUnitData += thisMpiRecvUnitData;
456 
457     //Binary clauses
458     assert(buf[at] == solver->nVars()*2);
459     at++;
460     for (uint32_t wsLit = 0; wsLit < solver->nVars()*2; wsLit++) {
461         Lit lit = ~Lit::toLit(wsLit);
462         uint32_t num = buf[at];
463         at++;
464         for (uint32_t i = 0; i < num; i++, at++) {
465             Lit otherLit = Lit::toLit(buf[at]);
466             addOneBinToOthers(lit, otherLit);
467             thisMpiRecvBinData++;
468         }
469     }
470     mpiRecvBinData += thisMpiRecvBinData;
471 
472     end:
473     #ifdef VERBOSE_DEBUG_MPI_SENDRCV
474     std::cout << "-->> MPI " << mpiRank << " Received " << thisMpiRecvUnitData << " units" << std::endl;
475     std::cout << "-->> MPI " << mpiRank << " Received " << thisMpiRecvBinData << " bins" << std::endl;
476     std::cout << "-->> MPI " << mpiRank << " Received " << thisMpiRecvTriData << " tris" << std::endl;
477     #endif
478 
479     delete[] buf;
480     return solver->ok;
481 }
482 
syncToMPI()483 void DataSync::syncToMPI()
484 {
485     int err;
486     if (mpiSendData != NULL) {
487         MPI_Status status;
488         err = MPI_Wait(&sendReq, &status);
489         assert(err == MPI_SUCCESS);
490         delete mpiSendData;
491         mpiSendData = NULL;
492     }
493 
494     vector<uint32_t> data;
495     data.push_back((uint32_t)solver->nVars());
496     for (uint32_t var = 0; var < solver->nVars(); var++) {
497         data.push_back(toInt(solver->value(var)));
498     }
499 
500     //Binary
501     uint32_t thisMpiSentBinData = 0;
502     data.push_back((uint32_t)solver->nVars()*2);
503     uint32_t wsLit = 0;
504     for(auto it = sharedData->bins.begin()
505         , end = sharedData->bins.end(); it != end; it++, wsLit++
506     ) {
507         //Lit lit1 = ~Lit::toLit(wsLit);
508         assert(it->data->size() >= syncMPIFinish[wsLit]);
509         uint32_t sizeToSend = it->data->size() - syncMPIFinish[wsLit];
510         data.push_back(sizeToSend);
511         for (uint32_t i = syncMPIFinish[wsLit]; i < it->data->size(); i++) {
512             data.push_back(it->data->at(i).toInt());
513             thisMpiSentBinData++;
514         }
515         syncMPIFinish[wsLit] = it->data->size();
516     }
517     assert(wsLit == solver->nVars()*2);
518     mpiSentBinData += thisMpiSentBinData;
519 
520     #ifdef VERBOSE_DEBUG_MPI_SENDRCV
521     std::cout << "-->> MPI " << mpiRank << " Sent " << data.size() << " uint32_t -s" << std::endl;
522     std::cout << "-->> MPI " << mpiRank << " Sent " << thisMpiSentBinData << " bins " << std::endl;
523     std::cout << "-->> MPI " << mpiRank << " Sent " << thisMpiSentTriData << " tris " << std::endl;
524     #endif
525 
526     mpiSendData = new uint32_t[data.size()];
527     std::copy(data.begin(), data.end(), mpiSendData);
528     err = MPI_Isend(mpiSendData, data.size(), MPI_UNSIGNED, 0, 0, MPI_COMM_WORLD, &sendReq);
529     assert(err == MPI_SUCCESS);
530 }
531 
sync_mpi_unit(const lbool otherVal,const uint32_t var,SharedData * shared,uint32_t & thisGotUnitData,uint32_t & thisSentUnitData)532 bool DataSync::sync_mpi_unit(
533     const lbool otherVal,
534     const uint32_t var,
535     SharedData* shared,
536     uint32_t& thisGotUnitData,
537     uint32_t& thisSentUnitData
538 ) {
539     Lit l = Lit(var, false);
540     Lit lit1 = solver->map_to_with_bva(l);
541     lit1 = solver->varReplacer->get_lit_replaced_with_outer(lit1);
542     lit1 = solver->map_outer_to_inter(lit1);
543     const lbool thisVal = solver->value(lit1);
544 
545     if (thisVal == l_Undef && otherVal == l_Undef) {
546         return true;
547     }
548     if (thisVal != l_Undef && otherVal != l_Undef) {
549         if (thisVal != otherVal) {
550             solver->ok = false;
551             return false;
552         } else {
553             return true;
554         }
555     }
556 
557     if (otherVal != l_Undef) {
558         assert(thisVal == l_Undef);
559         Lit litToEnqueue = lit1 ^ (otherVal == l_False);
560         if (solver->varData[litToEnqueue.var()].removed != Removed::none) {
561             return true;
562         }
563 
564         solver->enqueue(litToEnqueue);
565         solver->ok = solver->propagate<false>().isNULL();
566         if (!solver->ok) {
567             return false;
568         }
569 
570         thisGotUnitData++;
571         return true;
572     }
573 
574     if (shared != NULL && thisVal != l_Undef) {
575         assert(otherVal == l_Undef);
576         shared->value[var] = thisVal;
577         thisSentUnitData++;
578         return true;
579     }
580 
581     return true;
582 }
583 
584 #endif
585