1 //////////////////////////////////////////////////////////////////////////////////////
2 // This file is distributed under the University of Illinois/NCSA Open Source License.
3 // See LICENSE file in top directory for details.
4 //
5 // Copyright (c) 2020 QMCPACK developers.
6 //
7 // File developed by: Peter Doak, doakpw@ornl.gov, Oak Ridge National Laboratory
8 //
9 // File created by: Peter Doak, doakpw@ornl.gov, Oak Ridge National Laboratory
10 //////////////////////////////////////////////////////////////////////////////////////
11 
12 #include "catch.hpp"
13 #include "Message/Communicate.h"
14 
15 #include "Platforms/Host/OutputManager.h"
16 
17 #include "Estimators/EstimatorManagerNew.h"
18 #include "Estimators/tests/EstimatorManagerNewTest.h"
19 
20 namespace qmcplusplus
21 {
22 namespace testing
23 {
testMakeBlockAverages()24 bool EstimatorManagerNewTest::testMakeBlockAverages()
25 {
26   if (em.my_comm_->rank() == 1)
27   {
28     estimators_[1].scalars[0](3.0);
29     estimators_[1].scalars[1](3.0);
30     estimators_[1].scalars[2](3.0);
31     estimators_[1].scalars[3](3.0);
32   }
33 
34   // manipulation of state to arrive at to be tested state.
35   // - From EstimatorManagerBase::reset
36   em.weightInd      = em.BlockProperties.add("BlockWeight");
37   em.cpuInd         = em.BlockProperties.add("BlockCPU");
38   em.acceptRatioInd = em.BlockProperties.add("AcceptRatio");
39 
40   // - From EstimatorManagerBase::start
41   em.PropertyCache.resize(em.BlockProperties.size());
42 
43   // - From EstimatorManagerBase::stopBlocknew
44   //   three estimators
45   //   - 2 with 1 sample 1
46   //   - 1 with 2
47   double block_weight = 0;
48   std::for_each(estimators_.begin(), estimators_.end(),
49                 [&block_weight](auto& est) { block_weight += est.scalars[0].count(); });
50   em.PropertyCache[em.weightInd] = block_weight;
51   em.PropertyCache[em.cpuInd]    = 1.0;
52 
53   RefVector<ScalarEstimatorBase> est_list = makeRefVector<ScalarEstimatorBase>(estimators_);
54   em.collectScalarEstimators(est_list);
55 
56   unsigned long accepts = 4;
57   unsigned long rejects = 1;
58   em.makeBlockAverages(accepts, rejects);
59   return true;
60 }
61 
62 } // namespace testing
63 
64 TEST_CASE("EstimatorManagerNew::makeBlockAverages()", "[estimators]")
65 {
66   Communicate* c = OHMMS::Controller;
67   int num_ranks  = c->size();
68   testing::EstimatorManagerNewTest embt(c, num_ranks);
69 
70   embt.fakeSomeScalarSamples();
71   embt.testMakeBlockAverages();
72 
73   // right now only rank() == 0 gets the actual averages
74   if (c->rank() == 0)
75   {
76     double correct_value = (5.0 * num_ranks + 3.0) / (4 * (num_ranks - 1) + 5);
77     CHECK(embt.em.get_AverageCache()[0] == Approx(correct_value));
78     correct_value = (8.0 * num_ranks + 3.0) / (4 * (num_ranks - 1) + 5);
79     CHECK(embt.em.get_AverageCache()[1] == Approx(correct_value));
80     correct_value = (11.0 * num_ranks + 3.0) / (4 * (num_ranks - 1) + 5);
81     CHECK(embt.em.get_AverageCache()[2] == Approx(correct_value));
82     correct_value = (14.0 * num_ranks + 3.0) / (4 * (num_ranks - 1) + 5);
83     CHECK(embt.em.get_AverageCache()[3] == Approx(correct_value));
84   }
85 }
86 
87 TEST_CASE("EstimatorManagerNew::reduceOperatorestimators()", "[estimators]")
88 {
89   Communicate* c = OHMMS::Controller;
90   int num_ranks  = c->size();
91   testing::EstimatorManagerNewTest embt(c, num_ranks);
92 
93   embt.fakeSomeOperatorEstimatorSamples(c->rank());
94   std::vector<QMCTraits::RealType> good_data = embt.generateGoodOperatorData(num_ranks);
95 
96   // Normalization is done by reduceOperatorEstimators based on the the total weight of the
97   // estimators for that block.
98   embt.testReduceOperatorEstimators();
99 
100   if (c->rank() == 0)
101   {
102     auto& test_data          = embt.get_operator_data();
103 
104     QMCTraits::RealType norm = 1.0 / static_cast<QMCTraits::RealType>(num_ranks);
105     for (size_t i = 0; i < test_data.size(); ++i)
106     {
107       QMCTraits::RealType norm_good_data = good_data[i] * norm;
108       if (norm_good_data != test_data[i])
109       {
110         FAIL_CHECK("norm_good_data " << norm_good_data << " != test_data " << test_data[i] << " at index " << i);
111         break;
112       }
113     }
114   }
115 }
116 
117 } // namespace qmcplusplus
118