1 /* ----------------------------------------------------------------------
2    LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
3    https://www.lammps.org/ Sandia National Laboratories
4    Steve Plimpton, sjplimp@sandia.gov
5 
6    Copyright (2003) Sandia Corporation.  Under the terms of Contract
7    DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
8    certain rights in this software.  This software is distributed under
9    the GNU General Public License.
10 
11    See the README file in the top-level LAMMPS directory.
12 ------------------------------------------------------------------------- */
13 #include "gmock/gmock.h"
14 #include "gtest/gtest.h"
15 
16 #include <deque>
17 #include <mpi.h>
18 
19 using ::testing::TestCase;
20 using ::testing::TestEventListener;
21 using ::testing::TestInfo;
22 using ::testing::TestPartResult;
23 using ::testing::TestSuite;
24 using ::testing::UnitTest;
25 
26 class MPIPrinter : public TestEventListener {
27     MPI_Comm comm;
28     TestEventListener *default_listener;
29     int me;
30     int nprocs;
31     char *buffer;
32     size_t buffer_size;
33     std::deque<TestPartResult> results;
34     bool finalize_test;
35 
36 public:
MPIPrinter(TestEventListener * default_listener)37     MPIPrinter(TestEventListener *default_listener) : default_listener(default_listener)
38     {
39         comm = MPI_COMM_WORLD;
40         MPI_Comm_rank(comm, &me);
41         MPI_Comm_size(comm, &nprocs);
42         buffer_size   = 1024;
43         buffer        = new char[buffer_size];
44         finalize_test = false;
45     }
46 
~MPIPrinter()47     ~MPIPrinter() override
48     {
49         delete default_listener;
50         default_listener = nullptr;
51 
52         delete[] buffer;
53         buffer      = nullptr;
54         buffer_size = 0;
55     }
56 
OnTestProgramStart(const UnitTest & unit_test)57     virtual void OnTestProgramStart(const UnitTest &unit_test) override
58     {
59         if (me == 0) default_listener->OnTestProgramStart(unit_test);
60     }
61 
OnTestIterationStart(const UnitTest & unit_test,int iteration)62     virtual void OnTestIterationStart(const UnitTest &unit_test, int iteration) override
63     {
64         if (me == 0) default_listener->OnTestIterationStart(unit_test, iteration);
65     }
66 
OnEnvironmentsSetUpStart(const UnitTest & unit_test)67     virtual void OnEnvironmentsSetUpStart(const UnitTest &unit_test) override
68     {
69         if (me == 0) default_listener->OnEnvironmentsSetUpStart(unit_test);
70     }
71 
OnEnvironmentsSetUpEnd(const UnitTest & unit_test)72     virtual void OnEnvironmentsSetUpEnd(const UnitTest &unit_test) override
73     {
74         if (me == 0) default_listener->OnEnvironmentsSetUpEnd(unit_test);
75     }
76 
OnTestSuiteStart(const TestSuite & test_suite)77     virtual void OnTestSuiteStart(const TestSuite &test_suite) override
78     {
79         if (me == 0) default_listener->OnTestSuiteStart(test_suite);
80     }
81 
82     //  Legacy API is deprecated but still available
83 #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_
OnTestCaseStart(const TestCase & test_case)84     virtual void OnTestCaseStart(const TestCase &test_case) override
85     {
86         if (me == 0) default_listener->OnTestSuiteStart(test_case);
87     }
88 #endif //  GTEST_REMOVE_LEGACY_TEST_CASEAPI_
89 
OnTestStart(const TestInfo & test_info)90     virtual void OnTestStart(const TestInfo &test_info) override
91     {
92         // Called before a test starts.
93         if (me == 0) default_listener->OnTestStart(test_info);
94         results.clear();
95         finalize_test = false;
96     }
97 
OnTestPartResult(const TestPartResult & test_part_result)98     virtual void OnTestPartResult(const TestPartResult &test_part_result) override
99     {
100         // Called after a failed assertion or a SUCCESS().
101         // test_part_result()
102 
103         if (me == 0 && finalize_test) {
104             default_listener->OnTestPartResult(test_part_result);
105         } else {
106             std::stringstream proc_message;
107             std::istringstream msg(test_part_result.message());
108             std::string line;
109 
110             while (std::getline(msg, line)) {
111                 proc_message << "[Rank " << me << "] " << line << std::endl;
112             }
113 
114             results.push_back(TestPartResult(test_part_result.type(), test_part_result.file_name(),
115                                              test_part_result.line_number(),
116                                              proc_message.str().c_str()));
117         }
118     }
119 
OnTestEnd(const TestInfo & test_info)120     virtual void OnTestEnd(const TestInfo &test_info) override
121     {
122         // Called after a test ends.
123         MPI_Barrier(comm);
124 
125         // other procs send their test part results
126         if (me != 0) {
127             int nresults = results.size();
128             MPI_Send(&nresults, 1, MPI_INT, 0, 0, comm);
129 
130             for (auto &test_part_result : results) {
131 
132                 int type = test_part_result.type();
133                 MPI_Send(&type, 1, MPI_INT, 0, 0, comm);
134 
135                 const char *str = test_part_result.file_name();
136                 int length      = 0;
137                 if (str) length = strlen(str) + 1;
138                 MPI_Send(&length, 1, MPI_INT, 0, 0, comm);
139                 if (str) MPI_Send(str, length, MPI_CHAR, 0, 0, comm);
140 
141                 int lineno = test_part_result.line_number();
142                 MPI_Send(&lineno, 1, MPI_INT, 0, 0, comm);
143 
144                 str    = test_part_result.message();
145                 length = 0;
146                 if (str) length = strlen(str) + 1;
147                 MPI_Send(&length, 1, MPI_INT, 0, 0, comm);
148                 if (str) MPI_Send(str, length, MPI_CHAR, 0, 0, comm);
149             }
150         }
151 
152         if (me == 0) {
153             // collect results from other procs
154             for (int p = 1; p < nprocs; p++) {
155                 int nresults = 0;
156                 MPI_Recv(&nresults, 1, MPI_INT, p, 0, comm, MPI_STATUS_IGNORE);
157 
158                 for (int r = 0; r < nresults; r++) {
159 
160                     int type;
161                     MPI_Recv(&type, 1, MPI_INT, p, 0, comm, MPI_STATUS_IGNORE);
162 
163                     int length = 0;
164                     MPI_Recv(&length, 1, MPI_INT, p, 0, comm, MPI_STATUS_IGNORE);
165                     std::string file_name;
166 
167                     if (length > 0) {
168                         if (length > buffer_size) {
169                             delete[] buffer;
170                             buffer      = new char[length];
171                             buffer_size = length;
172                         }
173                         MPI_Recv(buffer, length, MPI_CHAR, p, 0, comm, MPI_STATUS_IGNORE);
174                         file_name = buffer;
175                     }
176 
177                     int lineno;
178                     MPI_Recv(&lineno, 1, MPI_INT, p, 0, comm, MPI_STATUS_IGNORE);
179 
180                     MPI_Recv(&length, 1, MPI_INT, p, 0, comm, MPI_STATUS_IGNORE);
181                     std::string message;
182 
183                     if (length > 0) {
184                         if (length > buffer_size) {
185                             delete[] buffer;
186                             buffer      = new char[length];
187                             buffer_size = length;
188                         }
189                         MPI_Recv(buffer, length, MPI_CHAR, p, 0, comm, MPI_STATUS_IGNORE);
190                         message = std::string(buffer);
191                     }
192 
193                     results.push_back(TestPartResult((TestPartResult::Type)type, file_name.c_str(),
194                                                      lineno, message.c_str()));
195                 }
196             }
197 
198             // ensure failures are reported
199             finalize_test = true;
200 
201             // add all failures
202             while (!results.empty()) {
203                 auto result = results.front();
204                 if (result.failed()) {
205                     ADD_FAILURE_AT(result.file_name(), result.line_number()) << result.message();
206                 } else {
207                     default_listener->OnTestPartResult(result);
208                 }
209                 results.pop_front();
210             }
211 
212             default_listener->OnTestEnd(test_info);
213         }
214     }
215 
OnTestSuiteEnd(const TestSuite & test_suite)216     virtual void OnTestSuiteEnd(const TestSuite &test_suite) override
217     {
218         if (me == 0) default_listener->OnTestSuiteEnd(test_suite);
219     }
220 
221 #ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_
OnTestCaseEnd(const TestCase & test_case)222     virtual void OnTestCaseEnd(const TestCase &test_case) override
223     {
224         if (me == 0) default_listener->OnTestCaseEnd(test_case);
225     }
226 #endif //  GTEST_REMOVE_LEGACY_TEST_CASEAPI_
227 
OnEnvironmentsTearDownStart(const UnitTest & unit_test)228     virtual void OnEnvironmentsTearDownStart(const UnitTest &unit_test) override
229     {
230         if (me == 0) default_listener->OnEnvironmentsTearDownStart(unit_test);
231     }
232 
OnEnvironmentsTearDownEnd(const UnitTest & unit_test)233     virtual void OnEnvironmentsTearDownEnd(const UnitTest &unit_test) override
234     {
235         if (me == 0) default_listener->OnEnvironmentsTearDownEnd(unit_test);
236     }
237 
OnTestIterationEnd(const UnitTest & unit_test,int iteration)238     virtual void OnTestIterationEnd(const UnitTest &unit_test, int iteration) override
239     {
240         if (me == 0) default_listener->OnTestIterationEnd(unit_test, iteration);
241     }
242 
OnTestProgramEnd(const UnitTest & unit_test)243     virtual void OnTestProgramEnd(const UnitTest &unit_test) override
244     {
245         if (me == 0) default_listener->OnTestProgramEnd(unit_test);
246     }
247 };
248