1 //
2 //  Copyright (C) 2001-2018 Rational Discovery LLC
3 //
4 //   @@ All Rights Reserved @@
5 //  This file is part of the RDKit.
6 //  The contents are covered by the terms of the BSD license
7 //  which is included in the file license.txt, found at the root
8 //  of the RDKit source tree.
9 //
10 #include <RDGeneral/test.h>
11 #include "AlignMolecules.h"
12 #include "O3AAlignMolecules.h"
13 #include <GraphMol/FileParsers/MolSupplier.h>
14 #include <GraphMol/FileParsers/MolWriters.h>
15 #include <GraphMol/FileParsers/FileParsers.h>
16 #include <GraphMol/Descriptors/Crippen.h>
17 #include <GraphMol/ROMol.h>
18 #include <GraphMol/Conformer.h>
19 #include <GraphMol/Substruct/SubstructMatch.h>
20 #include <Numerics/Vector.h>
21 #include <ForceField/ForceField.h>
22 #include <GraphMol/ForceFieldHelpers/UFF/Builder.h>
23 #include <GraphMol/ForceFieldHelpers/MMFF/Builder.h>
24 #include <GraphMol/MolPickler.h>
25 #include <GraphMol/DistGeomHelpers/Embedder.h>
26 #include <GraphMol/SmilesParse/SmilesParse.h>
27 #include <GraphMol/MolTransforms/MolTransforms.h>
28 
29 using namespace RDKit;
30 
testMMFFO3A()31 void testMMFFO3A() {
32   std::string rdbase = getenv("RDBASE");
33   std::string sdf = rdbase + "/Code/GraphMol/MolAlign/test_data/ref_e2";
34   std::string newSdf = sdf + "_MMFFO3A.sdf";
35   sdf += ".sdf";
36   SDMolSupplier supplier(sdf, true, false);
37   int nMol = supplier.length();
38   const int refNum = 48;
39   // SDWriter *newMol = new SDWriter(newSdf);
40   ROMol *refMol = supplier[refNum];
41   MMFF::MMFFMolProperties refMP(*refMol);
42   double cumScore = 0.0;
43   double cumMsd = 0.0;
44   for (int prbNum = 0; prbNum < nMol; ++prbNum) {
45     ROMol *prbMol = supplier[prbNum];
46     MMFF::MMFFMolProperties prbMP(*prbMol);
47     MolAlign::O3A o3a(*prbMol, *refMol, &prbMP, &refMP);
48     double rmsd = o3a.align();
49     cumScore += o3a.score();
50     cumMsd += rmsd * rmsd;
51     // newMol->write(prbMol);
52     delete prbMol;
53   }
54   cumMsd /= (double)nMol;
55   delete refMol;
56   // newMol->close();
57   // std::cerr<<cumScore<<","<<sqrt(cumMsd)<<std::endl;
58   TEST_ASSERT(RDKit::feq(cumScore, 6941.8, 1));
59   TEST_ASSERT(RDKit::feq(sqrt(cumMsd), .345, .001));
60 }
61 
testCrippenO3A()62 void testCrippenO3A() {
63   std::string rdbase = getenv("RDBASE");
64   std::string sdf = rdbase + "/Code/GraphMol/MolAlign/test_data/ref_e2";
65   std::string newSdf = sdf + "_CrippenO3A.sdf";
66   sdf += ".sdf";
67   SDMolSupplier supplier(sdf, true, false);
68   int nMol = supplier.length();
69   const int refNum = 48;
70   // SDWriter *newMol = new SDWriter(newSdf);
71   ROMol *refMol = supplier[refNum];
72   unsigned int refNAtoms = refMol->getNumAtoms();
73   std::vector<double> refLogpContribs(refNAtoms);
74   std::vector<double> refMRContribs(refNAtoms);
75   std::vector<unsigned int> refAtomTypes(refNAtoms);
76   std::vector<std::string> refAtomTypeLabels(refNAtoms);
77   Descriptors::getCrippenAtomContribs(*refMol, refLogpContribs, refMRContribs,
78                                       true, &refAtomTypes, &refAtomTypeLabels);
79   double cumScore = 0.0;
80   double cumMsd = 0.0;
81   for (int prbNum = 0; prbNum < nMol; ++prbNum) {
82     ROMol *prbMol = supplier[prbNum];
83     unsigned int prbNAtoms = prbMol->getNumAtoms();
84     std::vector<double> prbLogpContribs(prbNAtoms);
85     std::vector<double> prbMRContribs(prbNAtoms);
86     std::vector<unsigned int> prbAtomTypes(prbNAtoms);
87     std::vector<std::string> prbAtomTypeLabels(prbNAtoms);
88     Descriptors::getCrippenAtomContribs(*prbMol, prbLogpContribs, prbMRContribs,
89                                         true, &prbAtomTypes,
90                                         &prbAtomTypeLabels);
91     MolAlign::O3A o3a(*prbMol, *refMol, &prbLogpContribs, &refLogpContribs,
92                       MolAlign::O3A::CRIPPEN);
93     double rmsd = o3a.align();
94     cumScore += o3a.score();
95     cumMsd += rmsd * rmsd;
96     // newMol->write(prbMol);
97     delete prbMol;
98   }
99   cumMsd /= (double)nMol;
100   delete refMol;
101   // newMol->close();
102   // std::cerr<<cumScore<<","<<sqrt(cumMsd)<<std::endl;
103   TEST_ASSERT(RDKit::feq(cumScore, 4918.1, 1));
104   TEST_ASSERT(RDKit::feq(sqrt(cumMsd), .304, .001));
105 }
106 
testMMFFO3AMolHist()107 void testMMFFO3AMolHist() {
108   std::string rdbase = getenv("RDBASE");
109   std::string sdf = rdbase + "/Code/GraphMol/MolAlign/test_data/ref_e2";
110   std::string newSdf = sdf + "_MMFFO3A.sdf";
111   sdf += ".sdf";
112   SDMolSupplier supplier(sdf, true, false);
113   int nMol = supplier.length();
114   const int refNum = 48;
115   // SDWriter *newMol = new SDWriter(newSdf);
116   ROMol *refMol = supplier[refNum];
117   MMFF::MMFFMolProperties refMP(*refMol);
118   double *refDmat = MolOps::get3DDistanceMat(*refMol);
119   MolAlign::MolHistogram refHist(*refMol, refDmat);
120   double cumScore = 0.0;
121   double cumMsd = 0.0;
122   for (int prbNum = 0; prbNum < nMol; ++prbNum) {
123     ROMol *prbMol = supplier[prbNum];
124     MMFF::MMFFMolProperties prbMP(*prbMol);
125     double *prbDmat = MolOps::get3DDistanceMat(*prbMol);
126     MolAlign::MolHistogram prbHist(*prbMol, prbDmat);
127 
128     MolAlign::O3A o3a(*prbMol, *refMol, &prbMP, &refMP, MolAlign::O3A::MMFF94,
129                       -1, -1, false, 50, 0, nullptr, nullptr, nullptr, &prbHist,
130                       &refHist);
131     double rmsd = o3a.align();
132     cumScore += o3a.score();
133     cumMsd += rmsd * rmsd;
134     // newMol->write(prbMol);
135     delete prbMol;
136   }
137   cumMsd /= (double)nMol;
138   delete refMol;
139   // newMol->close();
140   // std::cerr<<cumScore<<","<<sqrt(cumMsd)<<std::endl;
141   TEST_ASSERT(RDKit::feq(cumScore, 6941.8, 1));
142   TEST_ASSERT(RDKit::feq(sqrt(cumMsd), .345, .001));
143 }
144 
testCrippenO3AMolHist()145 void testCrippenO3AMolHist() {
146   std::string rdbase = getenv("RDBASE");
147   std::string sdf = rdbase + "/Code/GraphMol/MolAlign/test_data/ref_e2";
148   std::string newSdf = sdf + "_CrippenO3A.sdf";
149   sdf += ".sdf";
150   SDMolSupplier supplier(sdf, true, false);
151   int nMol = supplier.length();
152   const int refNum = 48;
153   // SDWriter *newMol = new SDWriter(newSdf);
154   ROMol *refMol = supplier[refNum];
155   unsigned int refNAtoms = refMol->getNumAtoms();
156   std::vector<double> refLogpContribs(refNAtoms);
157   std::vector<double> refMRContribs(refNAtoms);
158   std::vector<unsigned int> refAtomTypes(refNAtoms);
159   std::vector<std::string> refAtomTypeLabels(refNAtoms);
160   Descriptors::getCrippenAtomContribs(*refMol, refLogpContribs, refMRContribs,
161                                       true, &refAtomTypes, &refAtomTypeLabels);
162   double *refDmat = MolOps::get3DDistanceMat(*refMol);
163   MolAlign::MolHistogram refHist(*refMol, refDmat);
164   double cumScore = 0.0;
165   double cumMsd = 0.0;
166   for (int prbNum = 0; prbNum < nMol; ++prbNum) {
167     ROMol *prbMol = supplier[prbNum];
168     unsigned int prbNAtoms = prbMol->getNumAtoms();
169     std::vector<double> prbLogpContribs(prbNAtoms);
170     std::vector<double> prbMRContribs(prbNAtoms);
171     std::vector<unsigned int> prbAtomTypes(prbNAtoms);
172     std::vector<std::string> prbAtomTypeLabels(prbNAtoms);
173     Descriptors::getCrippenAtomContribs(*prbMol, prbLogpContribs, prbMRContribs,
174                                         true, &prbAtomTypes,
175                                         &prbAtomTypeLabels);
176     double *prbDmat = MolOps::get3DDistanceMat(*prbMol);
177     MolAlign::MolHistogram prbHist(*prbMol, prbDmat);
178 
179     MolAlign::O3A o3a(*prbMol, *refMol, &prbLogpContribs, &refLogpContribs,
180                       MolAlign::O3A::CRIPPEN, -1, -1, false, 50, 0, nullptr,
181                       nullptr, nullptr, &prbHist, &refHist);
182     double rmsd = o3a.align();
183     cumScore += o3a.score();
184     cumMsd += rmsd * rmsd;
185     // newMol->write(prbMol);
186     delete prbMol;
187   }
188   cumMsd /= (double)nMol;
189   delete refMol;
190   // newMol->close();
191   // std::cerr<<cumScore<<","<<sqrt(cumMsd)<<std::endl;
192   TEST_ASSERT(RDKit::feq(cumScore, 4918.1, 1));
193   TEST_ASSERT(RDKit::feq(sqrt(cumMsd), .304, .001));
194 }
195 
testMMFFO3AConstraints()196 void testMMFFO3AConstraints() {
197   ROMol *m = SmilesToMol("n1ccc(cc1)-c1ccccc1");
198   TEST_ASSERT(m);
199   ROMol *m1 = MolOps::addHs(*m);
200   delete m;
201   TEST_ASSERT(m1);
202   DGeomHelpers::EmbedMolecule(*m1);
203   MMFF::sanitizeMMFFMol((RWMol &)(*m1));
204   MMFF::MMFFMolProperties mp(*m1);
205   TEST_ASSERT(mp.isValid());
206   ForceFields::ForceField *field = MMFF::constructForceField(*m1, &mp);
207   field->initialize();
208   field->minimize();
209   delete field;
210 
211   RWMol *patt = SmartsToMol("nccc-cccc");
212   MatchVectType matchVect;
213   TEST_ASSERT(SubstructMatch(*m1, (ROMol &)*patt, matchVect));
214   delete patt;
215   unsigned int nIdx = matchVect[0].second;
216   unsigned int cIdx = matchVect[matchVect.size() - 1].second;
217   MolTransforms::setDihedralDeg(m1->getConformer(), matchVect[2].second,
218                                 matchVect[3].second, matchVect[4].second,
219                                 matchVect[5].second, 0.0);
220   ROMol m2(*m1);
221   MolAlign::randomTransform(m2);
222   ROMol m3(m2);
223   auto *o3a = new MolAlign::O3A(m2, *m1, &mp, &mp);
224   TEST_ASSERT(o3a);
225   o3a->align();
226   delete o3a;
227   double d =
228       (m2.getConformer().getAtomPos(cIdx) - m1->getConformer().getAtomPos(cIdx))
229           .length();
230   TEST_ASSERT(feq(d, 0.0, 1));
231   MatchVectType constraintMap;
232   constraintMap.push_back(std::make_pair(cIdx, nIdx));
233   o3a = new MolAlign::O3A(m3, *m1, &mp, &mp, MolAlign::O3A::MMFF94, -1, -1,
234                           false, 50, 0, &constraintMap);
235   TEST_ASSERT(o3a);
236   o3a->align();
237   delete o3a;
238   d = (m3.getConformer().getAtomPos(cIdx) - m1->getConformer().getAtomPos(cIdx))
239           .length();
240   TEST_ASSERT(feq(d, 7.0, 1.0));
241   delete m1;
242 }
243 
testCrippenO3AConstraints()244 void testCrippenO3AConstraints() {
245   ROMol *m = SmilesToMol("n1ccc(cc1)-c1ccccc1");
246   TEST_ASSERT(m);
247   ROMol *m1 = MolOps::addHs(*m);
248   delete m;
249   TEST_ASSERT(m1);
250   DGeomHelpers::EmbedMolecule(*m1);
251   MMFF::sanitizeMMFFMol((RWMol &)(*m1));
252   MMFF::MMFFMolProperties mp(*m1);
253   TEST_ASSERT(mp.isValid());
254   ForceFields::ForceField *field = MMFF::constructForceField(*m1, &mp);
255   field->initialize();
256   field->minimize();
257   delete field;
258   RWMol *patt = SmartsToMol("nccc-cccc");
259   MatchVectType matchVect;
260   TEST_ASSERT(SubstructMatch(*m1, (ROMol &)*patt, matchVect));
261   delete patt;
262   unsigned int nIdx = matchVect[0].second;
263   unsigned int cIdx = matchVect[matchVect.size() - 1].second;
264   MolTransforms::setDihedralDeg(m1->getConformer(), matchVect[2].second,
265                                 matchVect[3].second, matchVect[4].second,
266                                 matchVect[5].second, 0.0);
267   ROMol m2(*m1);
268   MolAlign::randomTransform(m2);
269   ROMol m3(m2);
270   unsigned int prbNAtoms = m2.getNumAtoms();
271   std::vector<double> prbLogpContribs(prbNAtoms);
272   std::vector<double> prbMRContribs(prbNAtoms);
273   std::vector<unsigned int> prbAtomTypes(prbNAtoms);
274   std::vector<std::string> prbAtomTypeLabels(prbNAtoms);
275   Descriptors::getCrippenAtomContribs(m2, prbLogpContribs, prbMRContribs, true,
276                                       &prbAtomTypes, &prbAtomTypeLabels);
277   auto *o3a = new MolAlign::O3A(m2, *m1, &prbLogpContribs, &prbLogpContribs,
278                                 MolAlign::O3A::CRIPPEN);
279   TEST_ASSERT(o3a);
280   o3a->align();
281   delete o3a;
282   double d =
283       (m2.getConformer().getAtomPos(cIdx) - m1->getConformer().getAtomPos(cIdx))
284           .length();
285   TEST_ASSERT(feq(d, 0.0, 1));
286   MatchVectType constraintMap;
287   constraintMap.push_back(std::make_pair(cIdx, nIdx));
288   o3a = new MolAlign::O3A(m3, *m1, &prbLogpContribs, &prbLogpContribs,
289                           MolAlign::O3A::CRIPPEN, -1, -1, false, 50, 0,
290                           &constraintMap);
291   TEST_ASSERT(o3a);
292   o3a->align();
293   delete o3a;
294   d = (m3.getConformer().getAtomPos(cIdx) - m1->getConformer().getAtomPos(cIdx))
295           .length();
296   TEST_ASSERT(feq(d, 7.0, 1.0));
297   delete m1;
298 }
299 
testMMFFO3AConstraintsAndLocalOnly()300 void testMMFFO3AConstraintsAndLocalOnly() {
301   std::string rdbase = getenv("RDBASE");
302   std::string sdf = rdbase + "/Code/GraphMol/MolAlign/test_data/ref_e2.sdf";
303   SDMolSupplier supplier(sdf, true, false);
304   const int refNum = 23;
305   const int prbNum = 32;
306   ROMol *refMol = supplier[refNum];
307   ROMol *prbMol = supplier[prbNum];
308   unsigned int refNAtoms = refMol->getNumAtoms();
309   std::vector<double> refLogpContribs(refNAtoms);
310   std::vector<double> refMRContribs(refNAtoms);
311   std::vector<unsigned int> refAtomTypes(refNAtoms);
312   std::vector<std::string> refAtomTypeLabels(refNAtoms);
313   Descriptors::getCrippenAtomContribs(*refMol, refLogpContribs, refMRContribs,
314                                       true, &refAtomTypes, &refAtomTypeLabels);
315   unsigned int prbNAtoms = prbMol->getNumAtoms();
316   std::vector<double> prbLogpContribs(prbNAtoms);
317   std::vector<double> prbMRContribs(prbNAtoms);
318   std::vector<unsigned int> prbAtomTypes(prbNAtoms);
319   std::vector<std::string> prbAtomTypeLabels(prbNAtoms);
320   Descriptors::getCrippenAtomContribs(*prbMol, prbLogpContribs, prbMRContribs,
321                                       true, &prbAtomTypes, &prbAtomTypeLabels);
322   RWMol *patt = SmartsToMol("S");
323   MatchVectType matchVect;
324   TEST_ASSERT(SubstructMatch(*refMol, (ROMol &)*patt, matchVect));
325   delete patt;
326   unsigned int refSIdx = matchVect[0].second;
327   matchVect.clear();
328   patt = SmartsToMol("O");
329   TEST_ASSERT(SubstructMatch(*prbMol, (ROMol &)*patt, matchVect));
330   delete patt;
331   unsigned int prbOIdx = matchVect[0].second;
332   std::vector<double> distOS(2);
333   distOS[0] = 2.7;
334   distOS[1] = 0.4;
335   std::vector<double> weights(2);
336   weights[0] = 0.1;
337   weights[1] = 100.0;
338   for (unsigned int i = 0; i < 2; ++i) {
339     MatchVectType constraintMap;
340     constraintMap.push_back(std::make_pair(prbOIdx, refSIdx));
341     RDNumeric::DoubleVector constraintWeights(1);
342     constraintWeights[0] = weights[i];
343     auto *o3a =
344         new MolAlign::O3A(*prbMol, *refMol, &prbLogpContribs, &refLogpContribs,
345                           MolAlign::O3A::CRIPPEN, -1, -1, false, 50, 0,
346                           &constraintMap, &constraintWeights);
347     TEST_ASSERT(o3a);
348     o3a->align();
349     delete o3a;
350     o3a = new MolAlign::O3A(*prbMol, *refMol, &prbLogpContribs,
351                             &refLogpContribs, MolAlign::O3A::CRIPPEN, -1, -1,
352                             false, 50, MolAlign::O3_LOCAL_ONLY);
353     TEST_ASSERT(o3a);
354     o3a->align();
355     delete o3a;
356     double d = (prbMol->getConformer().getAtomPos(prbOIdx) -
357                 refMol->getConformer().getAtomPos(refSIdx))
358                    .length();
359     TEST_ASSERT(feq(d, distOS[i], 0.1));
360   }
361   delete refMol;
362   delete prbMol;
363 }
364 
testCrippenO3AConstraintsAndLocalOnly()365 void testCrippenO3AConstraintsAndLocalOnly() {
366   std::string rdbase = getenv("RDBASE");
367   std::string sdf = rdbase + "/Code/GraphMol/MolAlign/test_data/ref_e2.sdf";
368   SDMolSupplier supplier(sdf, true, false);
369   const int refNum = 23;
370   const int prbNum = 32;
371   ROMol *refMol = supplier[refNum];
372   ROMol *prbMol = supplier[prbNum];
373   MMFF::MMFFMolProperties refMP(*refMol);
374   TEST_ASSERT(refMP.isValid());
375   MMFF::MMFFMolProperties prbMP(*prbMol);
376   TEST_ASSERT(prbMP.isValid());
377   RWMol *patt = SmartsToMol("S");
378   MatchVectType matchVect;
379   TEST_ASSERT(SubstructMatch(*refMol, (ROMol &)*patt, matchVect));
380   delete patt;
381   unsigned int refSIdx = matchVect[0].second;
382   matchVect.clear();
383   patt = SmartsToMol("O");
384   TEST_ASSERT(SubstructMatch(*prbMol, (ROMol &)*patt, matchVect));
385   delete patt;
386   unsigned int prbOIdx = matchVect[0].second;
387   std::vector<double> distOS(2);
388   distOS[0] = 3.2;
389   distOS[1] = 0.3;
390   std::vector<double> weights(2);
391   weights[0] = 10.0;
392   weights[1] = 100.0;
393   for (unsigned int i = 0; i < 2; ++i) {
394     MatchVectType constraintMap;
395     constraintMap.push_back(std::make_pair(prbOIdx, refSIdx));
396     RDNumeric::DoubleVector constraintWeights(1);
397     constraintWeights[0] = weights[i];
398     auto *o3a = new MolAlign::O3A(*prbMol, *refMol, &prbMP, &refMP,
399                                   MolAlign::O3A::MMFF94, -1, -1, false, 50, 0,
400                                   &constraintMap, &constraintWeights);
401     TEST_ASSERT(o3a);
402     o3a->align();
403     delete o3a;
404     o3a = new MolAlign::O3A(*prbMol, *refMol, &prbMP, &refMP,
405                             MolAlign::O3A::MMFF94, -1, -1, false, 50,
406                             MolAlign::O3_LOCAL_ONLY);
407     TEST_ASSERT(o3a);
408     o3a->align();
409     delete o3a;
410     double d = (prbMol->getConformer().getAtomPos(prbOIdx) -
411                 refMol->getConformer().getAtomPos(refSIdx))
412                    .length();
413     TEST_ASSERT(feq(d, distOS[i], 0.1));
414   }
415   delete prbMol;
416   delete refMol;
417 }
418 
419 #ifdef RDK_TEST_MULTITHREADED
420 namespace {
runblock_o3a_mmff(ROMol * refMol,const std::vector<ROMol * > & mols,const std::vector<double> & rmsds,const std::vector<double> & scores,unsigned int count,unsigned int idx)421 void runblock_o3a_mmff(ROMol *refMol, const std::vector<ROMol *> &mols,
422                        const std::vector<double> &rmsds,
423                        const std::vector<double> &scores, unsigned int count,
424                        unsigned int idx) {
425   for (unsigned int rep = 0; rep < 10; ++rep) {
426     MMFF::MMFFMolProperties refMP(*refMol);
427     for (unsigned int i = 0; i < mols.size(); ++i) {
428       if (i % count != idx) {
429         continue;
430       }
431       if (!(rep % 10)) {
432         BOOST_LOG(rdErrorLog) << "Rep: " << rep << " Mol:" << i << std::endl;
433       }
434       ROMol prbMol(*mols[i]);
435       MMFF::MMFFMolProperties prbMP(prbMol);
436       MolAlign::O3A o3a(prbMol, *refMol, &prbMP, &refMP);
437       double rmsd = o3a.align();
438       double score = o3a.score();
439       TEST_ASSERT(feq(rmsd, rmsds[i]));
440       TEST_ASSERT(feq(score, scores[i]));
441     }
442   }
443 }
runblock_o3a_crippen(ROMol * refMol,const std::vector<ROMol * > & mols,const std::vector<double> & rmsds,const std::vector<double> & scores,unsigned int count,unsigned int idx)444 void runblock_o3a_crippen(ROMol *refMol, const std::vector<ROMol *> &mols,
445                           const std::vector<double> &rmsds,
446                           const std::vector<double> &scores, unsigned int count,
447                           unsigned int idx) {
448   ROMol refMolCopy(*refMol);
449   for (unsigned int rep = 0; rep < 10; ++rep) {
450     unsigned int refNAtoms = refMolCopy.getNumAtoms();
451     std::vector<double> refLogpContribs(refNAtoms);
452     std::vector<double> refMRContribs(refNAtoms);
453     std::vector<unsigned int> refAtomTypes(refNAtoms);
454     std::vector<std::string> refAtomTypeLabels(refNAtoms);
455     Descriptors::getCrippenAtomContribs(refMolCopy, refLogpContribs,
456                                         refMRContribs, true, &refAtomTypes,
457                                         &refAtomTypeLabels);
458     for (unsigned int i = 0; i < mols.size(); ++i) {
459       if (i % count != idx) {
460         continue;
461       }
462       if (!(rep % 10)) {
463         BOOST_LOG(rdErrorLog) << "Rep: " << rep << " Mol:" << i << std::endl;
464       }
465       ROMol prbMol(*mols[i]);
466       unsigned int prbNAtoms = prbMol.getNumAtoms();
467       std::vector<double> prbLogpContribs(prbNAtoms);
468       std::vector<double> prbMRContribs(prbNAtoms);
469       std::vector<unsigned int> prbAtomTypes(prbNAtoms);
470       std::vector<std::string> prbAtomTypeLabels(prbNAtoms);
471       Descriptors::getCrippenAtomContribs(prbMol, prbLogpContribs,
472                                           prbMRContribs, true, &prbAtomTypes,
473                                           &prbAtomTypeLabels);
474       MolAlign::O3A o3a(prbMol, refMolCopy, &prbLogpContribs, &refLogpContribs,
475                         MolAlign::O3A::CRIPPEN);
476       double rmsd = o3a.align();
477       double score = o3a.score();
478       TEST_ASSERT(feq(rmsd, rmsds[i]));
479       TEST_ASSERT(feq(score, scores[i]));
480     }
481   }
482 }
483 }  // namespace
484 #include <thread>
485 #include <future>
testMMFFO3AMultiThread()486 void testMMFFO3AMultiThread() {
487   std::string rdbase = getenv("RDBASE");
488   std::string sdf = rdbase + "/Code/GraphMol/MolAlign/test_data/ref_e2.sdf";
489 
490   SDMolSupplier suppl(sdf, true, false);
491 
492   std::vector<ROMol *> mols;
493   while (!suppl.atEnd() && mols.size() < 100) {
494     ROMol *mol = nullptr;
495     try {
496       mol = suppl.next();
497     } catch (...) {
498       continue;
499     }
500     if (!mol) {
501       continue;
502     }
503     mols.push_back(mol);
504   }
505 
506   std::cerr << "generating reference data" << std::endl;
507   std::vector<double> rmsds(mols.size(), 0.0);
508   std::vector<double> scores(mols.size(), 0.0);
509   const int refNum = 48;
510   ROMol *refMol = mols[refNum];
511   MMFF::MMFFMolProperties refMP(*refMol);
512 
513   for (unsigned int i = 0; i < mols.size(); ++i) {
514     ROMol prbMol(*mols[i]);
515     MMFF::MMFFMolProperties prbMP(prbMol);
516     MolAlign::O3A o3a(prbMol, *refMol, &prbMP, &refMP);
517     rmsds[i] = o3a.align();
518     scores[i] = o3a.score();
519   }
520 
521   std::vector<std::future<void>> tg;
522 
523   std::cerr << "processing" << std::endl;
524   unsigned int count = 4;
525   for (unsigned int i = 0; i < count; ++i) {
526     std::cerr << " launch :" << i << std::endl;
527     std::cerr.flush();
528     tg.emplace_back(std::async(std::launch::async, runblock_o3a_mmff, refMol,
529                                mols, rmsds, scores, count, i));
530   }
531   for (auto &fut : tg) {
532     fut.get();
533   }
534 
535   for (auto &&mol : mols) { delete mol; }
536   BOOST_LOG(rdErrorLog) << "  done" << std::endl;
537 }
538 
testCrippenO3AMultiThread()539 void testCrippenO3AMultiThread() {
540   std::string rdbase = getenv("RDBASE");
541   std::string sdf = rdbase + "/Code/GraphMol/MolAlign/test_data/ref_e2.sdf";
542 
543   SDMolSupplier suppl(sdf, true, false);
544 
545   std::vector<ROMol *> mols;
546   while (!suppl.atEnd() && mols.size() < 100) {
547     ROMol *mol = nullptr;
548     try {
549       mol = suppl.next();
550     } catch (...) {
551       continue;
552     }
553     if (!mol) {
554       continue;
555     }
556     mols.push_back(mol);
557   }
558 
559   std::cerr << "generating reference data" << std::endl;
560   std::vector<double> rmsds(mols.size(), 0.0);
561   std::vector<double> scores(mols.size(), 0.0);
562   const int refNum = 48;
563   ROMol *refMol = mols[refNum];
564   unsigned int refNAtoms = refMol->getNumAtoms();
565   std::vector<double> refLogpContribs(refNAtoms);
566   std::vector<double> refMRContribs(refNAtoms);
567   std::vector<unsigned int> refAtomTypes(refNAtoms);
568   std::vector<std::string> refAtomTypeLabels(refNAtoms);
569   Descriptors::getCrippenAtomContribs(*refMol, refLogpContribs, refMRContribs,
570                                       true, &refAtomTypes, &refAtomTypeLabels);
571 
572   for (unsigned int i = 0; i < mols.size(); ++i) {
573     ROMol prbMol(*mols[i]);
574     unsigned int prbNAtoms = prbMol.getNumAtoms();
575     std::vector<double> prbLogpContribs(prbNAtoms);
576     std::vector<double> prbMRContribs(prbNAtoms);
577     std::vector<unsigned int> prbAtomTypes(prbNAtoms);
578     std::vector<std::string> prbAtomTypeLabels(prbNAtoms);
579     Descriptors::getCrippenAtomContribs(prbMol, prbLogpContribs, prbMRContribs,
580                                         true, &prbAtomTypes,
581                                         &prbAtomTypeLabels);
582     MolAlign::O3A o3a(prbMol, *refMol, &prbLogpContribs, &refLogpContribs,
583                       MolAlign::O3A::CRIPPEN);
584     rmsds[i] = o3a.align();
585     scores[i] = o3a.score();
586   }
587 
588   std::vector<std::future<void>> tg;
589 
590   std::cerr << "processing" << std::endl;
591   unsigned int count = 4;
592   for (unsigned int i = 0; i < count; ++i) {
593     std::cerr << " launch :" << i << std::endl;
594     std::cerr.flush();
595     tg.emplace_back(std::async(std::launch::async, runblock_o3a_crippen, refMol,
596                                mols, rmsds, scores, count, i));
597   }
598   for (auto &fut : tg) {
599     fut.get();
600   }
601 
602   for (auto *mol : mols) { delete mol; }
603   BOOST_LOG(rdErrorLog) << "  done" << std::endl;
604 }
605 #endif
606 
testGetO3AForProbeConfs()607 void testGetO3AForProbeConfs() {
608   std::string rdbase = getenv("RDBASE");
609   std::string sdf = rdbase + "/Code/GraphMol/MolAlign/test_data/ref_e2.sdf";
610 
611   SDMolSupplier suppl(sdf, true, false);
612   ROMol *refMol = suppl[13];
613   TEST_ASSERT(refMol);
614 
615   sdf = rdbase + "/Code/GraphMol/MolAlign/test_data/probe_mol.sdf";
616   SDMolSupplier psuppl(sdf, true, false);
617   ROMol *prbMol = psuppl.next();
618   TEST_ASSERT(prbMol);
619   while (!psuppl.atEnd()) {
620     ROMol *mol = psuppl.next();
621     if (!mol) {
622       continue;
623     }
624     auto *conf = new Conformer(mol->getConformer());
625     prbMol->addConformer(conf, true);
626     delete mol;
627   }
628   TEST_ASSERT(prbMol->getNumConformers() == 50);
629 
630   MMFF::MMFFMolProperties refMP(*refMol);
631   MMFF::MMFFMolProperties prbMP(*prbMol);
632 
633   std::vector<std::pair<double, double>> oscores;
634   for (unsigned int i = 0; i < prbMol->getNumConformers(); ++i) {
635     MolAlign::O3A o3a(*prbMol, *refMol, &prbMP, &refMP, MolAlign::O3A::MMFF94,
636                       i);
637     double rmsd = o3a.align();
638     double score = o3a.score();
639     oscores.push_back(std::make_pair(rmsd, score));
640   }
641 
642   {
643     std::vector<boost::shared_ptr<MolAlign::O3A>> o3s;
644     MolAlign::getO3AForProbeConfs(*prbMol, *refMol, &prbMP, &refMP, o3s);
645     TEST_ASSERT(o3s.size() == prbMol->getNumConformers());
646     for (unsigned int i = 0; i < prbMol->getNumConformers(); ++i) {
647       TEST_ASSERT(feq(oscores[i].first, o3s[i]->align()));
648       TEST_ASSERT(feq(oscores[i].second, o3s[i]->score()));
649     }
650   }
651 #ifdef RDK_TEST_MULTITHREADED
652   {
653     ROMol prbMol2(*prbMol);
654     unsigned int nDups = 10;
655     for (unsigned int j = 0; j < nDups; ++j) {
656       for (unsigned int i = 0; i < prbMol->getNumConformers(); ++i) {
657         prbMol2.addConformer(new Conformer(prbMol->getConformer(i)), true);
658       }
659     }
660 
661     std::vector<boost::shared_ptr<MolAlign::O3A>> o3s;
662     MolAlign::getO3AForProbeConfs(prbMol2, *refMol, &prbMP, &refMP, o3s, 4);
663     TEST_ASSERT(o3s.size() == prbMol2.getNumConformers());
664     for (unsigned int i = 0; i < prbMol2.getNumConformers(); ++i) {
665       TEST_ASSERT(
666           feq(oscores[i % prbMol->getNumConformers()].first, o3s[i]->align()));
667       TEST_ASSERT(
668           feq(oscores[i % prbMol->getNumConformers()].second, o3s[i]->score()));
669     }
670   }
671 
672 #endif
673   delete refMol;
674   delete prbMol;
675 
676   BOOST_LOG(rdErrorLog) << "  done" << std::endl;
677 }
678 
testO3AMultiThreadBug()679 void testO3AMultiThreadBug() {
680   std::string rdbase = getenv("RDBASE");
681   std::string sdf = rdbase + "/Code/GraphMol/MolAlign/test_data/bzr_data.sdf";
682 
683   SDMolSupplier suppl(sdf, true, false);
684 
685   std::vector<ROMol *> mols;
686   while (!suppl.atEnd()) {
687     ROMol *mol = suppl.next();
688     if (!mol) {
689       continue;
690     }
691 
692     while (mol->getNumConformers() < 20) {
693       auto *conf = new Conformer(mol->getConformer(0));
694       mol->addConformer(conf, true);
695     }
696     mols.push_back(mol);
697   }
698   TEST_ASSERT(mols.size() == 10);
699 
700   auto *refMol = new ROMol(*mols[0]);
701   TEST_ASSERT(refMol);
702 
703   MMFF::MMFFMolProperties refMP(*refMol);
704 
705 #ifdef RDK_TEST_MULTITHREADED
706   {
707     for (auto &mol : mols) {
708       ROMol prbMol = *mol;
709       TEST_ASSERT(prbMol.getNumConformers() == 20);
710 
711       MMFF::MMFFMolProperties prbMP(prbMol);
712 
713       std::vector<std::pair<double, double>> oscores;
714       for (unsigned int i = 0; i < prbMol.getNumConformers(); ++i) {
715         MolAlign::O3A o3a(prbMol, *refMol, &prbMP, &refMP,
716                           MolAlign::O3A::MMFF94, i);
717         double rmsd = o3a.align();
718         double score = o3a.score();
719         oscores.push_back(std::make_pair(rmsd, score));
720       }
721 
722       ROMol prbMol2 = *mol;
723       std::vector<boost::shared_ptr<MolAlign::O3A>> o3s;
724       MolAlign::getO3AForProbeConfs(prbMol2, *refMol, &prbMP, &refMP, o3s, 0);
725       TEST_ASSERT(o3s.size() == prbMol2.getNumConformers());
726       for (unsigned int i = 0; i < prbMol2.getNumConformers(); ++i) {
727         TEST_ASSERT(
728             feq(oscores[i % prbMol.getNumConformers()].first, o3s[i]->align()));
729         TEST_ASSERT(feq(oscores[i % prbMol.getNumConformers()].second,
730                         o3s[i]->score()));
731       }
732     }
733   }
734 
735 #endif
736   delete refMol;
737   for (auto &&mol : mols) { delete mol; }
738   BOOST_LOG(rdErrorLog) << "  done" << std::endl;
739 }
740 
main()741 int main() {
742   std::cout << "***********************************************************\n";
743   std::cout << "Testing O3AAlign\n";
744 
745 #if 1
746   std::cout << "\t---------------------------------\n";
747   std::cout << "\t testMMFFO3A \n\n";
748   testMMFFO3A();
749 
750   std::cout << "\t---------------------------------\n";
751   std::cout << "\t testMMFFO3A with pre-computed dmat and MolHistogram\n\n";
752   testMMFFO3AMolHist();
753 
754   std::cout << "\t---------------------------------\n";
755   std::cout << "\t testMMFFO3A with constraints\n\n";
756   testMMFFO3AConstraints();
757 
758   std::cout << "\t---------------------------------\n";
759   std::cout << "\t testMMFFO3A with variable weight constraints followed by "
760                "local-only optimization\n\n";
761   testMMFFO3AConstraintsAndLocalOnly();
762 
763   std::cout << "\t---------------------------------\n";
764   std::cout << "\t testCrippenO3A \n\n";
765   testCrippenO3A();
766 
767   std::cout << "\t---------------------------------\n";
768   std::cout << "\t testCrippenO3A with pre-computed dmat and MolHistogram\n\n";
769   testCrippenO3AMolHist();
770 
771   std::cout << "\t---------------------------------\n";
772   std::cout << "\t testCrippenO3A with constraints\n\n";
773   testCrippenO3AConstraints();
774 
775   std::cout << "\t---------------------------------\n";
776   std::cout << "\t testCrippenO3A with variable weight constraints followed by "
777                "local-only optimization\n\n";
778   testCrippenO3AConstraintsAndLocalOnly();
779 
780 #ifdef RDK_TEST_MULTITHREADED
781   std::cout << "\t---------------------------------\n";
782   std::cout << "\t testMMFFO3A multithreading\n\n";
783   testMMFFO3AMultiThread();
784 
785   std::cout << "\t---------------------------------\n";
786   std::cout << "\t test O3A multithreading bug\n\n";
787   testO3AMultiThreadBug();
788 #endif
789 
790 #ifdef RDK_TEST_MULTITHREADED
791   std::cout << "\t---------------------------------\n";
792   std::cout << "\t testCrippenO3A multithreading\n\n";
793   testCrippenO3AMultiThread();
794 #endif
795 
796   std::cout << "\t---------------------------------\n";
797   std::cout << "\t test getO3AForProbeConfs\n\n";
798   testGetO3AForProbeConfs();
799 #endif
800 
801   std::cout << "***********************************************************\n";
802 }
803