1 //
2 // CDDL HEADER START
3 //
4 // The contents of this file are subject to the terms of the Common Development
5 // and Distribution License Version 1.0 (the "License").
6 //
7 // You can obtain a copy of the license at
8 // http://www.opensource.org/licenses/CDDL-1.0. See the License for the
9 // specific language governing permissions and limitations under the License.
10 //
11 // When distributing Covered Code, include this CDDL HEADER in each file and
12 // include the License file in a prominent location with the name LICENSE.CDDL.
13 // If applicable, add the following below this CDDL HEADER, with the fields
14 // enclosed by brackets "[]" replaced with your own identifying information:
15 //
16 // Portions Copyright (c) [yyyy] [name of copyright owner]. All rights reserved.
17 //
18 // CDDL HEADER END
19 //
20
21 //
22 // Copyright (c) 2019, Regents of the University of Minnesota.
23 // All rights reserved.
24 //
25 // Contributors:
26 // Mingjian Wen
27 //
28
29 #include <cmath>
30 #include <cstdlib>
31 #include <cstring>
32 #include <fstream>
33 #include <iostream>
34 #include <map>
35
36 #include "ANNImplementation.hpp"
37 #include "KIM_ModelDriverHeaders.hpp"
38
39 //==============================================================================
40 //
41 // Implementation of ANNImplementation public member functions
42 //
43 //==============================================================================
44
45 //******************************************************************************
46 #undef KIM_LOGGER_OBJECT_NAME
47 #define KIM_LOGGER_OBJECT_NAME modelDriverCreate
ANNImplementation(KIM::ModelDriverCreate * const modelDriverCreate,KIM::LengthUnit const requestedLengthUnit,KIM::EnergyUnit const requestedEnergyUnit,KIM::ChargeUnit const requestedChargeUnit,KIM::TemperatureUnit const requestedTemperatureUnit,KIM::TimeUnit const requestedTimeUnit,int * const ier)48 ANNImplementation::ANNImplementation(
49 KIM::ModelDriverCreate * const modelDriverCreate,
50 KIM::LengthUnit const requestedLengthUnit,
51 KIM::EnergyUnit const requestedEnergyUnit,
52 KIM::ChargeUnit const requestedChargeUnit,
53 KIM::TemperatureUnit const requestedTemperatureUnit,
54 KIM::TimeUnit const requestedTimeUnit,
55 int * const ier) :
56 energyScale_(1.0),
57 ensemble_size_(0),
58 last_ensemble_size_(0),
59 active_member_id_(-1),
60 last_active_member_id_(-1),
61 influenceDistance_(0.0),
62 modelWillNotRequestNeighborsOfNoncontributingParticles_(1),
63 cachedNumberOfParticles_(0)
64 {
65 // create descriptor and network classes
66 descriptor_ = new Descriptor();
67 network_ = new NeuralNetwork();
68
69 FILE * parameterFilePointers[MAX_PARAMETER_FILES];
70 int numberParameterFiles;
71
72 modelDriverCreate->GetNumberOfParameterFiles(&numberParameterFiles);
73 *ier = OpenParameterFiles(
74 modelDriverCreate, numberParameterFiles, parameterFilePointers);
75 if (*ier) { return; }
76
77 *ier = ProcessParameterFiles(
78 modelDriverCreate, numberParameterFiles, parameterFilePointers);
79 CloseParameterFiles(numberParameterFiles, parameterFilePointers);
80 if (*ier) { return; }
81
82 *ier = ConvertUnits(modelDriverCreate,
83 requestedLengthUnit,
84 requestedEnergyUnit,
85 requestedChargeUnit,
86 requestedTemperatureUnit,
87 requestedTimeUnit);
88 if (*ier) { return; }
89
90 *ier = SetRefreshMutableValues(modelDriverCreate);
91 if (*ier) { return; }
92
93 *ier = RegisterKIMModelSettings(modelDriverCreate);
94 if (*ier) { return; }
95
96 *ier = RegisterKIMParameters(modelDriverCreate);
97 if (*ier) { return; }
98
99 *ier = RegisterKIMFunctions(modelDriverCreate);
100 if (*ier) { return; }
101
102 // everything is good
103 *ier = false;
104 return;
105 }
106
107 //******************************************************************************
~ANNImplementation()108 ANNImplementation::~ANNImplementation()
109 { // note: it is ok to delete a null
110 // pointer and we have ensured that
111 // everything is initialized to null
112
113 delete descriptor_;
114 delete network_;
115 }
116
117 //******************************************************************************
118 #undef KIM_LOGGER_OBJECT_NAME
119 #define KIM_LOGGER_OBJECT_NAME modelRefresh
Refresh(KIM::ModelRefresh * const modelRefresh)120 int ANNImplementation::Refresh(KIM::ModelRefresh * const modelRefresh)
121 {
122 int ier;
123
124 ier = SetRefreshMutableValues(modelRefresh);
125 if (ier) { return ier; }
126
127 // nothing else to do for this case
128
129 // everything is good
130 ier = false;
131 return ier;
132 }
133
134 //******************************************************************************
Compute(KIM::ModelCompute const * const modelCompute,KIM::ModelComputeArguments const * const modelComputeArguments)135 int ANNImplementation::Compute(
136 KIM::ModelCompute const * const modelCompute,
137 KIM::ModelComputeArguments const * const modelComputeArguments)
138 {
139 int ier;
140
141 // KIM API Model Input compute flags
142 bool isComputeProcess_dEdr = false;
143 bool isComputeProcess_d2Edr2 = false;
144 //
145 // KIM API Model Output compute flags
146 bool isComputeEnergy = false;
147 bool isComputeForces = false;
148 bool isComputeParticleEnergy = false;
149 bool isComputeVirial = false;
150 bool isComputeParticleVirial = false;
151 //
152 // KIM API Model Input
153 int const * particleSpeciesCodes = NULL;
154 int const * particleContributing = NULL;
155 VectorOfSizeDIM const * coordinates = NULL;
156 //
157 // KIM API Model Output
158 double * energy = NULL;
159 double * particleEnergy = NULL;
160 VectorOfSizeDIM * forces = NULL;
161 VectorOfSizeSix * virial = NULL;
162 VectorOfSizeSix * particleVirial = NULL;
163
164 ier = SetComputeMutableValues(modelComputeArguments,
165 isComputeProcess_dEdr,
166 isComputeProcess_d2Edr2,
167 isComputeEnergy,
168 isComputeForces,
169 isComputeParticleEnergy,
170 isComputeVirial,
171 isComputeParticleVirial,
172 particleSpeciesCodes,
173 particleContributing,
174 coordinates,
175 energy,
176 forces,
177 particleEnergy,
178 virial,
179 particleVirial);
180 if (ier) { return ier; }
181
182 // Skip this check for efficiency
183 //
184 // ier = CheckParticleSpecies(modelComputeArguments, particleSpeciesCodes);
185 // if (ier) return ier;
186
187 #include "ANNImplementationComputeDispatch.cpp"
188
189 return ier;
190 }
191
192 //******************************************************************************
ComputeArgumentsCreate(KIM::ModelComputeArgumentsCreate * const modelComputeArgumentsCreate) const193 int ANNImplementation::ComputeArgumentsCreate(
194 KIM::ModelComputeArgumentsCreate * const modelComputeArgumentsCreate) const
195 {
196 int ier;
197
198 ier = RegisterKIMComputeArgumentsSettings(modelComputeArgumentsCreate);
199 if (ier) { return ier; }
200
201 // nothing else to do for this case
202
203 // everything is good
204 ier = false;
205 return ier;
206 }
207
208 //******************************************************************************
ComputeArgumentsDestroy(KIM::ModelComputeArgumentsDestroy * const modelComputeArgumentsDestroy) const209 int ANNImplementation::ComputeArgumentsDestroy(
210 KIM::ModelComputeArgumentsDestroy * const modelComputeArgumentsDestroy)
211 const
212 {
213 int ier;
214
215 (void) modelComputeArgumentsDestroy; // avoid not used warning
216
217 // nothing else to do for this case
218
219 // everything is good
220 ier = false;
221 return ier;
222 }
223
224 //==============================================================================
225 //
226 // Implementation of ANNImplementation private member functions
227 //
228 //==============================================================================
229
230 //******************************************************************************
AllocatePrivateParameterMemory()231 void ANNImplementation::AllocatePrivateParameterMemory()
232 {
233 // nothing to do for this case
234 }
235
236 //******************************************************************************
AllocateParameterMemory()237 void ANNImplementation::AllocateParameterMemory()
238 {
239 // nothing to do for this case
240 }
241
242 //******************************************************************************
243 #undef KIM_LOGGER_OBJECT_NAME
244 #define KIM_LOGGER_OBJECT_NAME modelDriverCreate
OpenParameterFiles(KIM::ModelDriverCreate * const modelDriverCreate,int const numberParameterFiles,FILE * parameterFilePointers[MAX_PARAMETER_FILES])245 int ANNImplementation::OpenParameterFiles(
246 KIM::ModelDriverCreate * const modelDriverCreate,
247 int const numberParameterFiles,
248 FILE * parameterFilePointers[MAX_PARAMETER_FILES])
249 {
250 int ier;
251
252 if (numberParameterFiles > MAX_PARAMETER_FILES)
253 {
254 ier = true;
255 LOG_ERROR("ANN given too many parameter files");
256 return ier;
257 }
258
259 for (int i = 0; i < numberParameterFiles; ++i)
260 {
261 std::string const * paramFileName;
262 ier = modelDriverCreate->GetParameterFileName(i, ¶mFileName);
263 if (ier)
264 {
265 LOG_ERROR("Unable to get parameter file name");
266 return ier;
267 }
268
269 parameterFilePointers[i] = fopen(paramFileName->c_str(), "r");
270 if (parameterFilePointers[i] == 0)
271 {
272 char message[MAXLINE];
273 sprintf(message, "ANN parameter file number %d cannot be opened", i);
274 ier = true;
275 LOG_ERROR(message);
276 for (int j = i - 1; i <= 0; --i) { fclose(parameterFilePointers[j]); }
277 return ier;
278 }
279 }
280
281 // everything is good
282 ier = false;
283 return ier;
284 }
285
286 //******************************************************************************
287 #undef KIM_LOGGER_OBJECT_NAME
288 #define KIM_LOGGER_OBJECT_NAME modelDriverCreate
ProcessParameterFiles(KIM::ModelDriverCreate * const modelDriverCreate,int const numberParameterFiles,FILE * const parameterFilePointers[MAX_PARAMETER_FILES])289 int ANNImplementation::ProcessParameterFiles(
290 KIM::ModelDriverCreate * const modelDriverCreate,
291 int const numberParameterFiles,
292 FILE * const parameterFilePointers[MAX_PARAMETER_FILES])
293 {
294 (void) numberParameterFiles; // avoid not used warning
295
296 int ier;
297 char errorMsg[1024];
298
299 //#######################
300 // descriptor params
301 //#######################
302 ier = descriptor_->read_parameter_file(parameterFilePointers[0]);
303 if (ier)
304 {
305 sprintf(errorMsg, "unable to read descriptor parameter file\n");
306 LOG_ERROR(errorMsg);
307 return true;
308 }
309
310 // set species
311 int Nspecies = descriptor_->get_num_species();
312 std::vector<std::string> species;
313 descriptor_->get_species(species);
314 for (int i = 0; i < Nspecies; i++)
315 {
316 KIM::SpeciesName const specName(species[i]);
317 if (!specName.Known())
318 {
319 sprintf(errorMsg, "get unknown species\n");
320 LOG_ERROR(errorMsg);
321 return true;
322 }
323 ier = modelDriverCreate->SetSpeciesCode(specName, i);
324 if (ier) { return ier; }
325 }
326
327 //#######################
328 // model parameters
329 //#######################
330
331 int desc_size = descriptor_->get_num_descriptors();
332 ier = network_->read_parameter_file(parameterFilePointers[1], desc_size);
333 if (ier)
334 {
335 sprintf(errorMsg, "unable to read neural network parameter file\n");
336 LOG_ERROR(errorMsg);
337 return true;
338 }
339
340 //#######################
341 // read dropout binary
342 //#######################
343 ier = network_->read_dropout_file(parameterFilePointers[2]);
344 if (ier)
345 {
346 sprintf(errorMsg, "unable to read dropout file\n");
347 LOG_ERROR(errorMsg);
348 return true;
349 }
350 ensemble_size_ = last_ensemble_size_ = network_->get_ensemble_size();
351 active_member_id_ = last_active_member_id_ = -1; // default to average
352
353 // everything is good
354 ier = false;
355 return ier;
356 }
357
358 //******************************************************************************
CloseParameterFiles(int const numberParameterFiles,FILE * const parameterFilePointers[MAX_PARAMETER_FILES])359 void ANNImplementation::CloseParameterFiles(
360 int const numberParameterFiles,
361 FILE * const parameterFilePointers[MAX_PARAMETER_FILES])
362 {
363 for (int i = 0; i < numberParameterFiles; ++i)
364 { fclose(parameterFilePointers[i]); }
365 }
366
367 //******************************************************************************
368 #undef KIM_LOGGER_OBJECT_NAME
369 #define KIM_LOGGER_OBJECT_NAME modelDriverCreate
ConvertUnits(KIM::ModelDriverCreate * const modelDriverCreate,KIM::LengthUnit const requestedLengthUnit,KIM::EnergyUnit const requestedEnergyUnit,KIM::ChargeUnit const requestedChargeUnit,KIM::TemperatureUnit const requestedTemperatureUnit,KIM::TimeUnit const requestedTimeUnit)370 int ANNImplementation::ConvertUnits(
371 KIM::ModelDriverCreate * const modelDriverCreate,
372 KIM::LengthUnit const requestedLengthUnit,
373 KIM::EnergyUnit const requestedEnergyUnit,
374 KIM::ChargeUnit const requestedChargeUnit,
375 KIM::TemperatureUnit const requestedTemperatureUnit,
376 KIM::TimeUnit const requestedTimeUnit)
377 {
378 int ier;
379
380 // define default base units
381 KIM::LengthUnit fromLength = KIM::LENGTH_UNIT::A;
382 KIM::EnergyUnit fromEnergy = KIM::ENERGY_UNIT::eV;
383 KIM::ChargeUnit fromCharge = KIM::CHARGE_UNIT::e;
384 KIM::TemperatureUnit fromTemperature = KIM::TEMPERATURE_UNIT::K;
385 KIM::TimeUnit fromTime = KIM::TIME_UNIT::ps;
386
387 double convertLength = 1.0;
388
389 ier = modelDriverCreate->ConvertUnit(fromLength,
390 fromEnergy,
391 fromCharge,
392 fromTemperature,
393 fromTime,
394 requestedLengthUnit,
395 requestedEnergyUnit,
396 requestedChargeUnit,
397 requestedTemperatureUnit,
398 requestedTimeUnit,
399 1.0,
400 0.0,
401 0.0,
402 0.0,
403 0.0,
404 &convertLength);
405 if (ier)
406 {
407 LOG_ERROR("Unable to convert length unit");
408 return ier;
409 }
410
411 double convertEnergy = 1.0;
412 ier = modelDriverCreate->ConvertUnit(fromLength,
413 fromEnergy,
414 fromCharge,
415 fromTemperature,
416 fromTime,
417 requestedLengthUnit,
418 requestedEnergyUnit,
419 requestedChargeUnit,
420 requestedTemperatureUnit,
421 requestedTimeUnit,
422 0.0,
423 1.0,
424 0.0,
425 0.0,
426 0.0,
427 &convertEnergy);
428 if (ier)
429 {
430 LOG_ERROR("Unable to convert energy unit");
431 return ier;
432 }
433
434 // convert to active units
435 if (convertEnergy != ONE or convertLength != ONE)
436 {
437 descriptor_->convert_units(convertEnergy, convertLength);
438 energyScale_ = convertEnergy;
439 }
440
441 // register units
442 ier = modelDriverCreate->SetUnits(requestedLengthUnit,
443 requestedEnergyUnit,
444 KIM::CHARGE_UNIT::unused,
445 KIM::TEMPERATURE_UNIT::unused,
446 KIM::TIME_UNIT::unused);
447 if (ier)
448 {
449 LOG_ERROR("Unable to set units to requested values");
450 return ier;
451 }
452
453 // everything is good
454 ier = false;
455 return ier;
456 }
457
458 //******************************************************************************
RegisterKIMModelSettings(KIM::ModelDriverCreate * const modelDriverCreate) const459 int ANNImplementation::RegisterKIMModelSettings(
460 KIM::ModelDriverCreate * const modelDriverCreate) const
461 {
462 // register numbering
463 int error = modelDriverCreate->SetModelNumbering(KIM::NUMBERING::zeroBased);
464
465 return error;
466 }
467
468 //******************************************************************************
469 #undef KIM_LOGGER_OBJECT_NAME
470 #define KIM_LOGGER_OBJECT_NAME modelComputeArgumentsCreate
RegisterKIMComputeArgumentsSettings(KIM::ModelComputeArgumentsCreate * const modelComputeArgumentsCreate) const471 int ANNImplementation::RegisterKIMComputeArgumentsSettings(
472 KIM::ModelComputeArgumentsCreate * const modelComputeArgumentsCreate) const
473 {
474 // register arguments
475 LOG_INFORMATION("Register argument supportStatus");
476
477 int error = modelComputeArgumentsCreate->SetArgumentSupportStatus(
478 KIM::COMPUTE_ARGUMENT_NAME::partialEnergy,
479 KIM::SUPPORT_STATUS::optional)
480 || modelComputeArgumentsCreate->SetArgumentSupportStatus(
481 KIM::COMPUTE_ARGUMENT_NAME::partialForces,
482 KIM::SUPPORT_STATUS::optional)
483 || modelComputeArgumentsCreate->SetArgumentSupportStatus(
484 KIM::COMPUTE_ARGUMENT_NAME::partialParticleEnergy,
485 KIM::SUPPORT_STATUS::optional)
486 || modelComputeArgumentsCreate->SetArgumentSupportStatus(
487 KIM::COMPUTE_ARGUMENT_NAME::partialVirial,
488 KIM::SUPPORT_STATUS::optional)
489 || modelComputeArgumentsCreate->SetArgumentSupportStatus(
490 KIM::COMPUTE_ARGUMENT_NAME::partialParticleVirial,
491 KIM::SUPPORT_STATUS::optional);
492
493 // register callbacks
494 LOG_INFORMATION("Register callback supportStatus");
495 error = error
496 || modelComputeArgumentsCreate->SetCallbackSupportStatus(
497 KIM::COMPUTE_CALLBACK_NAME::ProcessDEDrTerm,
498 KIM::SUPPORT_STATUS::optional)
499 || modelComputeArgumentsCreate->SetCallbackSupportStatus(
500 KIM::COMPUTE_CALLBACK_NAME::ProcessD2EDr2Term,
501 KIM::SUPPORT_STATUS::optional);
502
503 return error;
504 }
505
506 //******************************************************************************
507 #undef KIM_LOGGER_OBJECT_NAME
508 #define KIM_LOGGER_OBJECT_NAME modelDriverCreate
RegisterKIMParameters(KIM::ModelDriverCreate * const modelDriverCreate)509 int ANNImplementation::RegisterKIMParameters(
510 KIM::ModelDriverCreate * const modelDriverCreate)
511 {
512 int ier = false;
513
514 // publish parameters (order is important)
515 ier = modelDriverCreate->SetParameterPointer(
516 1,
517 &ensemble_size_,
518 "ensemble_size",
519 "Size of the ensemble of models. `0` means this is a fully-"
520 "connected neural network that does not support running in "
521 "ensemble mode.")
522 || modelDriverCreate->SetParameterPointer(
523 1,
524 &active_member_id_,
525 "active_member_id",
526 "Running mode of the ensemble of models, with available values of "
527 "`-1, 0, 1, 2, ..., ensemble_size`. If `ensemble_size = 0`, "
528 "this is ignored. Otherwise, `active_member_id = -1` means the "
529 "output "
530 "(energy, forces, etc.) will be obtained by averaging over all "
531 "members of the ensemble (different dropout matrices); "
532 "`active_member_id = 0` means the fully-connected network without "
533 "dropout will be used; and `active_member_id = i` where i is an "
534 "integer from 1 to `ensemble_size` means ensemble member i will be "
535 "used to calculate the output.");
536
537 if (ier)
538 {
539 LOG_ERROR("set_parameters");
540 return ier;
541 }
542
543 // everything is good
544 ier = false;
545 return ier;
546 }
547
548 //******************************************************************************
RegisterKIMFunctions(KIM::ModelDriverCreate * const modelDriverCreate) const549 int ANNImplementation::RegisterKIMFunctions(
550 KIM::ModelDriverCreate * const modelDriverCreate) const
551 {
552 int error;
553
554 // register functions
555 error = modelDriverCreate->SetRoutinePointer(
556 KIM::MODEL_ROUTINE_NAME::Destroy,
557 KIM::LANGUAGE_NAME::cpp,
558 true,
559 reinterpret_cast<KIM::Function *>(ANN::Destroy))
560 || modelDriverCreate->SetRoutinePointer(
561 KIM::MODEL_ROUTINE_NAME::Refresh,
562 KIM::LANGUAGE_NAME::cpp,
563 true,
564 reinterpret_cast<KIM::Function *>(ANN::Refresh))
565 || modelDriverCreate->SetRoutinePointer(
566 KIM::MODEL_ROUTINE_NAME::Compute,
567 KIM::LANGUAGE_NAME::cpp,
568 true,
569 reinterpret_cast<KIM::Function *>(ANN::Compute))
570 || modelDriverCreate->SetRoutinePointer(
571 KIM::MODEL_ROUTINE_NAME::ComputeArgumentsCreate,
572 KIM::LANGUAGE_NAME::cpp,
573 true,
574 reinterpret_cast<KIM::Function *>(ANN::ComputeArgumentsCreate))
575 || modelDriverCreate->SetRoutinePointer(
576 KIM::MODEL_ROUTINE_NAME::ComputeArgumentsDestroy,
577 KIM::LANGUAGE_NAME::cpp,
578 true,
579 reinterpret_cast<KIM::Function *>(ANN::ComputeArgumentsDestroy));
580
581 return error;
582 }
583
584 //******************************************************************************
585 #undef KIM_LOGGER_OBJECT_NAME
586 #define KIM_LOGGER_OBJECT_NAME modelObj
587 template<class ModelObj>
SetRefreshMutableValues(ModelObj * const modelObj)588 int ANNImplementation::SetRefreshMutableValues(ModelObj * const modelObj)
589 { // use (possibly) new values of parameters to
590 // compute other quantities
591 // NOTE: This function is templated because it's called with both a
592 // modelDriverCreate object during initialization and with a
593 // modelRefresh object when the Model's parameters have been altered
594 int ier = true;
595
596 // checks to make sure ensemble_size_ and active_member_id_ are correct
597 if (ensemble_size_ != last_ensemble_size_)
598 {
599 LOG_ERROR("Value of `ensemble_size` changed.");
600 return ier;
601 }
602 if (active_member_id_ < -1 || active_member_id_ > ensemble_size_)
603 {
604 char message[MAXLINE];
605 sprintf(message,
606 "`active_member_id=%d` out of range. Should be [-1, %d]",
607 active_member_id_,
608 ensemble_size_);
609 LOG_ERROR(message);
610 return ier;
611 }
612 if ((last_ensemble_size_ == 0)
613 && (active_member_id_ != last_active_member_id_))
614 { LOG_INFORMATION("`active_member_id`ignored since `ensemble_size=0`."); }
615 last_active_member_id_ = active_member_id_;
616
617 // update influence distance value in KIM API object
618 int Nspecies = descriptor_->get_num_species();
619 influenceDistance_ = 0.0;
620 for (int i = 0; i < Nspecies; i++)
621 {
622 for (int j = 0; j < Nspecies; j++)
623 {
624 double cutoff = descriptor_->get_cutoff(i, j);
625 if (influenceDistance_ < cutoff) { influenceDistance_ = cutoff; }
626 }
627 }
628
629 modelObj->SetInfluenceDistancePointer(&influenceDistance_);
630 modelObj->SetNeighborListPointers(
631 1,
632 &influenceDistance_,
633 &modelWillNotRequestNeighborsOfNoncontributingParticles_);
634
635 // everything is good
636 ier = false;
637 return ier;
638 }
639
640 //******************************************************************************
641 #undef KIM_LOGGER_OBJECT_NAME
642 #define KIM_LOGGER_OBJECT_NAME modelComputeArguments
SetComputeMutableValues(KIM::ModelComputeArguments const * const modelComputeArguments,bool & isComputeProcess_dEdr,bool & isComputeProcess_d2Edr2,bool & isComputeEnergy,bool & isComputeForces,bool & isComputeParticleEnergy,bool & isComputeVirial,bool & isComputeParticleVirial,int const * & particleSpeciesCodes,int const * & particleContributing,VectorOfSizeDIM const * & coordinates,double * & energy,VectorOfSizeDIM * & forces,double * & particleEnergy,VectorOfSizeSix * & virial,VectorOfSizeSix * & particleVirial)643 int ANNImplementation::SetComputeMutableValues(
644 KIM::ModelComputeArguments const * const modelComputeArguments,
645 bool & isComputeProcess_dEdr,
646 bool & isComputeProcess_d2Edr2,
647 bool & isComputeEnergy,
648 bool & isComputeForces,
649 bool & isComputeParticleEnergy,
650 bool & isComputeVirial,
651 bool & isComputeParticleVirial,
652 int const *& particleSpeciesCodes,
653 int const *& particleContributing,
654 VectorOfSizeDIM const *& coordinates,
655 double *& energy,
656 VectorOfSizeDIM *& forces,
657 double *& particleEnergy,
658 VectorOfSizeSix *& virial,
659 VectorOfSizeSix *& particleVirial)
660 {
661 int ier = true;
662
663 // get compute flags
664 int compProcess_dEdr;
665 int compProcess_d2Edr2;
666
667 modelComputeArguments->IsCallbackPresent(
668 KIM::COMPUTE_CALLBACK_NAME::ProcessDEDrTerm, &compProcess_dEdr);
669 modelComputeArguments->IsCallbackPresent(
670 KIM::COMPUTE_CALLBACK_NAME::ProcessD2EDr2Term, &compProcess_d2Edr2);
671
672 isComputeProcess_dEdr = compProcess_dEdr;
673 isComputeProcess_d2Edr2 = compProcess_d2Edr2;
674
675 int const * numberOfParticles;
676 ier = modelComputeArguments->GetArgumentPointer(
677 KIM::COMPUTE_ARGUMENT_NAME::numberOfParticles, &numberOfParticles)
678 || modelComputeArguments->GetArgumentPointer(
679 KIM::COMPUTE_ARGUMENT_NAME::particleSpeciesCodes,
680 &particleSpeciesCodes)
681 || modelComputeArguments->GetArgumentPointer(
682 KIM::COMPUTE_ARGUMENT_NAME::particleContributing,
683 &particleContributing)
684 || modelComputeArguments->GetArgumentPointer(
685 KIM::COMPUTE_ARGUMENT_NAME::coordinates,
686 (double const **) &coordinates)
687 || modelComputeArguments->GetArgumentPointer(
688 KIM::COMPUTE_ARGUMENT_NAME::partialEnergy, &energy)
689 || modelComputeArguments->GetArgumentPointer(
690 KIM::COMPUTE_ARGUMENT_NAME::partialForces,
691 (double const **) &forces)
692 || modelComputeArguments->GetArgumentPointer(
693 KIM::COMPUTE_ARGUMENT_NAME::partialParticleEnergy, &particleEnergy)
694 || modelComputeArguments->GetArgumentPointer(
695 KIM::COMPUTE_ARGUMENT_NAME::partialVirial,
696 (double const **) &virial)
697 || modelComputeArguments->GetArgumentPointer(
698 KIM::COMPUTE_ARGUMENT_NAME::partialParticleVirial,
699 (double const **) &particleVirial);
700 if (ier)
701 {
702 LOG_ERROR("GetArgumentPointer");
703 return ier;
704 }
705
706 isComputeEnergy = (energy != NULL);
707 isComputeForces = (forces != NULL);
708 isComputeParticleEnergy = (particleEnergy != NULL);
709 isComputeVirial = (virial != NULL);
710 isComputeParticleVirial = (particleVirial != NULL);
711
712 // update values
713 cachedNumberOfParticles_ = *numberOfParticles;
714
715 // everything is good
716 ier = false;
717 return ier;
718 }
719
720 //******************************************************************************
721 // Assume that the particle species interge code starts from 0
722 #undef KIM_LOGGER_OBJECT_NAME
723 #define KIM_LOGGER_OBJECT_NAME modelCompute
CheckParticleSpeciesCodes(KIM::ModelCompute const * const modelCompute,int const * const particleSpeciesCodes) const724 int ANNImplementation::CheckParticleSpeciesCodes(
725 KIM::ModelCompute const * const modelCompute,
726 int const * const particleSpeciesCodes) const
727 {
728 int ier;
729
730 for (int i = 0; i < cachedNumberOfParticles_; ++i)
731 {
732 if ((particleSpeciesCodes[i] < 0)
733 || (particleSpeciesCodes[i] >= descriptor_->get_num_species()))
734 {
735 ier = true;
736 LOG_ERROR("unsupported particle species codes detected");
737 return ier;
738 }
739 }
740
741 // everything is good
742 ier = false;
743 return ier;
744 }
745
746 //******************************************************************************
GetComputeIndex(const bool & isComputeProcess_dEdr,const bool & isComputeProcess_d2Edr2,const bool & isComputeEnergy,const bool & isComputeForces,const bool & isComputeParticleEnergy,const bool & isComputeVirial,const bool & isComputeParticleVirial) const747 int ANNImplementation::GetComputeIndex(
748 const bool & isComputeProcess_dEdr,
749 const bool & isComputeProcess_d2Edr2,
750 const bool & isComputeEnergy,
751 const bool & isComputeForces,
752 const bool & isComputeParticleEnergy,
753 const bool & isComputeVirial,
754 const bool & isComputeParticleVirial) const
755 {
756 // const int processdE = 2;
757 const int processd2E = 2;
758 const int energy = 2;
759 const int force = 2;
760 const int particleEnergy = 2;
761 const int virial = 2;
762 const int particleVirial = 2;
763
764 int index = 0;
765
766 // processdE
767 index += (int(isComputeProcess_dEdr)) * processd2E * energy * force
768 * particleEnergy * virial * particleVirial;
769
770 // processd2E
771 index += (int(isComputeProcess_d2Edr2)) * energy * force * particleEnergy
772 * virial * particleVirial;
773
774 // energy
775 index += (int(isComputeEnergy)) * force * particleEnergy * virial
776 * particleVirial;
777
778 // force
779 index += (int(isComputeForces)) * particleEnergy * virial * particleVirial;
780
781 // particleEnergy
782 index += (int(isComputeParticleEnergy)) * virial * particleVirial;
783
784 // virial
785 index += (int(isComputeVirial)) * particleVirial;
786
787 // particleVirial
788 index += (int(isComputeParticleVirial));
789
790 return index;
791 }
792