1 // =============================================================================
2 // PROJECT CHRONO - http://projectchrono.org
3 //
4 // Copyright (c) 2020 projectchrono.org
5 // All rights reserved.
6 //
7 // Use of this source code is governed by a BSD-style license that can be found
8 // in the LICENSE file at the top level of the distribution and at
9 // http://projectchrono.org/license-chrono.txt.
10 //
11 // =============================================================================
12 // Authors: Jay Taves
13 // =============================================================================
14 //
15 // Unit test for SynChrono MPI code
16 //
17 // =============================================================================
18
19 #include <numeric>
20
21 #include "gtest/gtest.h"
22
23 #include "chrono_thirdparty/cxxopts/ChCLI.h"
24 #include "chrono_synchrono/SynChronoManager.h"
25 #include "chrono_synchrono/communication/mpi/SynMPICommunicator.h"
26
27 #include "chrono_synchrono/utils/SynDataLoader.h"
28
29 #include "chrono_synchrono/agent/SynEnvironmentAgent.h"
30 #include "chrono_synchrono/agent/SynWheeledVehicleAgent.h"
31
32 using namespace chrono;
33 using namespace synchrono;
34
35 int rank;
36 int num_ranks;
37
38 // Define our own main here to handle the MPI setup
main(int argc,char * argv[])39 int main(int argc, char* argv[]) {
40 // Let google strip their cli arguments
41 ::testing::InitGoogleTest(&argc, argv);
42
43 // Create the MPI communicator and the manager
44 auto communicator = chrono_types::make_shared<SynMPICommunicator>(argc, argv);
45 rank = communicator->GetRank();
46 num_ranks = communicator->GetNumRanks();
47 SynChronoManager syn_manager(rank, num_ranks, communicator);
48
49 ::testing::TestEventListeners& listeners = ::testing::UnitTest::GetInstance()->listeners();
50 if (rank != 0) {
51 delete listeners.Release(listeners.default_result_printer());
52 }
53
54 // Each rank will be running each test
55 return RUN_ALL_TESTS();
56 }
57
TEST(SynChrono,SynChronoInit)58 TEST(SynChrono, SynChronoInit) {
59 int* msg_lengths = new int[num_ranks];
60 int* msg_displs = new int[num_ranks];
61
62 // Just a meaningless message length that we will fill with data
63 int msg_length = 10 + num_ranks - rank;
64
65 // Determine how much we stuff we get from other ranks
66 MPI_Allgather(&msg_length, 1, MPI_INT, // Sending args
67 msg_lengths, 1, MPI_INT, // Receiving args
68 MPI_COMM_WORLD);
69
70 int total_length = 0;
71 std::vector<int> all_data;
72 std::vector<int> my_data;
73
74 // Fill the data we will send
75 for (int i = 0; i < msg_length; i++)
76 my_data.push_back(rank);
77
78 // Compute offsets for our receiving buffer
79 // In C++17 this could just be an exclusive scan from std::
80 // Didn't use std::partial_sum since we want m_total_length computed
81 // m_msg_displs is needed by MPI_Gatherv
82 for (int i = 0; i < num_ranks; i++) {
83 msg_displs[i] = total_length;
84 total_length += msg_lengths[i];
85 }
86
87 // Need resize rather than reserve so that MPI can just copy into the buffer
88 all_data.resize(total_length);
89
90 MPI_Allgatherv(my_data.data(), msg_length, MPI_INT, // Sending args
91 all_data.data(), msg_lengths, msg_displs, MPI_INT, // Receiving args
92 MPI_COMM_WORLD);
93
94 int sum = std::accumulate(all_data.begin(), all_data.end(), 0);
95
96 // Σ rank * (10 + num_ranks - rank) from 0 -> num_ranks - 1
97 int check = (num_ranks - 1) * num_ranks * (num_ranks + 31) / 6;
98
99 MPI_Barrier(MPI_COMM_WORLD);
100
101 ASSERT_EQ(sum, check);
102
103 delete[] msg_lengths;
104 delete[] msg_displs;
105 }