1 /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2    Copyright (c) 2012-2021 The plumed team
3    (see the PEOPLE file at the root of the distribution for a list of names)
4 
5    See http://www.plumed.org for more information.
6 
7    This file is part of plumed, version 2.
8 
9    plumed is free software: you can redistribute it and/or modify
10    it under the terms of the GNU Lesser General Public License as published by
11    the Free Software Foundation, either version 3 of the License, or
12    (at your option) any later version.
13 
14    plumed is distributed in the hope that it will be useful,
15    but WITHOUT ANY WARRANTY; without even the implied warranty of
16    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17    GNU Lesser General Public License for more details.
18 
19    You should have received a copy of the GNU Lesser General Public License
20    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
21 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 #include "ActionWithVessel.h"
23 #include "tools/Communicator.h"
24 #include "Vessel.h"
25 #include "ShortcutVessel.h"
26 #include "StoreDataVessel.h"
27 #include "VesselRegister.h"
28 #include "BridgeVessel.h"
29 #include "FunctionVessel.h"
30 #include "StoreDataVessel.h"
31 #include "tools/OpenMP.h"
32 #include "tools/Stopwatch.h"
33 
34 using namespace std;
35 namespace PLMD {
36 namespace vesselbase {
37 
registerKeywords(Keywords & keys)38 void ActionWithVessel::registerKeywords(Keywords& keys) {
39   keys.add("hidden","TOL","this keyword can be used to speed up your calculation. When accumulating sums in which the individual "
40            "terms are numbers in between zero and one it is assumed that terms less than a certain tolerance "
41            "make only a small contribution to the sum.  They can thus be safely ignored as can the the derivatives "
42            "wrt these small quantities.");
43   keys.add("hidden","MAXDERIVATIVES","The maximum number of derivatives that can be used when storing data.  This controls when "
44            "we have to start using lowmem");
45   keys.addFlag("SERIAL",false,"do the calculation in serial.  Do not use MPI");
46   keys.addFlag("LOWMEM",false,"lower the memory requirements");
47   keys.addFlag("TIMINGS",false,"output information on the timings of the various parts of the calculation");
48   keys.reserveFlag("HIGHMEM",false,"use a more memory intensive version of this collective variable");
49   keys.add( vesselRegister().getKeywords() );
50 }
51 
ActionWithVessel(const ActionOptions & ao)52 ActionWithVessel::ActionWithVessel(const ActionOptions&ao):
53   Action(ao),
54   serial(false),
55   lowmem(false),
56   noderiv(true),
57   actionIsBridged(false),
58   nactive_tasks(0),
59   dertime_can_be_off(false),
60   dertime(true),
61   contributorsAreUnlocked(false),
62   weightHasDerivatives(false),
63   mydata(NULL)
64 {
65   maxderivatives=309; parse("MAXDERIVATIVES",maxderivatives);
66   if( keywords.exists("SERIAL") ) parseFlag("SERIAL",serial);
67   else serial=true;
68   if(serial)log.printf("  doing calculation in serial\n");
69   if( keywords.exists("LOWMEM") ) {
70     plumed_assert( !keywords.exists("HIGHMEM") );
71     parseFlag("LOWMEM",lowmem);
72     if(lowmem) {
73       log.printf("  lowering memory requirements\n");
74       dertime_can_be_off=true;
75     }
76   }
77   if( keywords.exists("HIGHMEM") ) {
78     plumed_assert( !keywords.exists("LOWMEM") );
79     bool highmem; parseFlag("HIGHMEM",highmem);
80     lowmem=!highmem;
81     if(!lowmem) log.printf("  increasing the memory requirements\n");
82   }
83   tolerance=nl_tolerance=epsilon;
84   if( keywords.exists("TOL") ) parse("TOL",tolerance);
85   if( tolerance>epsilon) {
86     log.printf(" Ignoring contributions less than %f \n",tolerance);
87   }
88   parseFlag("TIMINGS",timers);
89   stopwatch.start(); stopwatch.pause();
90 }
91 
~ActionWithVessel()92 ActionWithVessel::~ActionWithVessel() {
93   stopwatch.start(); stopwatch.stop();
94   if(timers) {
95     log.printf("timings for action %s with label %s \n", getName().c_str(), getLabel().c_str() );
96     log<<stopwatch;
97   }
98 }
99 
addVessel(const std::string & name,const std::string & input,const int numlab)100 void ActionWithVessel::addVessel( const std::string& name, const std::string& input, const int numlab ) {
101   VesselOptions da(name,"",numlab,input,this);
102   auto vv=vesselRegister().create(name,da);
103   FunctionVessel* fv=dynamic_cast<FunctionVessel*>(vv.get());
104   if( fv ) {
105     std::string mylabel=Vessel::transformName( name );
106     plumed_massert( keywords.outputComponentExists(mylabel,false), "a description of the value calculated by vessel " + name + " has not been added to the manual");
107   }
108   addVessel(std::move(vv));
109 }
110 
addVessel(std::unique_ptr<Vessel> vv_ptr)111 void ActionWithVessel::addVessel( std::unique_ptr<Vessel> vv_ptr ) {
112 
113 // In the original code, the dynamically casted pointer was deleted here.
114 // Now that vv_ptr is a unique_ptr, the object will be deleted automatically when
115 // exiting this routine.
116   if(dynamic_cast<ShortcutVessel*>(vv_ptr.get())) return;
117 
118   vv_ptr->checkRead();
119 
120   StoreDataVessel* mm=dynamic_cast<StoreDataVessel*>( vv_ptr.get() );
121   if( mydata && mm ) error("cannot have more than one StoreDataVessel in one action");
122   else if( mm ) mydata=mm;
123   else dertime_can_be_off=false;
124 
125 // Ownership is transferred to functions
126   functions.emplace_back(std::move(vv_ptr));
127 }
128 
addBridgingVessel(ActionWithVessel * tome)129 BridgeVessel* ActionWithVessel::addBridgingVessel( ActionWithVessel* tome ) {
130   VesselOptions da("","",0,"",this);
131   std::unique_ptr<BridgeVessel> bv(new BridgeVessel(da));
132   bv->setOutputAction( tome );
133   tome->actionIsBridged=true; dertime_can_be_off=false;
134 // store this pointer in order to return it later.
135 // notice that I cannot access this with functions.tail().get()
136 // since functions contains pointers to a different class (Vessel)
137   auto toBeReturned=bv.get();
138   functions.emplace_back( std::move(bv) );
139   resizeFunctions();
140   return toBeReturned;
141 }
142 
buildDataStashes(ActionWithVessel * actionThatUses)143 StoreDataVessel* ActionWithVessel::buildDataStashes( ActionWithVessel* actionThatUses ) {
144   if(mydata) {
145     if( actionThatUses ) mydata->addActionThatUses( actionThatUses );
146     return mydata;
147   }
148 
149   VesselOptions da("","",0,"",this);
150   std::unique_ptr<StoreDataVessel> mm( new StoreDataVessel(da) );
151   if( actionThatUses ) mm->addActionThatUses( actionThatUses );
152   addVessel(std::move(mm));
153 
154   // Make sure resizing of vessels is done
155   resizeFunctions();
156 
157   return mydata;
158 }
159 
addTaskToList(const unsigned & taskCode)160 void ActionWithVessel::addTaskToList( const unsigned& taskCode ) {
161   fullTaskList.push_back( taskCode ); taskFlags.push_back(0);
162   plumed_assert( fullTaskList.size()==taskFlags.size() );
163 }
164 
readVesselKeywords()165 void ActionWithVessel::readVesselKeywords() {
166   // Set maxderivatives if it is too big
167   if( maxderivatives>getNumberOfDerivatives() ) maxderivatives=getNumberOfDerivatives();
168 
169   // Loop over all keywords find the vessels and create appropriate functions
170   for(unsigned i=0; i<keywords.size(); ++i) {
171     std::string thiskey,input; thiskey=keywords.getKeyword(i);
172     // Check if this is a key for a vessel
173     if( vesselRegister().check(thiskey) ) {
174       plumed_assert( keywords.style(thiskey,"vessel") );
175       bool dothis=false; parseFlag(thiskey,dothis);
176       if(dothis) addVessel( thiskey, input );
177 
178       parse(thiskey,input);
179       if(input.size()!=0) {
180         addVessel( thiskey, input );
181       } else {
182         for(unsigned i=1;; ++i) {
183           if( !parseNumbered(thiskey,i,input) ) break;
184           std::string ss; Tools::convert(i,ss);
185           addVessel( thiskey, input, i );
186           input.clear();
187         }
188       }
189     }
190   }
191 
192   // Make sure all vessels have had been resized at start
193   if( functions.size()>0 ) resizeFunctions();
194 }
195 
resizeFunctions()196 void ActionWithVessel::resizeFunctions() {
197   for(unsigned i=0; i<functions.size(); ++i) functions[i]->resize();
198 }
199 
needsDerivatives()200 void ActionWithVessel::needsDerivatives() {
201   // Turn on the derivatives and resize
202   noderiv=false; resizeFunctions();
203   // Setting contributors unlocked here ensures that link cells are ignored
204   contributorsAreUnlocked=true; contributorsAreUnlocked=false;
205   // And turn on the derivatives in all actions on which we are dependent
206   for(unsigned i=0; i<getDependencies().size(); ++i) {
207     ActionWithVessel* vv=dynamic_cast<ActionWithVessel*>( getDependencies()[i] );
208     if(vv) vv->needsDerivatives();
209   }
210 }
211 
lockContributors()212 void ActionWithVessel::lockContributors() {
213   nactive_tasks = 0;
214   for(unsigned i=0; i<fullTaskList.size(); ++i) {
215     if( taskFlags[i]>0 ) nactive_tasks++;
216   }
217 
218   unsigned n=0;
219   partialTaskList.resize( nactive_tasks );
220   indexOfTaskInFullList.resize( nactive_tasks );
221   for(unsigned i=0; i<fullTaskList.size(); ++i) {
222     // Deactivate sets inactive tasks to number not equal to zero
223     if( taskFlags[i]>0 ) {
224       partialTaskList[n] = fullTaskList[i];
225       indexOfTaskInFullList[n]=i;
226       n++;
227     }
228   }
229   plumed_dbg_assert( n==nactive_tasks );
230   for(unsigned i=0; i<functions.size(); ++i) {
231     BridgeVessel* bb = dynamic_cast<BridgeVessel*>( functions[i].get() );
232     if( bb ) bb->copyTaskFlags();
233   }
234   // Resize mydata to accommodate all active tasks
235   if( mydata ) mydata->resize();
236   contributorsAreUnlocked=false;
237 }
238 
deactivateAllTasks()239 void ActionWithVessel::deactivateAllTasks() {
240   contributorsAreUnlocked=true; nactive_tasks = 0;
241   taskFlags.assign(taskFlags.size(),0);
242 }
243 
taskIsCurrentlyActive(const unsigned & index) const244 bool ActionWithVessel::taskIsCurrentlyActive( const unsigned& index ) const {
245   plumed_dbg_assert( index<taskFlags.size() ); return (taskFlags[index]>0);
246 }
247 
doJobsRequiredBeforeTaskList()248 void ActionWithVessel::doJobsRequiredBeforeTaskList() {
249   // Do any preparatory stuff for functions
250   for(unsigned j=0; j<functions.size(); ++j) functions[j]->prepare();
251 }
252 
getSizeOfBuffer(unsigned & bufsize)253 unsigned ActionWithVessel::getSizeOfBuffer( unsigned& bufsize ) {
254   for(unsigned i=0; i<functions.size(); ++i) functions[i]->setBufferStart( bufsize );
255   if( buffer.size()!=bufsize ) buffer.resize( bufsize );
256   if( mydata ) {
257     unsigned dsize=mydata->getSizeOfDerivativeList();
258     if( der_list.size()!=dsize ) der_list.resize( dsize );
259   }
260   return bufsize;
261 }
262 
runAllTasks()263 void ActionWithVessel::runAllTasks() {
264   plumed_massert( !contributorsAreUnlocked && functions.size()>0, "you must have a call to readVesselKeywords somewhere" );
265   unsigned stride=comm.Get_size();
266   unsigned rank=comm.Get_rank();
267   if(serial) { stride=1; rank=0; }
268 
269   // Make sure jobs are done
270   if(timers) stopwatch.start("1 Prepare Tasks");
271   doJobsRequiredBeforeTaskList();
272   if(timers) stopwatch.stop("1 Prepare Tasks");
273 
274   // Get number of threads for OpenMP
275   unsigned nt=OpenMP::getNumThreads();
276   if( nt*stride*2>nactive_tasks || !threadSafe()) nt=1;
277 
278   // Get size for buffer
279   unsigned bsize=0, bufsize=getSizeOfBuffer( bsize );
280   // Clear buffer
281   buffer.assign( buffer.size(), 0.0 );
282   // Switch off calculation of derivatives in main loop
283   if( dertime_can_be_off ) dertime=false;
284 
285   if(timers) stopwatch.start("2 Loop over tasks");
286   #pragma omp parallel num_threads(nt)
287   {
288     std::vector<double> omp_buffer;
289     if( nt>1 ) omp_buffer.resize( bufsize, 0.0 );
290     MultiValue myvals( getNumberOfQuantities(), getNumberOfDerivatives() );
291     MultiValue bvals( getNumberOfQuantities(), getNumberOfDerivatives() );
292     myvals.clearAll(); bvals.clearAll();
293 
294     #pragma omp for nowait schedule(dynamic)
295     for(unsigned i=rank; i<nactive_tasks; i+=stride) {
296       // Calculate the stuff in the loop for this action
297       performTask( indexOfTaskInFullList[i], partialTaskList[i], myvals );
298 
299       // Check for conditions that allow us to just to skip the calculation
300       // the condition is that the weight of the contribution is low
301       // N.B. Here weights are assumed to be between zero and one
302       if( myvals.get(0)<tolerance ) {
303         // Clear the derivatives
304         myvals.clearAll();
305         continue;
306       }
307 
308       // Now calculate all the functions
309       // If the contribution of this quantity is very small at neighbour list time ignore it
310       // until next neighbour list time
311       if( nt>1 ) {
312         calculateAllVessels( indexOfTaskInFullList[i], myvals, bvals, omp_buffer, der_list );
313       } else {
314         calculateAllVessels( indexOfTaskInFullList[i], myvals, bvals, buffer, der_list );
315       }
316 
317       // Clear the value
318       myvals.clearAll();
319     }
320     #pragma omp critical
321     if(nt>1) for(unsigned i=0; i<bufsize; ++i) buffer[i]+=omp_buffer[i];
322   }
323   if(timers) stopwatch.stop("2 Loop over tasks");
324   // Turn back on derivative calculation
325   dertime=true;
326 
327   if(timers) stopwatch.start("3 MPI gather");
328   // MPI Gather everything
329   if( !serial && buffer.size()>0 ) comm.Sum( buffer );
330   // MPI Gather index stores
331   if( mydata && !lowmem && !noderiv ) {
332     comm.Sum( der_list ); mydata->setActiveValsAndDerivatives( der_list );
333   }
334   // Update the elements that are makign contributions to the sum here
335   // this causes problems if we do it in prepare
336   if(timers) stopwatch.stop("3 MPI gather");
337 
338   if(timers) stopwatch.start("4 Finishing computations");
339   finishComputations( buffer );
340   if(timers) stopwatch.stop("4 Finishing computations");
341 }
342 
transformBridgedDerivatives(const unsigned & current,MultiValue & invals,MultiValue & outvals) const343 void ActionWithVessel::transformBridgedDerivatives( const unsigned& current, MultiValue& invals, MultiValue& outvals ) const {
344   plumed_error();
345 }
346 
calculateAllVessels(const unsigned & taskCode,MultiValue & myvals,MultiValue & bvals,std::vector<double> & buffer,std::vector<unsigned> & der_list)347 void ActionWithVessel::calculateAllVessels( const unsigned& taskCode, MultiValue& myvals, MultiValue& bvals, std::vector<double>& buffer, std::vector<unsigned>& der_list ) {
348   for(unsigned j=0; j<functions.size(); ++j) {
349     // Calculate returns a bool that tells us if this particular
350     // quantity is contributing more than the tolerance
351     functions[j]->calculate( taskCode, functions[j]->transformDerivatives(taskCode, myvals, bvals), buffer, der_list );
352     if( !actionIsBridged ) bvals.clearAll();
353   }
354   return;
355 }
356 
finishComputations(const std::vector<double> & buffer)357 void ActionWithVessel::finishComputations( const std::vector<double>& buffer ) {
358   // Set the final value of the function
359   for(unsigned j=0; j<functions.size(); ++j) functions[j]->finish( buffer );
360 }
361 
getForcesFromVessels(std::vector<double> & forcesToApply)362 bool ActionWithVessel::getForcesFromVessels( std::vector<double>& forcesToApply ) {
363 #ifndef NDEBUG
364   if( forcesToApply.size()>0 ) plumed_dbg_assert( forcesToApply.size()==getNumberOfDerivatives() );
365 #endif
366   if(tmpforces.size()!=forcesToApply.size() ) tmpforces.resize( forcesToApply.size() );
367 
368   forcesToApply.assign( forcesToApply.size(),0.0 );
369   bool wasforced=false;
370   for(unsigned i=0; i<getNumberOfVessels(); ++i) {
371     if( (functions[i]->applyForce( tmpforces )) ) {
372       wasforced=true;
373       for(unsigned j=0; j<forcesToApply.size(); ++j) forcesToApply[j]+=tmpforces[j];
374     }
375   }
376   return wasforced;
377 }
378 
retrieveDomain(std::string & min,std::string & max)379 void ActionWithVessel::retrieveDomain( std::string& min, std::string& max ) {
380   plumed_merror("If your function is periodic you need to add a retrieveDomain function so that ActionWithVessel can retrieve the domain");
381 }
382 
getVesselWithName(const std::string & mynam)383 Vessel* ActionWithVessel::getVesselWithName( const std::string& mynam ) {
384   int target=-1;
385   for(unsigned i=0; i<functions.size(); ++i) {
386     if( functions[i]->getName().find(mynam)!=std::string::npos ) {
387       if( target<0 ) target=i;
388       else error("found more than one " + mynam + " object in action");
389     }
390   }
391   return functions[target].get();
392 }
393 
394 }
395 }
396