1import os
2import sys
3import unittest
4
5here = os.path.split(__file__)[0]
6
7## In Python 3.x, generators have a __next__() method
8## instead of a next() method
9nextmethod = "next"
10ispy2 = True
11if sys.version_info[0] >= 3:
12    nextmethod = "__next__"
13    ispy2 = False
14try:
15    test = os.write
16    try:
17        from cinfony import pybel, rdkit, cdk
18    except ImportError:
19        cinfony = None
20    try:
21        from openbabel import pybel
22        rdkit = cdk = None
23    except ImportError:
24        pybel = None
25except AttributeError:
26    from cinfony import cdk
27    pybel = rdkit = None
28try:
29    set
30except NameError:
31    from sets import Set as set
32
33class myTestCase(unittest.TestCase):
34    """Additional methods not present in Jython 2.2"""
35    # Taken from unittest.py in Python 2.5 distribution
36    def assertFalse(self, expr, msg=None):
37        "Fail the test if the expression is true."
38        if expr: raise self.failureException(msg)
39    def assertTrue(self, expr, msg=None):
40        """Fail the test unless the expression is true."""
41        if not expr: raise self.failureException(msg)
42    def assertAlmostEqual(self, first, second, places=7, msg=None):
43        """Fail if the two objects are unequal as determined by their
44           difference rounded to the given number of decimal places
45           (default 7) and comparing to zero.
46
47           Note that decimal places (from zero) are usually not the same
48           as significant digits (measured from the most signficant digit).
49        """
50        if round(second-first, places) != 0:
51            raise self.failureException(
52                  (msg or '%r != %r within %r places' % (first, second, places)))
53
54class TestToolkit(myTestCase):
55
56    def setUp(self):
57        self.mols = [self.toolkit.readstring("smi", "CCCC"),
58                     self.toolkit.readstring("smi", "CCCN")]
59        self.head = list(self.toolkit.readfile("sdf", os.path.join(here, "head.sdf")))
60        self.atom = self.head[0].atoms[1]
61
62    def testattributes(self):
63        """Test attributes like informats, descs and so on"""
64        informats, outformats = self.toolkit.informats, self.toolkit.outformats
65        self.assertNotEqual(len(list(self.toolkit.informats.keys())), 0)
66        self.assertNotEqual(len(list(self.toolkit.outformats.keys())), 0)
67        self.assertNotEqual(len(self.toolkit.descs), 0)
68        self.assertNotEqual(len(self.toolkit.forcefields), 0)
69        self.assertNotEqual(len(self.toolkit.fps), 0)
70
71    def FPaccesstest(self):
72        # Should raise AttributeError
73        return self.mols[0].calcfp().nosuchname
74
75    def testFPTanimoto(self):
76        """Test the calculation of the Tanimoto coefficient"""
77        fps = [x.calcfp() for x in self.mols]
78        self.assertEqual(fps[0] | fps[1], self.tanimotoresult)
79
80    def testFPstringrepr(self):
81        """Test the string representation and corner cases."""
82        self.assertRaises(ValueError, self.mols[0].calcfp, "Nosuchname")
83        self.assertRaises(AttributeError, self.FPaccesstest)
84        r = str(self.mols[0].calcfp())
85        t = r.split(", ")
86        self.assertEqual(len(t), self.Nfpbits)
87
88    def testFPbits(self):
89        """Test whether the bits are set correctly."""
90        bits = [x.calcfp().bits for x in self.mols]
91        self.assertEqual(len(bits[0]), self.Nbits)
92        bits = [set(x) for x in bits]
93        # Calculate the Tanimoto coefficient the old-fashioned way
94        tanimoto = len(bits[0] & bits[1]) / float(len(bits[0] | bits[1]))
95        self.assertEqual(tanimoto, self.tanimotoresult)
96
97    def RSaccesstest(self):
98        # Should raise AttributeError
99        return self.mols[0].nosuchname
100
101    def testRSformaterror(self):
102        """Test that invalid formats raise an error"""
103        self.assertRaises(ValueError, self.toolkit.readstring, "noel", "jkjk")
104        self.assertRaises(IOError, self.toolkit.readstring, "smi", "&*)(%)($)")
105
106    def testselfconversion(self):
107        """Test that the toolkit can eat its own dog-food."""
108        newmol = self.toolkit.Molecule(self.head[0])
109        self.assertEqual(newmol._exchange,
110                         self.head[0]._exchange)
111        newmol = self.toolkit.Molecule(self.mols[0])
112        self.assertEqual(newmol._exchange,
113                         self.mols[0]._exchange)
114
115    def testLocalOpt(self):
116        """Test that local optimisation affects the coordinates"""
117        oldcoords = self.head[0].atoms[0].coords
118        self.head[0].localopt()
119        newcoords = self.head[0].atoms[0].coords
120        self.assertNotEqual(oldcoords, newcoords)
121        # Make sure that make3D() is called for molecules without coordinates
122        mol = self.mols[0]
123        mol.localopt()
124        self.assertNotEqual(mol.atoms[3].coords, (0., 0., 0.))
125
126    def testMake2D(self):
127        """Test that 2D coordinate generation does something"""
128        mol = self.mols[1]
129        mol.make2D()
130        self.assertNotEqual(mol.atoms[2].coords, (0., 0., 0.))
131        self.assertEqual(mol.atoms[2].coords[2], 0.)
132
133    def testMake3D(self):
134        """Test that 3D coordinate generation does something"""
135        mol = self.mols[0]
136        mol.make3D()
137        self.assertNotEqual(mol.atoms[3].coords, (0., 0., 0.))
138
139    def testDraw(self):
140        """Create a 2D depiction"""
141        self.mols[0].draw(show=False,
142                          filename="%s.png" % self.toolkit.__name__)
143        self.mols[0].draw(show=False) # Just making sure that it doesn't raise an Error
144        self.mols[0].draw(show=False, update=True)
145        coords = [x.coords for x in self.mols[0].atoms[0:2]]
146        self.assertNotEqual(coords, [(0., 0., 0.), (0., 0., 0.)])
147        self.mols[0].draw(show=False, usecoords=True,
148                          filename="%s_b.png" % self.toolkit.__name__)
149
150    def testRSgetprops(self):
151        """Get the values of the properties."""
152        # self.assertAlmostEqual(self.mols[0].exactmass, 58.078, 3)
153        # Only OpenBabel has a working exactmass
154        # CDK doesn't include implicit Hs when calculating the molwt
155        self.assertAlmostEqual(self.mols[0].molwt, 58.12, 2)
156        self.assertEqual(len(self.mols[0].atoms), 4)
157        self.assertRaises(AttributeError, self.RSaccesstest)
158
159    def testRSconversiontoMOL(self):
160        """Convert to mol"""
161        as_mol = self.mols[0].write("mol")
162        test = """
163 OpenBabel04220815032D
164
165  4  3  0  0  0  0  0  0  0  0999 V2000
166    0.0000    0.0000    0.0000 C   0  0  0  0  0
167    0.0000    0.0000    0.0000 C   0  0  0  0  0
168    0.0000    0.0000    0.0000 C   0  0  0  0  0
169    0.0000    0.0000    0.0000 C   0  0  0  0  0
170  1  2  1  0  0  0
171  2  3  1  0  0  0
172  3  4  1  0  0  0
173M  END
174"""
175        data, result = test.split("\n"), as_mol.split("\n")
176        self.assertEqual(len(data), len(result))
177        self.assertEqual(data[-2], result[-2].rstrip()) # M  END
178
179    def testRSstringrepr(self):
180        """Test the string representation of a molecule"""
181        self.assertEqual(str(self.mols[0]).strip(), "CCCC")
182
183    def testRFread(self):
184        """Is the right number of molecules read from the file?"""
185        self.assertEqual(len(self.mols), 2)
186
187    def RFreaderror(self):
188        mol = getattr(self.toolkit.readfile("sdf", "nosuchfile.sdf"), nextmethod)()
189
190    def testRFmissingfile(self):
191        """Test that reading from a non-existent file raises an error."""
192        self.assertRaises(IOError, self.RFreaderror)
193
194    def RFformaterror(self):
195        mol = getattr(self.toolkit.readfile("noel", "head.sdf"), nextmethod)()
196
197    def testRFformaterror(self):
198        """Test that invalid formats raise an error"""
199        self.assertRaises(ValueError, self.RFformaterror)
200
201    def RFunitcellerror(self):
202        unitcell = self.mols[0].unitcell
203
204    def testRFunitcellerror(self):
205        """Test that accessing the unitcell raises an error"""
206        self.assertRaises(AttributeError, self.RFunitcellerror)
207
208    def testRFconversion(self):
209        """Convert to smiles"""
210        as_smi = [mol.write("smi").split("\t")[0] for mol in self.mols]
211        ans = []
212        for smi in as_smi:
213            t = list(smi)
214            t.sort()
215            ans.append("".join(t))
216        test = ['CCCC', 'CCCN']
217        self.assertEqual(ans, test)
218
219    def testRFsingletofile(self):
220        """Test the molecule.write() method"""
221        mol = self.mols[0]
222        mol.write("smi", "testoutput.txt")
223        test = 'CCCC'
224        input = open("testoutput.txt", "r")
225        filecontents = input.readlines()[0].split("\t")[0].strip()
226        input.close()
227        self.assertEqual(filecontents, test)
228        self.assertRaises(IOError, mol.write, "smi", "testoutput.txt")
229        os.remove("testoutput.txt")
230        self.assertRaises(ValueError, mol.write, "noel", "testoutput.txt")
231
232    def testRFoutputfile(self):
233        """Test the Outputfile class"""
234        self.assertRaises(ValueError, self.toolkit.Outputfile, "noel", "testoutput.txt")
235        with self.toolkit.Outputfile("sdf", "testoutput.txt") as outputfile:
236            for mol in self.head:
237                outputfile.write(mol)
238        self.assertRaises(IOError, outputfile.write, mol)
239        self.assertRaises(IOError, self.toolkit.Outputfile, "sdf", "testoutput.txt")
240        input = open("testoutput.txt", "r")
241        numdollar = len([x for x in input.readlines()
242                         if x.rstrip() == "$$$$"])
243        input.close()
244        os.remove("testoutput.txt")
245        self.assertEqual(numdollar, 2)
246
247    def RFdesctest(self):
248        # Should raise ValueError
249        self.mols[0].calcdesc("BadDescName")
250
251    def testRFdesc(self):
252        """Test the descriptors"""
253        if self.toolkit.__name__ == "cinfony.cdk":
254            # For the CDK, you need to call addh()
255            # or some descriptors will be incorrectly calculated
256            # (even those that are supposed to be immune like TPSA)
257            self.mols[1].addh()
258        desc = self.mols[1].calcdesc()
259        self.assertTrue(len(desc) > 3)
260        self.assertAlmostEqual(desc[self.tpsaname], 26.02, 2)
261        self.assertRaises(ValueError, self.RFdesctest)
262
263    def MDaccesstest(self):
264        # Should raise KeyError
265        return self.head[0].data['noel']
266
267    def testMDaccess(self):
268        """Change the value of a field"""
269        data = self.head[0].data
270        self.assertRaises(KeyError, self.MDaccesstest)
271        data['noel'] = 'testvalue'
272        self.assertEqual(data['noel'], 'testvalue')
273        newvalues = {'hey':'there', 'yo':1}
274        data.update(newvalues)
275        self.assertEqual(data['yo'], '1')
276        self.assertTrue('there' in data.values())
277
278    def testMDglobalaccess(self):
279        """Check out the keys"""
280        data = self.head[0].data
281        self.assertFalse('Noel' in data)
282        self.assertEqual(len(data), len(self.datakeys))
283        for key in data:
284            self.assertEqual(key in self.datakeys, True)
285        r = repr(data)
286        self.assertTrue(r[0]=="{" and r[-2:]=="'}", r)
287
288    def testMDdelete(self):
289        """Delete some keys"""
290        data = self.head[0].data
291        self.assertTrue('NSC' in data)
292        del data['NSC']
293        self.assertFalse('NSC' in data)
294        data.clear()
295        self.assertEqual(len(data), 0)
296
297    def testAiteration(self):
298        """Test the ability to iterate over the atoms"""
299        atoms = [atom for atom in self.head[0]]
300        self.assertEqual(len(atoms), self.Natoms)
301
302    def Atomaccesstest(self):
303        # Should raise AttributeError
304        return self.atom.nosuchname
305
306    def testAattributes(self):
307        """Get the values of some properties"""
308        self.assertRaises(AttributeError, self.Atomaccesstest)
309        self.assertAlmostEqual(self.atom.coords[0], -0.0691, 4)
310
311    def testAstringrepr(self):
312        """Test the string representation of the Atom"""
313        test = "Atom: 8 (-0.07 5.24 0.03)"
314        self.assertEqual(str(self.atom), test)
315
316    def invalidSMARTStest(self):
317        # Should raise IOError
318        return self.toolkit.Smarts("[#NOEL][#NOEL]")
319
320    def testSMARTS(self):
321        """Searching for ethyl groups in triethylamine"""
322        mol = self.toolkit.readstring("smi", "CCN(CC)CC")
323        smarts = self.toolkit.Smarts("[#6][#6]")
324        ans = smarts.findall(mol)
325        self.assertEqual(len(ans), 3)
326        self.toolkit.ob.obErrorLog.SetOutputLevel(self.toolkit.ob.obError)
327        self.assertRaises(IOError, self.invalidSMARTStest)
328        self.toolkit.ob.obErrorLog.SetOutputLevel(self.toolkit.ob.obWarning)
329
330    def testAddh(self):
331        """Adding and removing hydrogens"""
332        self.assertEqual(len(self.mols[0].atoms),4)
333        self.mols[0].addh()
334        self.assertEqual(len(self.mols[0].atoms),14)
335        self.mols[0].removeh()
336        self.assertEqual(len(self.mols[0].atoms),4)
337
338class TestPybel(TestToolkit):
339    toolkit = pybel
340    tanimotoresult = 1/3.
341    Natoms = 15
342    tpsaname = "TPSA"
343    Nbits = 3
344    Nfpbits = 32
345    datakeys = ['NSC', 'Comment', 'OpenBabel Symmetry Classes', 'MOL Chiral Flag']
346
347    def testFP_FP3(self):
348        "Checking the results from FP3"
349        fps = [x.calcfp("FP3") for x in self.mols]
350        self.assertEqual(fps[0] | fps[1], 0.)
351
352    def testunitcell(self):
353        """Testing unit cell access"""
354        mol = getattr(self.toolkit.readfile("cif", os.path.join(here, "hashizume.cif")), nextmethod)()
355        cell = mol.unitcell
356        self.assertAlmostEqual(cell.GetAlpha(), 93.0, 1)
357
358    def testMDcomment(self):
359        """Mess about with the comment field"""
360        data = self.head[0].data
361        self.assertEqual('Comment' in data, True)
362        self.assertEqual(data['Comment'], 'CORINA 2.61 0041  25.10.2001')
363        data['Comment'] = 'New comment'
364        self.assertEqual(data['Comment'], 'New comment')
365
366    def importtest(self):
367        self.mols[0].draw(show=True, usecoords=True)
368
369    def testRSconversiontoMOL2(self):
370        """Convert to mol2"""
371        as_mol2 = self.mols[0].write("mol2")
372        test = """@<TRIPOS>MOLECULE
373*****
374 4 3 0 0 0
375SMALL
376GASTEIGER
377
378@<TRIPOS>ATOM
379      1 C           0.0000    0.0000    0.0000 C.3     1  UNL1        0.0000
380      2 C           0.0000    0.0000    0.0000 C.3     1  UNL1        0.0000
381      3 C           0.0000    0.0000    0.0000 C.3     1  UNL1        0.0000
382      4 C           0.0000    0.0000    0.0000 C.3     1  UNL1        0.0000
383@<TRIPOS>BOND
384     1     1     2    1
385     2     2     3    1
386     3     3     4    1
387"""
388        self.assertEqual(as_mol2, test)
389
390    def testRSgetprops(self):
391        """Get the values of the properties."""
392        self.assertAlmostEqual(self.mols[0].exactmass, 58.078, 3)
393        self.assertAlmostEqual(self.mols[0].molwt, 58.122, 3)
394        self.assertEqual(len(self.mols[0].atoms), 4)
395        self.assertRaises(AttributeError, self.RSaccesstest)
396
397    def testIterators(self):
398        """Check out the OB iterators"""
399        numatoms = len(list(self.toolkit.ob.OBMolAtomIter(self.mols[0].OBMol)))
400        self.assertEqual(numatoms, 4)
401
402class TestOBPybelNoDraw(TestPybel):
403    def testDraw(self):
404       """No drawing done"""
405       pass
406
407class TestPybelWithDraw(TestPybel):
408
409    def testDrawdependencies(self):
410        "Testing the draw dependencies"
411        t = self.toolkit.tk
412        self.toolkit.tk = None
413        self.mols[0].draw(show=False, usecoords=True,
414                          filename="%s_b.png" % self.toolkit.__name__)
415        self.assertRaises(ImportError,
416                          self.importtest)
417        self.toolkit.tk = t
418
419        t = self.toolkit.oasa
420        self.toolkit.oasa = None
421        self.assertRaises(ImportError,
422                          self.importtest)
423
424
425class TestRDKit(TestToolkit):
426    toolkit = rdkit
427    tanimotoresult = 1/3.
428    Natoms = 9
429    tpsaname = "TPSA"
430    Nbits = 12
431    Nfpbits = 64
432    datakeys = ['NSC']
433
434
435class TestCDK(TestToolkit):
436    toolkit = cdk
437    tanimotoresult = 0.375
438    Natoms = 15
439    tpsaname = "tpsa"
440    Nbits = 4
441    Nfpbits = 4 # The CDK uses a true java.util.Bitset
442    datakeys = ['NSC', 'Remark', 'Title']
443
444    def testSMARTS(self):
445        """No SMARTS testing done"""
446        pass
447
448    def testLocalOpt(self):
449        """No local opt testing done"""
450        pass
451
452    def testMake2D(self):
453        """No 2D coordinate generation done"""
454        pass
455
456    def testMake3D(self):
457        """No 3D coordinate generation done"""
458        pass
459
460    def testRSgetprops(self):
461        """Get the values of the properties."""
462        # self.assertAlmostEqual(self.mols[0].exactmass, 58.078, 3)
463        # Only OpenBabel has a working exactmass
464        # CDK doesn't include implicit Hs when calculating the molwt
465        self.mols[0].addh()
466        self.assertAlmostEqual(self.mols[0].molwt, 58.12, 2)
467        self.assertEqual(len(self.mols[0].atoms), 14)
468        self.assertRaises(AttributeError, self.RSaccesstest)
469
470if __name__=="__main__":
471    # Tidy up
472    if os.path.isfile("testoutput.txt"):
473        os.remove("testoutput.txt")
474
475    #testcases = [TestPybel, TestCDK, TestRDKit]
476    # testcases = [TestCDK]
477    # testcases = [TestPybel]
478    # testcases = [TestRDKit]
479    testcases = [TestPybel]
480    for testcase in testcases:
481        sys.stdout.write("\n\n\nTESTING %s\n%s\n\n\n" % (testcase.__name__, "== "*10))
482        myunittest = unittest.defaultTestLoader.loadTestsFromTestCase(testcase)
483        unittest.TextTestRunner(verbosity=2).run(myunittest)
484