1 //
2 // CDDL HEADER START
3 //
4 // The contents of this file are subject to the terms of the Common Development
5 // and Distribution License Version 1.0 (the "License").
6 //
7 // You can obtain a copy of the license at
8 // http://www.opensource.org/licenses/CDDL-1.0.  See the License for the
9 // specific language governing permissions and limitations under the License.
10 //
11 // When distributing Covered Code, include this CDDL HEADER in each file and
12 // include the License file in a prominent location with the name LICENSE.CDDL.
13 // If applicable, add the following below this CDDL HEADER, with the fields
14 // enclosed by brackets "[]" replaced with your own identifying information:
15 //
16 // Portions Copyright (c) [yyyy] [name of copyright owner]. All rights reserved.
17 //
18 // CDDL HEADER END
19 //
20 
21 //
22 // Copyright (c) 2019, Regents of the University of Minnesota.
23 // All rights reserved.
24 //
25 // Contributors:
26 //    Mingjian Wen
27 //
28 
29 #include <cmath>
30 #include <cstdlib>
31 #include <cstring>
32 #include <fstream>
33 #include <iostream>
34 #include <map>
35 
36 #include "ANNImplementation.hpp"
37 #include "KIM_ModelDriverHeaders.hpp"
38 
39 //==============================================================================
40 //
41 // Implementation of ANNImplementation public member functions
42 //
43 //==============================================================================
44 
45 //******************************************************************************
46 #undef KIM_LOGGER_OBJECT_NAME
47 #define KIM_LOGGER_OBJECT_NAME modelDriverCreate
ANNImplementation(KIM::ModelDriverCreate * const modelDriverCreate,KIM::LengthUnit const requestedLengthUnit,KIM::EnergyUnit const requestedEnergyUnit,KIM::ChargeUnit const requestedChargeUnit,KIM::TemperatureUnit const requestedTemperatureUnit,KIM::TimeUnit const requestedTimeUnit,int * const ier)48 ANNImplementation::ANNImplementation(
49     KIM::ModelDriverCreate * const modelDriverCreate,
50     KIM::LengthUnit const requestedLengthUnit,
51     KIM::EnergyUnit const requestedEnergyUnit,
52     KIM::ChargeUnit const requestedChargeUnit,
53     KIM::TemperatureUnit const requestedTemperatureUnit,
54     KIM::TimeUnit const requestedTimeUnit,
55     int * const ier) :
56     energyScale_(1.0),
57     ensemble_size_(0),
58     last_ensemble_size_(0),
59     active_member_id_(-1),
60     last_active_member_id_(-1),
61     influenceDistance_(0.0),
62     modelWillNotRequestNeighborsOfNoncontributingParticles_(1),
63     cachedNumberOfParticles_(0)
64 {
65   // create descriptor and network classes
66   descriptor_ = new Descriptor();
67   network_ = new NeuralNetwork();
68 
69   FILE * parameterFilePointers[MAX_PARAMETER_FILES];
70   int numberParameterFiles;
71 
72   modelDriverCreate->GetNumberOfParameterFiles(&numberParameterFiles);
73   *ier = OpenParameterFiles(
74       modelDriverCreate, numberParameterFiles, parameterFilePointers);
75   if (*ier) { return; }
76 
77   *ier = ProcessParameterFiles(
78       modelDriverCreate, numberParameterFiles, parameterFilePointers);
79   CloseParameterFiles(numberParameterFiles, parameterFilePointers);
80   if (*ier) { return; }
81 
82   *ier = ConvertUnits(modelDriverCreate,
83                       requestedLengthUnit,
84                       requestedEnergyUnit,
85                       requestedChargeUnit,
86                       requestedTemperatureUnit,
87                       requestedTimeUnit);
88   if (*ier) { return; }
89 
90   *ier = SetRefreshMutableValues(modelDriverCreate);
91   if (*ier) { return; }
92 
93   *ier = RegisterKIMModelSettings(modelDriverCreate);
94   if (*ier) { return; }
95 
96   *ier = RegisterKIMParameters(modelDriverCreate);
97   if (*ier) { return; }
98 
99   *ier = RegisterKIMFunctions(modelDriverCreate);
100   if (*ier) { return; }
101 
102   // everything is good
103   *ier = false;
104   return;
105 }
106 
107 //******************************************************************************
~ANNImplementation()108 ANNImplementation::~ANNImplementation()
109 {  // note: it is ok to delete a null
110    // pointer and we have ensured that
111   // everything is initialized to null
112 
113   delete descriptor_;
114   delete network_;
115 }
116 
117 //******************************************************************************
118 #undef KIM_LOGGER_OBJECT_NAME
119 #define KIM_LOGGER_OBJECT_NAME modelRefresh
Refresh(KIM::ModelRefresh * const modelRefresh)120 int ANNImplementation::Refresh(KIM::ModelRefresh * const modelRefresh)
121 {
122   int ier;
123 
124   ier = SetRefreshMutableValues(modelRefresh);
125   if (ier) { return ier; }
126 
127   // nothing else to do for this case
128 
129   // everything is good
130   ier = false;
131   return ier;
132 }
133 
134 //******************************************************************************
Compute(KIM::ModelCompute const * const modelCompute,KIM::ModelComputeArguments const * const modelComputeArguments)135 int ANNImplementation::Compute(
136     KIM::ModelCompute const * const modelCompute,
137     KIM::ModelComputeArguments const * const modelComputeArguments)
138 {
139   int ier;
140 
141   // KIM API Model Input compute flags
142   bool isComputeProcess_dEdr = false;
143   bool isComputeProcess_d2Edr2 = false;
144   //
145   // KIM API Model Output compute flags
146   bool isComputeEnergy = false;
147   bool isComputeForces = false;
148   bool isComputeParticleEnergy = false;
149   bool isComputeVirial = false;
150   bool isComputeParticleVirial = false;
151   //
152   // KIM API Model Input
153   int const * particleSpeciesCodes = NULL;
154   int const * particleContributing = NULL;
155   VectorOfSizeDIM const * coordinates = NULL;
156   //
157   // KIM API Model Output
158   double * energy = NULL;
159   double * particleEnergy = NULL;
160   VectorOfSizeDIM * forces = NULL;
161   VectorOfSizeSix * virial = NULL;
162   VectorOfSizeSix * particleVirial = NULL;
163 
164   ier = SetComputeMutableValues(modelComputeArguments,
165                                 isComputeProcess_dEdr,
166                                 isComputeProcess_d2Edr2,
167                                 isComputeEnergy,
168                                 isComputeForces,
169                                 isComputeParticleEnergy,
170                                 isComputeVirial,
171                                 isComputeParticleVirial,
172                                 particleSpeciesCodes,
173                                 particleContributing,
174                                 coordinates,
175                                 energy,
176                                 forces,
177                                 particleEnergy,
178                                 virial,
179                                 particleVirial);
180   if (ier) { return ier; }
181 
182   // Skip this check for efficiency
183   //
184   // ier = CheckParticleSpecies(modelComputeArguments, particleSpeciesCodes);
185   // if (ier) return ier;
186 
187 #include "ANNImplementationComputeDispatch.cpp"
188 
189   return ier;
190 }
191 
192 //******************************************************************************
ComputeArgumentsCreate(KIM::ModelComputeArgumentsCreate * const modelComputeArgumentsCreate) const193 int ANNImplementation::ComputeArgumentsCreate(
194     KIM::ModelComputeArgumentsCreate * const modelComputeArgumentsCreate) const
195 {
196   int ier;
197 
198   ier = RegisterKIMComputeArgumentsSettings(modelComputeArgumentsCreate);
199   if (ier) { return ier; }
200 
201   // nothing else to do for this case
202 
203   // everything is good
204   ier = false;
205   return ier;
206 }
207 
208 //******************************************************************************
ComputeArgumentsDestroy(KIM::ModelComputeArgumentsDestroy * const modelComputeArgumentsDestroy) const209 int ANNImplementation::ComputeArgumentsDestroy(
210     KIM::ModelComputeArgumentsDestroy * const modelComputeArgumentsDestroy)
211     const
212 {
213   int ier;
214 
215   (void) modelComputeArgumentsDestroy;  // avoid not used warning
216 
217   // nothing else to do for this case
218 
219   // everything is good
220   ier = false;
221   return ier;
222 }
223 
224 //==============================================================================
225 //
226 // Implementation of ANNImplementation private member functions
227 //
228 //==============================================================================
229 
230 //******************************************************************************
AllocatePrivateParameterMemory()231 void ANNImplementation::AllocatePrivateParameterMemory()
232 {
233   // nothing to do for this case
234 }
235 
236 //******************************************************************************
AllocateParameterMemory()237 void ANNImplementation::AllocateParameterMemory()
238 {
239   // nothing to do for this case
240 }
241 
242 //******************************************************************************
243 #undef KIM_LOGGER_OBJECT_NAME
244 #define KIM_LOGGER_OBJECT_NAME modelDriverCreate
OpenParameterFiles(KIM::ModelDriverCreate * const modelDriverCreate,int const numberParameterFiles,FILE * parameterFilePointers[MAX_PARAMETER_FILES])245 int ANNImplementation::OpenParameterFiles(
246     KIM::ModelDriverCreate * const modelDriverCreate,
247     int const numberParameterFiles,
248     FILE * parameterFilePointers[MAX_PARAMETER_FILES])
249 {
250   int ier;
251 
252   if (numberParameterFiles > MAX_PARAMETER_FILES)
253   {
254     ier = true;
255     LOG_ERROR("ANN given too many parameter files");
256     return ier;
257   }
258 
259   for (int i = 0; i < numberParameterFiles; ++i)
260   {
261     std::string const * paramFileName;
262     ier = modelDriverCreate->GetParameterFileName(i, &paramFileName);
263     if (ier)
264     {
265       LOG_ERROR("Unable to get parameter file name");
266       return ier;
267     }
268 
269     parameterFilePointers[i] = fopen(paramFileName->c_str(), "r");
270     if (parameterFilePointers[i] == 0)
271     {
272       char message[MAXLINE];
273       sprintf(message, "ANN parameter file number %d cannot be opened", i);
274       ier = true;
275       LOG_ERROR(message);
276       for (int j = i - 1; i <= 0; --i) { fclose(parameterFilePointers[j]); }
277       return ier;
278     }
279   }
280 
281   // everything is good
282   ier = false;
283   return ier;
284 }
285 
286 //******************************************************************************
287 #undef KIM_LOGGER_OBJECT_NAME
288 #define KIM_LOGGER_OBJECT_NAME modelDriverCreate
ProcessParameterFiles(KIM::ModelDriverCreate * const modelDriverCreate,int const numberParameterFiles,FILE * const parameterFilePointers[MAX_PARAMETER_FILES])289 int ANNImplementation::ProcessParameterFiles(
290     KIM::ModelDriverCreate * const modelDriverCreate,
291     int const numberParameterFiles,
292     FILE * const parameterFilePointers[MAX_PARAMETER_FILES])
293 {
294   (void) numberParameterFiles;  // avoid not used warning
295 
296   int ier;
297   char errorMsg[1024];
298 
299   //#######################
300   // descriptor params
301   //#######################
302   ier = descriptor_->read_parameter_file(parameterFilePointers[0]);
303   if (ier)
304   {
305     sprintf(errorMsg, "unable to read descriptor parameter file\n");
306     LOG_ERROR(errorMsg);
307     return true;
308   }
309 
310   // set species
311   int Nspecies = descriptor_->get_num_species();
312   std::vector<std::string> species;
313   descriptor_->get_species(species);
314   for (int i = 0; i < Nspecies; i++)
315   {
316     KIM::SpeciesName const specName(species[i]);
317     if (!specName.Known())
318     {
319       sprintf(errorMsg, "get unknown species\n");
320       LOG_ERROR(errorMsg);
321       return true;
322     }
323     ier = modelDriverCreate->SetSpeciesCode(specName, i);
324     if (ier) { return ier; }
325   }
326 
327   //#######################
328   // model parameters
329   //#######################
330 
331   int desc_size = descriptor_->get_num_descriptors();
332   ier = network_->read_parameter_file(parameterFilePointers[1], desc_size);
333   if (ier)
334   {
335     sprintf(errorMsg, "unable to read neural network parameter file\n");
336     LOG_ERROR(errorMsg);
337     return true;
338   }
339 
340   //#######################
341   // read dropout binary
342   //#######################
343   ier = network_->read_dropout_file(parameterFilePointers[2]);
344   if (ier)
345   {
346     sprintf(errorMsg, "unable to read dropout file\n");
347     LOG_ERROR(errorMsg);
348     return true;
349   }
350   ensemble_size_ = last_ensemble_size_ = network_->get_ensemble_size();
351   active_member_id_ = last_active_member_id_ = -1;  // default to average
352 
353   // everything is good
354   ier = false;
355   return ier;
356 }
357 
358 //******************************************************************************
CloseParameterFiles(int const numberParameterFiles,FILE * const parameterFilePointers[MAX_PARAMETER_FILES])359 void ANNImplementation::CloseParameterFiles(
360     int const numberParameterFiles,
361     FILE * const parameterFilePointers[MAX_PARAMETER_FILES])
362 {
363   for (int i = 0; i < numberParameterFiles; ++i)
364   { fclose(parameterFilePointers[i]); }
365 }
366 
367 //******************************************************************************
368 #undef KIM_LOGGER_OBJECT_NAME
369 #define KIM_LOGGER_OBJECT_NAME modelDriverCreate
ConvertUnits(KIM::ModelDriverCreate * const modelDriverCreate,KIM::LengthUnit const requestedLengthUnit,KIM::EnergyUnit const requestedEnergyUnit,KIM::ChargeUnit const requestedChargeUnit,KIM::TemperatureUnit const requestedTemperatureUnit,KIM::TimeUnit const requestedTimeUnit)370 int ANNImplementation::ConvertUnits(
371     KIM::ModelDriverCreate * const modelDriverCreate,
372     KIM::LengthUnit const requestedLengthUnit,
373     KIM::EnergyUnit const requestedEnergyUnit,
374     KIM::ChargeUnit const requestedChargeUnit,
375     KIM::TemperatureUnit const requestedTemperatureUnit,
376     KIM::TimeUnit const requestedTimeUnit)
377 {
378   int ier;
379 
380   // define default base units
381   KIM::LengthUnit fromLength = KIM::LENGTH_UNIT::A;
382   KIM::EnergyUnit fromEnergy = KIM::ENERGY_UNIT::eV;
383   KIM::ChargeUnit fromCharge = KIM::CHARGE_UNIT::e;
384   KIM::TemperatureUnit fromTemperature = KIM::TEMPERATURE_UNIT::K;
385   KIM::TimeUnit fromTime = KIM::TIME_UNIT::ps;
386 
387   double convertLength = 1.0;
388 
389   ier = modelDriverCreate->ConvertUnit(fromLength,
390                                        fromEnergy,
391                                        fromCharge,
392                                        fromTemperature,
393                                        fromTime,
394                                        requestedLengthUnit,
395                                        requestedEnergyUnit,
396                                        requestedChargeUnit,
397                                        requestedTemperatureUnit,
398                                        requestedTimeUnit,
399                                        1.0,
400                                        0.0,
401                                        0.0,
402                                        0.0,
403                                        0.0,
404                                        &convertLength);
405   if (ier)
406   {
407     LOG_ERROR("Unable to convert length unit");
408     return ier;
409   }
410 
411   double convertEnergy = 1.0;
412   ier = modelDriverCreate->ConvertUnit(fromLength,
413                                        fromEnergy,
414                                        fromCharge,
415                                        fromTemperature,
416                                        fromTime,
417                                        requestedLengthUnit,
418                                        requestedEnergyUnit,
419                                        requestedChargeUnit,
420                                        requestedTemperatureUnit,
421                                        requestedTimeUnit,
422                                        0.0,
423                                        1.0,
424                                        0.0,
425                                        0.0,
426                                        0.0,
427                                        &convertEnergy);
428   if (ier)
429   {
430     LOG_ERROR("Unable to convert energy unit");
431     return ier;
432   }
433 
434   // convert to active units
435   if (convertEnergy != ONE or convertLength != ONE)
436   {
437     descriptor_->convert_units(convertEnergy, convertLength);
438     energyScale_ = convertEnergy;
439   }
440 
441   // register units
442   ier = modelDriverCreate->SetUnits(requestedLengthUnit,
443                                     requestedEnergyUnit,
444                                     KIM::CHARGE_UNIT::unused,
445                                     KIM::TEMPERATURE_UNIT::unused,
446                                     KIM::TIME_UNIT::unused);
447   if (ier)
448   {
449     LOG_ERROR("Unable to set units to requested values");
450     return ier;
451   }
452 
453   // everything is good
454   ier = false;
455   return ier;
456 }
457 
458 //******************************************************************************
RegisterKIMModelSettings(KIM::ModelDriverCreate * const modelDriverCreate) const459 int ANNImplementation::RegisterKIMModelSettings(
460     KIM::ModelDriverCreate * const modelDriverCreate) const
461 {
462   // register numbering
463   int error = modelDriverCreate->SetModelNumbering(KIM::NUMBERING::zeroBased);
464 
465   return error;
466 }
467 
468 //******************************************************************************
469 #undef KIM_LOGGER_OBJECT_NAME
470 #define KIM_LOGGER_OBJECT_NAME modelComputeArgumentsCreate
RegisterKIMComputeArgumentsSettings(KIM::ModelComputeArgumentsCreate * const modelComputeArgumentsCreate) const471 int ANNImplementation::RegisterKIMComputeArgumentsSettings(
472     KIM::ModelComputeArgumentsCreate * const modelComputeArgumentsCreate) const
473 {
474   // register arguments
475   LOG_INFORMATION("Register argument supportStatus");
476 
477   int error = modelComputeArgumentsCreate->SetArgumentSupportStatus(
478                   KIM::COMPUTE_ARGUMENT_NAME::partialEnergy,
479                   KIM::SUPPORT_STATUS::optional)
480               || modelComputeArgumentsCreate->SetArgumentSupportStatus(
481                   KIM::COMPUTE_ARGUMENT_NAME::partialForces,
482                   KIM::SUPPORT_STATUS::optional)
483               || modelComputeArgumentsCreate->SetArgumentSupportStatus(
484                   KIM::COMPUTE_ARGUMENT_NAME::partialParticleEnergy,
485                   KIM::SUPPORT_STATUS::optional)
486               || modelComputeArgumentsCreate->SetArgumentSupportStatus(
487                   KIM::COMPUTE_ARGUMENT_NAME::partialVirial,
488                   KIM::SUPPORT_STATUS::optional)
489               || modelComputeArgumentsCreate->SetArgumentSupportStatus(
490                   KIM::COMPUTE_ARGUMENT_NAME::partialParticleVirial,
491                   KIM::SUPPORT_STATUS::optional);
492 
493   // register callbacks
494   LOG_INFORMATION("Register callback supportStatus");
495   error = error
496           || modelComputeArgumentsCreate->SetCallbackSupportStatus(
497               KIM::COMPUTE_CALLBACK_NAME::ProcessDEDrTerm,
498               KIM::SUPPORT_STATUS::optional)
499           || modelComputeArgumentsCreate->SetCallbackSupportStatus(
500               KIM::COMPUTE_CALLBACK_NAME::ProcessD2EDr2Term,
501               KIM::SUPPORT_STATUS::optional);
502 
503   return error;
504 }
505 
506 //******************************************************************************
507 #undef KIM_LOGGER_OBJECT_NAME
508 #define KIM_LOGGER_OBJECT_NAME modelDriverCreate
RegisterKIMParameters(KIM::ModelDriverCreate * const modelDriverCreate)509 int ANNImplementation::RegisterKIMParameters(
510     KIM::ModelDriverCreate * const modelDriverCreate)
511 {
512   int ier = false;
513 
514   // publish parameters (order is important)
515   ier = modelDriverCreate->SetParameterPointer(
516             1,
517             &ensemble_size_,
518             "ensemble_size",
519             "Size of the ensemble of models. `0` means this is a fully-"
520             "connected neural network that does not support running in "
521             "ensemble mode.")
522         || modelDriverCreate->SetParameterPointer(
523             1,
524             &active_member_id_,
525             "active_member_id",
526             "Running mode of the ensemble of models, with available values of "
527             "`-1, 0, 1, 2, ..., ensemble_size`. If `ensemble_size = 0`, "
528             "this is ignored. Otherwise, `active_member_id = -1` means the "
529             "output "
530             "(energy, forces, etc.) will be obtained by averaging over all "
531             "members of the ensemble (different dropout matrices); "
532             "`active_member_id = 0` means the fully-connected network without "
533             "dropout will be used; and `active_member_id = i` where i is an "
534             "integer from 1 to `ensemble_size` means ensemble member i will be "
535             "used to calculate the output.");
536 
537   if (ier)
538   {
539     LOG_ERROR("set_parameters");
540     return ier;
541   }
542 
543   // everything is good
544   ier = false;
545   return ier;
546 }
547 
548 //******************************************************************************
RegisterKIMFunctions(KIM::ModelDriverCreate * const modelDriverCreate) const549 int ANNImplementation::RegisterKIMFunctions(
550     KIM::ModelDriverCreate * const modelDriverCreate) const
551 {
552   int error;
553 
554   // register functions
555   error = modelDriverCreate->SetRoutinePointer(
556               KIM::MODEL_ROUTINE_NAME::Destroy,
557               KIM::LANGUAGE_NAME::cpp,
558               true,
559               reinterpret_cast<KIM::Function *>(ANN::Destroy))
560           || modelDriverCreate->SetRoutinePointer(
561               KIM::MODEL_ROUTINE_NAME::Refresh,
562               KIM::LANGUAGE_NAME::cpp,
563               true,
564               reinterpret_cast<KIM::Function *>(ANN::Refresh))
565           || modelDriverCreate->SetRoutinePointer(
566               KIM::MODEL_ROUTINE_NAME::Compute,
567               KIM::LANGUAGE_NAME::cpp,
568               true,
569               reinterpret_cast<KIM::Function *>(ANN::Compute))
570           || modelDriverCreate->SetRoutinePointer(
571               KIM::MODEL_ROUTINE_NAME::ComputeArgumentsCreate,
572               KIM::LANGUAGE_NAME::cpp,
573               true,
574               reinterpret_cast<KIM::Function *>(ANN::ComputeArgumentsCreate))
575           || modelDriverCreate->SetRoutinePointer(
576               KIM::MODEL_ROUTINE_NAME::ComputeArgumentsDestroy,
577               KIM::LANGUAGE_NAME::cpp,
578               true,
579               reinterpret_cast<KIM::Function *>(ANN::ComputeArgumentsDestroy));
580 
581   return error;
582 }
583 
584 //******************************************************************************
585 #undef KIM_LOGGER_OBJECT_NAME
586 #define KIM_LOGGER_OBJECT_NAME modelObj
587 template<class ModelObj>
SetRefreshMutableValues(ModelObj * const modelObj)588 int ANNImplementation::SetRefreshMutableValues(ModelObj * const modelObj)
589 {  // use (possibly) new values of parameters to
590    // compute other quantities
591   // NOTE: This function is templated because it's called with both a
592   //       modelDriverCreate object during initialization and with a
593   //       modelRefresh object when the Model's parameters have been altered
594   int ier = true;
595 
596   // checks to make sure ensemble_size_ and active_member_id_ are correct
597   if (ensemble_size_ != last_ensemble_size_)
598   {
599     LOG_ERROR("Value of `ensemble_size` changed.");
600     return ier;
601   }
602   if (active_member_id_ < -1 || active_member_id_ > ensemble_size_)
603   {
604     char message[MAXLINE];
605     sprintf(message,
606             "`active_member_id=%d` out of range. Should be [-1, %d]",
607             active_member_id_,
608             ensemble_size_);
609     LOG_ERROR(message);
610     return ier;
611   }
612   if ((last_ensemble_size_ == 0)
613       && (active_member_id_ != last_active_member_id_))
614   { LOG_INFORMATION("`active_member_id`ignored since `ensemble_size=0`."); }
615   last_active_member_id_ = active_member_id_;
616 
617   // update influence distance value in KIM API object
618   int Nspecies = descriptor_->get_num_species();
619   influenceDistance_ = 0.0;
620   for (int i = 0; i < Nspecies; i++)
621   {
622     for (int j = 0; j < Nspecies; j++)
623     {
624       double cutoff = descriptor_->get_cutoff(i, j);
625       if (influenceDistance_ < cutoff) { influenceDistance_ = cutoff; }
626     }
627   }
628 
629   modelObj->SetInfluenceDistancePointer(&influenceDistance_);
630   modelObj->SetNeighborListPointers(
631       1,
632       &influenceDistance_,
633       &modelWillNotRequestNeighborsOfNoncontributingParticles_);
634 
635   // everything is good
636   ier = false;
637   return ier;
638 }
639 
640 //******************************************************************************
641 #undef KIM_LOGGER_OBJECT_NAME
642 #define KIM_LOGGER_OBJECT_NAME modelComputeArguments
SetComputeMutableValues(KIM::ModelComputeArguments const * const modelComputeArguments,bool & isComputeProcess_dEdr,bool & isComputeProcess_d2Edr2,bool & isComputeEnergy,bool & isComputeForces,bool & isComputeParticleEnergy,bool & isComputeVirial,bool & isComputeParticleVirial,int const * & particleSpeciesCodes,int const * & particleContributing,VectorOfSizeDIM const * & coordinates,double * & energy,VectorOfSizeDIM * & forces,double * & particleEnergy,VectorOfSizeSix * & virial,VectorOfSizeSix * & particleVirial)643 int ANNImplementation::SetComputeMutableValues(
644     KIM::ModelComputeArguments const * const modelComputeArguments,
645     bool & isComputeProcess_dEdr,
646     bool & isComputeProcess_d2Edr2,
647     bool & isComputeEnergy,
648     bool & isComputeForces,
649     bool & isComputeParticleEnergy,
650     bool & isComputeVirial,
651     bool & isComputeParticleVirial,
652     int const *& particleSpeciesCodes,
653     int const *& particleContributing,
654     VectorOfSizeDIM const *& coordinates,
655     double *& energy,
656     VectorOfSizeDIM *& forces,
657     double *& particleEnergy,
658     VectorOfSizeSix *& virial,
659     VectorOfSizeSix *& particleVirial)
660 {
661   int ier = true;
662 
663   // get compute flags
664   int compProcess_dEdr;
665   int compProcess_d2Edr2;
666 
667   modelComputeArguments->IsCallbackPresent(
668       KIM::COMPUTE_CALLBACK_NAME::ProcessDEDrTerm, &compProcess_dEdr);
669   modelComputeArguments->IsCallbackPresent(
670       KIM::COMPUTE_CALLBACK_NAME::ProcessD2EDr2Term, &compProcess_d2Edr2);
671 
672   isComputeProcess_dEdr = compProcess_dEdr;
673   isComputeProcess_d2Edr2 = compProcess_d2Edr2;
674 
675   int const * numberOfParticles;
676   ier = modelComputeArguments->GetArgumentPointer(
677             KIM::COMPUTE_ARGUMENT_NAME::numberOfParticles, &numberOfParticles)
678         || modelComputeArguments->GetArgumentPointer(
679             KIM::COMPUTE_ARGUMENT_NAME::particleSpeciesCodes,
680             &particleSpeciesCodes)
681         || modelComputeArguments->GetArgumentPointer(
682             KIM::COMPUTE_ARGUMENT_NAME::particleContributing,
683             &particleContributing)
684         || modelComputeArguments->GetArgumentPointer(
685             KIM::COMPUTE_ARGUMENT_NAME::coordinates,
686             (double const **) &coordinates)
687         || modelComputeArguments->GetArgumentPointer(
688             KIM::COMPUTE_ARGUMENT_NAME::partialEnergy, &energy)
689         || modelComputeArguments->GetArgumentPointer(
690             KIM::COMPUTE_ARGUMENT_NAME::partialForces,
691             (double const **) &forces)
692         || modelComputeArguments->GetArgumentPointer(
693             KIM::COMPUTE_ARGUMENT_NAME::partialParticleEnergy, &particleEnergy)
694         || modelComputeArguments->GetArgumentPointer(
695             KIM::COMPUTE_ARGUMENT_NAME::partialVirial,
696             (double const **) &virial)
697         || modelComputeArguments->GetArgumentPointer(
698             KIM::COMPUTE_ARGUMENT_NAME::partialParticleVirial,
699             (double const **) &particleVirial);
700   if (ier)
701   {
702     LOG_ERROR("GetArgumentPointer");
703     return ier;
704   }
705 
706   isComputeEnergy = (energy != NULL);
707   isComputeForces = (forces != NULL);
708   isComputeParticleEnergy = (particleEnergy != NULL);
709   isComputeVirial = (virial != NULL);
710   isComputeParticleVirial = (particleVirial != NULL);
711 
712   // update values
713   cachedNumberOfParticles_ = *numberOfParticles;
714 
715   // everything is good
716   ier = false;
717   return ier;
718 }
719 
720 //******************************************************************************
721 // Assume that the particle species interge code starts from 0
722 #undef KIM_LOGGER_OBJECT_NAME
723 #define KIM_LOGGER_OBJECT_NAME modelCompute
CheckParticleSpeciesCodes(KIM::ModelCompute const * const modelCompute,int const * const particleSpeciesCodes) const724 int ANNImplementation::CheckParticleSpeciesCodes(
725     KIM::ModelCompute const * const modelCompute,
726     int const * const particleSpeciesCodes) const
727 {
728   int ier;
729 
730   for (int i = 0; i < cachedNumberOfParticles_; ++i)
731   {
732     if ((particleSpeciesCodes[i] < 0)
733         || (particleSpeciesCodes[i] >= descriptor_->get_num_species()))
734     {
735       ier = true;
736       LOG_ERROR("unsupported particle species codes detected");
737       return ier;
738     }
739   }
740 
741   // everything is good
742   ier = false;
743   return ier;
744 }
745 
746 //******************************************************************************
GetComputeIndex(const bool & isComputeProcess_dEdr,const bool & isComputeProcess_d2Edr2,const bool & isComputeEnergy,const bool & isComputeForces,const bool & isComputeParticleEnergy,const bool & isComputeVirial,const bool & isComputeParticleVirial) const747 int ANNImplementation::GetComputeIndex(
748     const bool & isComputeProcess_dEdr,
749     const bool & isComputeProcess_d2Edr2,
750     const bool & isComputeEnergy,
751     const bool & isComputeForces,
752     const bool & isComputeParticleEnergy,
753     const bool & isComputeVirial,
754     const bool & isComputeParticleVirial) const
755 {
756   // const int processdE = 2;
757   const int processd2E = 2;
758   const int energy = 2;
759   const int force = 2;
760   const int particleEnergy = 2;
761   const int virial = 2;
762   const int particleVirial = 2;
763 
764   int index = 0;
765 
766   // processdE
767   index += (int(isComputeProcess_dEdr)) * processd2E * energy * force
768            * particleEnergy * virial * particleVirial;
769 
770   // processd2E
771   index += (int(isComputeProcess_d2Edr2)) * energy * force * particleEnergy
772            * virial * particleVirial;
773 
774   // energy
775   index += (int(isComputeEnergy)) * force * particleEnergy * virial
776            * particleVirial;
777 
778   // force
779   index += (int(isComputeForces)) * particleEnergy * virial * particleVirial;
780 
781   // particleEnergy
782   index += (int(isComputeParticleEnergy)) * virial * particleVirial;
783 
784   // virial
785   index += (int(isComputeVirial)) * particleVirial;
786 
787   // particleVirial
788   index += (int(isComputeParticleVirial));
789 
790   return index;
791 }
792