1 /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 Copyright (c) 2015-2021 The plumed team
3 (see the PEOPLE file at the root of the distribution for a list of names)
4
5 See http://www.plumed.org for more information.
6
7 This file is part of plumed, version 2.
8
9 plumed is free software: you can redistribute it and/or modify
10 it under the terms of the GNU Lesser General Public License as published by
11 the Free Software Foundation, either version 3 of the License, or
12 (at your option) any later version.
13
14 plumed is distributed in the hope that it will be useful,
15 but WITHOUT ANY WARRANTY; without even the implied warranty of
16 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 GNU Lesser General Public License for more details.
18
19 You should have received a copy of the GNU Lesser General Public License
20 along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 #include "Function.h"
23 #include "ActionRegister.h"
24
25 using namespace std;
26
27 namespace PLMD {
28 namespace function {
29
30 //+PLUMEDOC FUNCTION STATS
31 /*
32 Calculates statistical properties of a set of collective variables with respect to a set of reference values.
33
34 In particular it calculates and stores as components the sum of the squared deviations, the correlation, the
35 slope and the intercept of a linear fit.
36
37 The reference values can be either provided as values using PARAMETERS or using value without derivatives
38 from other actions using PARARG (for example using experimental values from collective variables such as
39 \ref CS2BACKBONE, \ref RDC, \ref NOE, \ref PRE).
40
41 \par Examples
42
43 The following input tells plumed to print the distance between three couple of atoms
44 and compare them with three reference distances.
45
46 \plumedfile
47 d1: DISTANCE ATOMS=10,50
48 d2: DISTANCE ATOMS=1,100
49 d3: DISTANCE ATOMS=45,75
50 st: STATS ARG=d1,d2,d3 PARAMETERS=1.5,4.0,2.0
51 PRINT ARG=d1,d2,d3,st.*
52 \endplumedfile
53
54 */
55 //+ENDPLUMEDOC
56
57
58 class Stats :
59 public Function
60 {
61 std::vector<double> parameters;
62 bool sqdonly;
63 bool components;
64 bool upperd;
65 public:
66 explicit Stats(const ActionOptions&);
67 void calculate() override;
68 static void registerKeywords(Keywords& keys);
69 };
70
71
72 PLUMED_REGISTER_ACTION(Stats,"STATS")
73
registerKeywords(Keywords & keys)74 void Stats::registerKeywords(Keywords& keys) {
75 Function::registerKeywords(keys);
76 keys.use("ARG");
77 keys.add("optional","PARARG","the input for this action is the scalar output from one or more other actions without derivatives.");
78 keys.add("optional","PARAMETERS","the parameters of the arguments in your function");
79 keys.addFlag("SQDEVSUM",false,"calculates only SQDEVSUM");
80 keys.addFlag("SQDEV",false,"calculates and store the SQDEV as components");
81 keys.addFlag("UPPERDISTS",false,"calculates and store the SQDEV as components");
82 keys.addOutputComponent("sqdevsum","default","the sum of the squared deviations between arguments and parameters");
83 keys.addOutputComponent("corr","default","the correlation between arguments and parameters");
84 keys.addOutputComponent("slope","default","the slope of a linear fit between arguments and parameters");
85 keys.addOutputComponent("intercept","default","the intercept of a linear fit between arguments and parameters");
86 keys.addOutputComponent("sqd","SQDEV","the squared deviations between arguments and parameters");
87 }
88
Stats(const ActionOptions & ao)89 Stats::Stats(const ActionOptions&ao):
90 Action(ao),
91 Function(ao),
92 sqdonly(false),
93 components(false),
94 upperd(false)
95 {
96 parseVector("PARAMETERS",parameters);
97 if(parameters.size()!=static_cast<unsigned>(getNumberOfArguments())&&!parameters.empty())
98 error("Size of PARAMETERS array should be either 0 or the same as of the number of arguments in ARG1");
99
100 vector<Value*> arg2;
101 parseArgumentList("PARARG",arg2);
102
103 if(!arg2.empty()) {
104 if(parameters.size()>0) error("It is not possible to use PARARG and PARAMETERS together");
105 if(arg2.size()!=getNumberOfArguments()) error("Size of PARARG array should be the same as number for arguments in ARG");
106 for(unsigned i=0; i<arg2.size(); i++) {
107 parameters.push_back(arg2[i]->get());
108 if(arg2[i]->hasDerivatives()==true) error("PARARG can only accept arguments without derivatives");
109 }
110 }
111
112 if(parameters.size()!=getNumberOfArguments())
113 error("PARARG or PARAMETERS arrays should include the same number of elements as the arguments in ARG");
114
115 if(getNumberOfArguments()<2) error("STATS need at least two arguments to be used");
116
117 parseFlag("SQDEVSUM",sqdonly);
118 parseFlag("SQDEV",components);
119 parseFlag("UPPERDISTS",upperd);
120
121 if(sqdonly&&components) error("You cannot used SQDEVSUM and SQDEV at the sametime");
122
123 if(components) sqdonly = true;
124
125 if(!arg2.empty()) log.printf(" using %zu parameters from inactive actions:", arg2.size());
126 else log.printf(" using %zu parameters:", arg2.size());
127 for(unsigned i=0; i<parameters.size(); i++) log.printf(" %f",parameters[i]);
128 log.printf("\n");
129
130 if(sqdonly) {
131 if(components) {
132 for(unsigned i=0; i<parameters.size(); i++) {
133 std::string num; Tools::convert(i,num);
134 addComponentWithDerivatives("sqd-"+num);
135 componentIsNotPeriodic("sqd-"+num);
136 }
137 } else {
138 addComponentWithDerivatives("sqdevsum");
139 componentIsNotPeriodic("sqdevsum");
140 }
141 } else {
142 addComponentWithDerivatives("sqdevsum");
143 componentIsNotPeriodic("sqdevsum");
144 addComponentWithDerivatives("corr");
145 componentIsNotPeriodic("corr");
146 addComponentWithDerivatives("slope");
147 componentIsNotPeriodic("slope");
148 addComponentWithDerivatives("intercept");
149 componentIsNotPeriodic("intercept");
150 }
151
152
153 checkRead();
154 }
155
calculate()156 void Stats::calculate()
157 {
158 if(sqdonly) {
159
160 double nsqd = 0.;
161 Value* val;
162 if(!components) val=getPntrToComponent("sqdevsum");
163 for(unsigned i=0; i<parameters.size(); ++i) {
164 double dev = getArgument(i)-parameters[i];
165 if(upperd&&dev<0) dev=0.;
166 if(components) {
167 val=getPntrToComponent(i);
168 val->set(dev*dev);
169 } else {
170 nsqd += dev*dev;
171 }
172 setDerivative(val,i,2.*dev);
173 }
174 if(!components) val->set(nsqd);
175
176 } else {
177
178 double scx=0., scx2=0., scy=0., scy2=0., scxy=0.;
179
180 for(unsigned i=0; i<parameters.size(); ++i) {
181 const double tmpx=getArgument(i);
182 const double tmpy=parameters[i];
183 scx += tmpx;
184 scx2 += tmpx*tmpx;
185 scy += tmpy;
186 scy2 += tmpy*tmpy;
187 scxy += tmpx*tmpy;
188 }
189
190 const double ns = parameters.size();
191
192 const double num = ns*scxy - scx*scy;
193 const double idev2x = 1./(ns*scx2-scx*scx);
194 const double idevx = sqrt(idev2x);
195 const double idevy = 1./sqrt(ns*scy2-scy*scy);
196
197 /* sd */
198 const double nsqd = scx2 + scy2 - 2.*scxy;
199 /* correlation */
200 const double correlation = num * idevx * idevy;
201 /* slope and intercept */
202 const double slope = num * idev2x;
203 const double inter = (scy - slope * scx)/ns;
204
205 Value* valuea=getPntrToComponent("sqdevsum");
206 Value* valueb=getPntrToComponent("corr");
207 Value* valuec=getPntrToComponent("slope");
208 Value* valued=getPntrToComponent("intercept");
209
210 valuea->set(nsqd);
211 valueb->set(correlation);
212 valuec->set(slope);
213 valued->set(inter);
214
215 /* derivatives */
216 for(unsigned i=0; i<parameters.size(); ++i) {
217 const double common_d1 = (ns*parameters[i]-scy)*idevx;
218 const double common_d2 = num*(ns*getArgument(i)-scx)*idev2x*idevx;
219 const double common_d3 = common_d1 - common_d2;
220
221 /* sqdevsum */
222 const double sq_der = 2.*(getArgument(i)-parameters[i]);
223 /* correlation */
224 const double co_der = common_d3*idevy;
225 /* slope */
226 const double sl_der = (common_d1-2.*common_d2)*idevx;
227 /* intercept */
228 const double int_der = -(slope+ scx*sl_der)/ns;
229
230 setDerivative(valuea,i,sq_der);
231 setDerivative(valueb,i,co_der);
232 setDerivative(valuec,i,sl_der);
233 setDerivative(valued,i,int_der);
234 }
235
236 }
237 }
238
239 }
240 }
241
242
243