1 /*
2  *            Copyright 2009-2020 The VOTCA Development Team
3  *                       (http://www.votca.org)
4  *
5  *      Licensed under the Apache License, Version 2.0 (the "License")
6  *
7  * You may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *              http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  *
18  */
19 
20 // Local VOTCA includes
21 #include "votca/xtp/statetracker.h"
22 #include "votca/xtp/filterfactory.h"
23 
24 namespace votca {
25 namespace xtp {
26 using std::flush;
27 
Initialize(const tools::Property & options)28 void StateTracker::Initialize(const tools::Property& options) {
29 
30   FilterFactory::RegisterAll();
31   for (const tools::Property& filter : options) {
32     filters_.push_back(Filter().Create(filter.name()));
33   }
34 
35   for (auto& filter : filters_) {
36     const tools::Property& filterop = options.get(filter->Identify());
37     filter->Initialize(filterop);
38   }
39 }
40 
PrintInfo() const41 void StateTracker::PrintInfo() const {
42   XTP_LOG(Log::error, *log_)
43       << "Initial state: " << statehist_[0].ToString() << flush;
44   if (statehist_.size() > 1) {
45     XTP_LOG(Log::error, *log_)
46         << "Last state: " << statehist_.back().ToString() << flush;
47   }
48 
49   if (filters_.empty()) {
50     XTP_LOG(Log::error, *log_) << "WARNING: No tracker is used " << flush;
51   } else {
52     for (const auto& filter : filters_) {
53       filter->Info(*log_);
54     }
55   }
56 }
57 
ComparePairofVectors(std::vector<Index> & vec1,std::vector<Index> & vec2) const58 std::vector<Index> StateTracker::ComparePairofVectors(
59     std::vector<Index>& vec1, std::vector<Index>& vec2) const {
60   std::vector<Index> result(std::min(vec1, vec2));
61   std::sort(vec1.begin(), vec1.end());
62   std::sort(vec2.begin(), vec2.end());
63   std::vector<Index>::iterator it = std::set_intersection(
64       vec1.begin(), vec1.end(), vec2.begin(), vec2.end(), result.begin());
65   result.resize(it - result.begin());
66   return result;
67 }
68 
CollapseResults(std::vector<std::vector<Index>> & results) const69 std::vector<Index> StateTracker::CollapseResults(
70     std::vector<std::vector<Index>>& results) const {
71   if (results.empty()) {
72     return std::vector<Index>(0);
73   } else {
74     std::vector<Index> result = results[0];
75     for (Index i = 1; i < Index(results.size()); i++) {
76       result = ComparePairofVectors(result, results[i]);
77     }
78     return result;
79   }
80 }
81 
CalcState(const Orbitals & orbitals) const82 QMState StateTracker::CalcState(const Orbitals& orbitals) const {
83 
84   if (filters_.empty()) {
85     return statehist_[0];
86   }
87 
88   std::vector<std::vector<Index>> results;
89   for (const auto& filter : filters_) {
90     if (statehist_.size() < 2 && filter->NeedsInitialState()) {
91       XTP_LOG(Log::error, *log_)
92           << "Filter " << filter->Identify()
93           << " not used in first iteration as it needs a reference state"
94           << flush;
95       continue;
96     }
97     results.push_back(filter->CalcIndeces(orbitals, statehist_[0].Type()));
98   }
99 
100   std::vector<Index> result = CollapseResults(results);
101   QMState state;
102   if (result.size() < 1) {
103     state = statehist_.back();
104     XTP_LOG(Log::error, *log_)
105         << "No State found by tracker using last state: " << state.ToString()
106         << flush;
107   } else {
108     state = QMState(statehist_.back().Type(), result[0], false);
109     XTP_LOG(Log::error, *log_)
110         << "Next State is: " << state.ToString() << flush;
111   }
112   return state;
113 }
114 
CalcStateAndUpdate(const Orbitals & orbitals)115 QMState StateTracker::CalcStateAndUpdate(const Orbitals& orbitals) {
116   QMState result = CalcState(orbitals);
117   statehist_.push_back(result);
118   for (auto& filter : filters_) {
119     filter->UpdateHist(orbitals, result);
120   }
121   return result;
122 }
123 
WriteToCpt(CheckpointWriter & w) const124 void StateTracker::WriteToCpt(CheckpointWriter& w) const {
125   std::vector<std::string> statehiststring;
126   statehiststring.reserve(statehist_.size());
127   for (const QMState& s : statehist_) {
128     statehiststring.push_back(s.ToString());
129   }
130   w(statehiststring, "statehist");
131 
132   for (const auto& filter : filters_) {
133     CheckpointWriter ww = w.openChild(filter->Identify());
134     filter->WriteToCpt(ww);
135   }
136 }
137 
ReadFromCpt(CheckpointReader & r)138 void StateTracker::ReadFromCpt(CheckpointReader& r) {
139   FilterFactory::RegisterAll();
140   std::vector<std::string> statehiststring;
141   r(statehiststring, "statehist");
142   statehist_.clear();
143   statehist_.reserve(statehiststring.size());
144   for (const std::string& s : statehiststring) {
145     statehist_.push_back(QMState(s));
146   }
147   filters_.clear();
148   for (const std::string& filtername : r.getChildGroupNames()) {
149     CheckpointReader rr = r.openChild(filtername);
150     filters_.push_back(Filter().Create(filtername));
151     filters_.back()->ReadFromCpt(rr);
152   }
153 }
154 
155 }  // namespace xtp
156 }  // namespace votca
157