1 // =============================================================================
2 // PROJECT CHRONO - http://projectchrono.org
3 //
4 // Copyright (c) 2020 projectchrono.org
5 // All rights reserved.
6 //
7 // Use of this source code is governed by a BSD-style license that can be found
8 // in the LICENSE file at the top level of the distribution and at
9 // http://projectchrono.org/license-chrono.txt.
10 //
11 // =============================================================================
12 // Authors: Jay Taves
13 // =============================================================================
14 //
15 // Unit test for SynChrono MPI code
16 //
17 // =============================================================================
18 
19 #include <numeric>
20 
21 #include "gtest/gtest.h"
22 
23 #include "chrono_thirdparty/cxxopts/ChCLI.h"
24 #include "chrono_synchrono/SynChronoManager.h"
25 #include "chrono_synchrono/communication/mpi/SynMPICommunicator.h"
26 
27 #include "chrono_synchrono/utils/SynDataLoader.h"
28 
29 #include "chrono_synchrono/agent/SynEnvironmentAgent.h"
30 #include "chrono_synchrono/agent/SynWheeledVehicleAgent.h"
31 
32 using namespace chrono;
33 using namespace synchrono;
34 
35 int rank;
36 int num_ranks;
37 
38 // Define our own main here to handle the MPI setup
main(int argc,char * argv[])39 int main(int argc, char* argv[]) {
40     // Let google strip their cli arguments
41     ::testing::InitGoogleTest(&argc, argv);
42 
43     // Create the MPI communicator and the manager
44     auto communicator = chrono_types::make_shared<SynMPICommunicator>(argc, argv);
45     rank = communicator->GetRank();
46     num_ranks = communicator->GetNumRanks();
47     SynChronoManager syn_manager(rank, num_ranks, communicator);
48 
49     ::testing::TestEventListeners& listeners = ::testing::UnitTest::GetInstance()->listeners();
50     if (rank != 0) {
51         delete listeners.Release(listeners.default_result_printer());
52     }
53 
54     // Each rank will be running each test
55     return RUN_ALL_TESTS();
56 }
57 
TEST(SynChrono,SynChronoInit)58 TEST(SynChrono, SynChronoInit) {
59     int* msg_lengths = new int[num_ranks];
60     int* msg_displs = new int[num_ranks];
61 
62     // Just a meaningless message length that we will fill with data
63     int msg_length = 10 + num_ranks - rank;
64 
65     // Determine how much we stuff we get from other ranks
66     MPI_Allgather(&msg_length, 1, MPI_INT,  // Sending args
67                   msg_lengths, 1, MPI_INT,  // Receiving args
68                   MPI_COMM_WORLD);
69 
70     int total_length = 0;
71     std::vector<int> all_data;
72     std::vector<int> my_data;
73 
74     // Fill the data we will send
75     for (int i = 0; i < msg_length; i++)
76         my_data.push_back(rank);
77 
78     // Compute offsets for our receiving buffer
79     // In C++17 this could just be an exclusive scan from std::
80     // Didn't use std::partial_sum since we want m_total_length computed
81     // m_msg_displs is needed by MPI_Gatherv
82     for (int i = 0; i < num_ranks; i++) {
83         msg_displs[i] = total_length;
84         total_length += msg_lengths[i];
85     }
86 
87     // Need resize rather than reserve so that MPI can just copy into the buffer
88     all_data.resize(total_length);
89 
90     MPI_Allgatherv(my_data.data(), msg_length, MPI_INT,                // Sending args
91                    all_data.data(), msg_lengths, msg_displs, MPI_INT,  // Receiving args
92                    MPI_COMM_WORLD);
93 
94     int sum = std::accumulate(all_data.begin(), all_data.end(), 0);
95 
96     // Σ rank * (10 + num_ranks - rank)     from 0 -> num_ranks - 1
97     int check = (num_ranks - 1) * num_ranks * (num_ranks + 31) / 6;
98 
99     MPI_Barrier(MPI_COMM_WORLD);
100 
101     ASSERT_EQ(sum, check);
102 
103     delete[] msg_lengths;
104     delete[] msg_displs;
105 }