1import os
2import re
3from pathlib import Path
4from typing import Mapping
5import configparser
6
7import pytest
8
9from ase.calculators.calculator import (names as calculator_names,
10                                        get_calculator_class)
11
12
13class NotInstalled(Exception):
14    pass
15
16
17def get_testing_executables():
18    # TODO: better cross-platform support (namely Windows),
19    # and a cross-platform global config file like /etc/ase/ase.conf
20    paths = [Path.home() / '.config' / 'ase' / 'ase.conf']
21    try:
22        paths += [Path(x) for x in os.environ['ASE_CONFIG'].split(':')]
23    except KeyError:
24        pass
25    conf = configparser.ConfigParser()
26    conf['executables'] = {}
27    effective_paths = conf.read(paths)
28    return effective_paths, conf['executables']
29
30
31factory_classes = {}
32
33
34def factory(name):
35    def decorator(cls):
36        cls.name = name
37        factory_classes[name] = cls
38        return cls
39
40    return decorator
41
42
43def make_factory_fixture(name):
44    @pytest.fixture(scope='session')
45    def _factory(factories):
46        factories.require(name)
47        return factories[name]
48
49    _factory.__name__ = '{}_factory'.format(name)
50    return _factory
51
52
53@factory('abinit')
54class AbinitFactory:
55    def __init__(self, executable, pp_paths):
56        self.executable = executable
57        self.pp_paths = pp_paths
58        self._version = None
59
60    def version(self):
61        from ase.calculators.abinit import get_abinit_version
62        # XXX Ugly
63        if self._version is None:
64            self._version = get_abinit_version(self.executable)
65        return self._version
66
67    def is_legacy_version(self):
68        version = self.version()
69        major_ver = int(version.split('.')[0])
70        return major_ver < 9
71
72    def _base_kw(self, v8_legacy_format):
73        if v8_legacy_format:
74            command = f'{self.executable} < PREFIX.files > PREFIX.log'
75        else:
76            command = f'{self.executable} PREFIX.in > PREFIX.log'
77
78        return dict(command=command,
79                    v8_legacy_format=v8_legacy_format,
80                    pp_paths=self.pp_paths,
81                    ecut=150,
82                    chksymbreak=0,
83                    toldfe=1e-3)
84
85    def calc(self, **kwargs):
86        from ase.calculators.abinit import Abinit
87        legacy = kwargs.pop('v8_legacy_format', None)
88        if legacy is None:
89            legacy = self.is_legacy_version()
90
91        kw = self._base_kw(legacy)
92        kw.update(kwargs)
93        return Abinit(**kw)
94
95    @classmethod
96    def fromconfig(cls, config):
97        factory = AbinitFactory(config.executables['abinit'],
98                                config.datafiles['abinit'])
99        # XXX Hack
100        factory._version = factory.version()
101        return factory
102
103
104@factory('aims')
105class AimsFactory:
106    def __init__(self, executable):
107        self.executable = executable
108        # XXX pseudo_dir
109
110    def calc(self, **kwargs):
111        from ase.calculators.aims import Aims
112        kwargs1 = dict(xc='LDA')
113        kwargs1.update(kwargs)
114        return Aims(command=self.executable, **kwargs1)
115
116    def version(self):
117        from ase.calculators.aims import get_aims_version
118        txt = read_stdout([self.executable])
119        return get_aims_version(txt)
120
121    @classmethod
122    def fromconfig(cls, config):
123        return cls(config.executables['aims'])
124
125
126@factory('asap')
127class AsapFactory:
128    importname = 'asap3'
129
130    def calc(self, **kwargs):
131        from asap3 import EMT
132        return EMT(**kwargs)
133
134    def version(self):
135        import asap3
136        return asap3.__version__
137
138    @classmethod
139    def fromconfig(cls, config):
140        # XXXX TODO Clean this up.  Copy of GPAW.
141        # How do we design these things?
142        import importlib
143        spec = importlib.util.find_spec('asap3')
144        if spec is None:
145            raise NotInstalled('asap3')
146        return cls()
147
148
149@factory('cp2k')
150class CP2KFactory:
151    def __init__(self, executable):
152        self.executable = executable
153
154    def version(self):
155        from ase.calculators.cp2k import Cp2kShell
156        shell = Cp2kShell(self.executable, debug=False)
157        return shell.version
158
159    def calc(self, **kwargs):
160        from ase.calculators.cp2k import CP2K
161        return CP2K(command=self.executable, **kwargs)
162
163    @classmethod
164    def fromconfig(cls, config):
165        return CP2KFactory(config.executables['cp2k'])
166
167
168@factory('castep')
169class CastepFactory:
170    def __init__(self, executable):
171        self.executable = executable
172
173    def version(self):
174        from ase.calculators.castep import get_castep_version
175        return get_castep_version(self.executable)
176
177    def calc(self, **kwargs):
178        from ase.calculators.castep import Castep
179        return Castep(castep_command=self.executable, **kwargs)
180
181    @classmethod
182    def fromconfig(cls, config):
183        return cls(config.executables['castep'])
184
185
186@factory('dftb')
187class DFTBFactory:
188    def __init__(self, executable, skt_paths):
189        self.executable = executable
190        assert len(skt_paths) == 1
191        self.skt_path = skt_paths[0]
192
193    def version(self):
194        stdout = read_stdout([self.executable])
195        match = re.search(r'DFTB\+ release\s*(\S+)', stdout, re.M)
196        return match.group(1)
197
198    def calc(self, **kwargs):
199        from ase.calculators.dftb import Dftb
200        command = f'{self.executable} > PREFIX.out'
201        return Dftb(
202            command=command,
203            slako_dir=str(self.skt_path) + '/',  # XXX not obvious
204            **kwargs)
205
206    @classmethod
207    def fromconfig(cls, config):
208        return cls(config.executables['dftb'], config.datafiles['dftb'])
209
210
211@factory('dftd3')
212class DFTD3Factory:
213    def __init__(self, executable):
214        self.executable = executable
215
216    def calc(self, **kwargs):
217        from ase.calculators.dftd3 import DFTD3
218        return DFTD3(command=self.executable, **kwargs)
219
220    @classmethod
221    def fromconfig(cls, config):
222        return cls(config.executables['dftd3'])
223
224
225def read_stdout(args, createfile=None):
226    import tempfile
227    from subprocess import Popen, PIPE
228    with tempfile.TemporaryDirectory() as directory:
229        if createfile is not None:
230            path = Path(directory) / createfile
231            path.touch()
232        proc = Popen(args,
233                     stdout=PIPE,
234                     stderr=PIPE,
235                     stdin=PIPE,
236                     cwd=directory,
237                     encoding='ascii')
238        stdout, _ = proc.communicate()
239        # Exit code will be != 0 because there isn't an input file
240    return stdout
241
242
243@factory('elk')
244class ElkFactory:
245    def __init__(self, executable, species_dir):
246        self.executable = executable
247        self.species_dir = species_dir
248
249    def version(self):
250        output = read_stdout([self.executable])
251        match = re.search(r'Elk code version (\S+)', output, re.M)
252        return match.group(1)
253
254    def calc(self, **kwargs):
255        from ase.calculators.elk import ELK
256        command = f'{self.executable} > elk.out'
257        return ELK(command=command, species_dir=self.species_dir, **kwargs)
258
259    @classmethod
260    def fromconfig(cls, config):
261        return cls(config.executables['elk'], config.datafiles['elk'][0])
262
263
264@factory('espresso')
265class EspressoFactory:
266    def __init__(self, executable, pseudo_dir):
267        self.executable = executable
268        self.pseudo_dir = pseudo_dir
269
270    def _base_kw(self):
271        from ase.units import Ry
272        return dict(ecutwfc=300 / Ry)
273
274    def version(self):
275        stdout = read_stdout([self.executable])
276        match = re.match(r'\s*Program PWSCF\s*(\S+)', stdout, re.M)
277        assert match is not None
278        return match.group(1)
279
280    def calc(self, **kwargs):
281        from ase.calculators.espresso import Espresso
282        command = '{} -in PREFIX.pwi > PREFIX.pwo'.format(self.executable)
283        pseudopotentials = {}
284        for path in self.pseudo_dir.glob('*.UPF'):
285            fname = path.name
286            # Names are e.g. si_lda_v1.uspp.F.UPF
287            symbol = fname.split('_', 1)[0].capitalize()
288            pseudopotentials[symbol] = fname
289
290        kw = self._base_kw()
291        kw.update(kwargs)
292        return Espresso(command=command,
293                        pseudo_dir=str(self.pseudo_dir),
294                        pseudopotentials=pseudopotentials,
295                        **kw)
296
297    @classmethod
298    def fromconfig(cls, config):
299        paths = config.datafiles['espresso']
300        assert len(paths) == 1
301        return cls(config.executables['espresso'], paths[0])
302
303
304@factory('exciting')
305class ExcitingFactory:
306    def __init__(self, executable):
307        # XXX species path
308        self.executable = executable
309
310    def calc(self, **kwargs):
311        from ase.calculators.exciting import Exciting
312        return Exciting(bin=self.executable, **kwargs)
313
314    @classmethod
315    def fromconfig(cls, config):
316        return cls(config.executables['exciting'])
317
318
319@factory('vasp')
320class VaspFactory:
321    def __init__(self, executable):
322        self.executable = executable
323
324    def version(self):
325        from ase.calculators.vasp import get_vasp_version
326        header = read_stdout([self.executable], createfile='INCAR')
327        return get_vasp_version(header)
328
329    def calc(self, **kwargs):
330        from ase.calculators.vasp import Vasp
331        # XXX We assume the user has set VASP_PP_PATH
332        if Vasp.VASP_PP_PATH not in os.environ:
333            # For now, we skip with a message that we cannot run the test
334            pytest.skip(
335                'No VASP pseudopotential path set. Set the ${} environment variable to enable.'
336                .format(Vasp.VASP_PP_PATH))
337        return Vasp(command=self.executable, **kwargs)
338
339    @classmethod
340    def fromconfig(cls, config):
341        return cls(config.executables['vasp'])
342
343
344@factory('gpaw')
345class GPAWFactory:
346    importname = 'gpaw'
347
348    def calc(self, **kwargs):
349        from gpaw import GPAW
350        return GPAW(**kwargs)
351
352    def version(self):
353        import gpaw
354        return gpaw.__version__
355
356    @classmethod
357    def fromconfig(cls, config):
358        import importlib
359        spec = importlib.util.find_spec('gpaw')
360        # XXX should be made non-pytest dependent
361        if spec is None:
362            raise NotInstalled('gpaw')
363        return cls()
364
365
366@factory('gromacs')
367class GromacsFactory:
368    def __init__(self, executable):
369        self.executable = executable
370
371    def version(self):
372        from ase.calculators.gromacs import get_gromacs_version
373        return get_gromacs_version(self.executable)
374
375    def calc(self, **kwargs):
376        from ase.calculators.gromacs import Gromacs
377        return Gromacs(command=self.executable, **kwargs)
378
379    @classmethod
380    def fromconfig(cls, config):
381        return cls(config.executables['gromacs'])
382
383
384class BuiltinCalculatorFactory:
385    def calc(self, **kwargs):
386        from ase.calculators.calculator import get_calculator_class
387        cls = get_calculator_class(self.name)
388        return cls(**kwargs)
389
390    @classmethod
391    def fromconfig(cls, config):
392        return cls()
393
394
395@factory('emt')
396class EMTFactory(BuiltinCalculatorFactory):
397    pass
398
399
400@factory('lammpsrun')
401class LammpsRunFactory:
402    def __init__(self, executable):
403        self.executable = executable
404
405    def version(self):
406        stdout = read_stdout([self.executable])
407        match = re.match(r'LAMMPS\s*\((.+?)\)', stdout, re.M)
408        return match.group(1)
409
410    def calc(self, **kwargs):
411        from ase.calculators.lammpsrun import LAMMPS
412        return LAMMPS(command=self.executable, **kwargs)
413
414    @classmethod
415    def fromconfig(cls, config):
416        return cls(config.executables['lammpsrun'])
417
418
419@factory('lammpslib')
420class LammpsLibFactory:
421    def __init__(self, potentials_path):
422        # Set the path where LAMMPS will look for potential parameter files
423        os.environ["LAMMPS_POTENTIALS"] = str(potentials_path)
424        self.potentials_path = potentials_path
425
426    def version(self):
427        import lammps
428        cmd_args = [
429            "-echo", "log", "-log", "none", "-screen", "none", "-nocite"
430        ]
431        lmp = lammps.lammps(name="", cmdargs=cmd_args, comm=None)
432        try:
433            return lmp.version()
434        finally:
435            lmp.close()
436
437    def calc(self, **kwargs):
438        from ase.calculators.lammpslib import LAMMPSlib
439        return LAMMPSlib(**kwargs)
440
441    @classmethod
442    def fromconfig(cls, config):
443        return cls(config.datafiles['lammps'][0])
444
445
446@factory('openmx')
447class OpenMXFactory:
448    def __init__(self, executable, data_path):
449        self.executable = executable
450        self.data_path = data_path
451
452    def version(self):
453        from ase.calculators.openmx.openmx import parse_omx_version
454        dummyfile = 'omx_dummy_input'
455        stdout = read_stdout([self.executable, dummyfile],
456                             createfile=dummyfile)
457        return parse_omx_version(stdout)
458
459    def calc(self, **kwargs):
460        from ase.calculators.openmx import OpenMX
461        return OpenMX(command=self.executable,
462                      data_path=str(self.data_path),
463                      **kwargs)
464
465    @classmethod
466    def fromconfig(cls, config):
467        return cls(config.executables['openmx'],
468                   data_path=config.datafiles['openmx'][0])
469
470
471@factory('octopus')
472class OctopusFactory:
473    def __init__(self, executable):
474        self.executable = executable
475
476    def version(self):
477        stdout = read_stdout([self.executable, '--version'])
478        match = re.match(r'octopus\s*(.+)', stdout)
479        return match.group(1)
480
481    def calc(self, **kwargs):
482        from ase.calculators.octopus import Octopus
483        command = f'{self.executable} > stdout.log'
484        return Octopus(command=command, **kwargs)
485
486    @classmethod
487    def fromconfig(cls, config):
488        return cls(config.executables['octopus'])
489
490
491@factory('siesta')
492class SiestaFactory:
493    def __init__(self, executable, pseudo_path):
494        self.executable = executable
495        self.pseudo_path = pseudo_path
496
497    def version(self):
498        from ase.calculators.siesta.siesta import get_siesta_version
499        full_ver = get_siesta_version(self.executable)
500        m = re.match(r'siesta-(\S+)', full_ver, flags=re.I)
501        if m:
502            return m.group(1)
503        return full_ver
504
505    def calc(self, **kwargs):
506        from ase.calculators.siesta import Siesta
507        command = '{} < PREFIX.fdf > PREFIX.out'.format(self.executable)
508        return Siesta(command=command,
509                      pseudo_path=str(self.pseudo_path),
510                      **kwargs)
511
512    @classmethod
513    def fromconfig(cls, config):
514        paths = config.datafiles['siesta']
515        assert len(paths) == 1
516        path = paths[0]
517        return cls(config.executables['siesta'], str(path))
518
519
520@factory('nwchem')
521class NWChemFactory:
522    def __init__(self, executable):
523        self.executable = executable
524
525    def version(self):
526        stdout = read_stdout([self.executable], createfile='nwchem.nw')
527        match = re.search(
528            r'Northwest Computational Chemistry Package \(NWChem\) (\S+)',
529            stdout, re.M)
530        return match.group(1)
531
532    def calc(self, **kwargs):
533        from ase.calculators.nwchem import NWChem
534        command = f'{self.executable} PREFIX.nwi > PREFIX.nwo'
535        return NWChem(command=command, **kwargs)
536
537    @classmethod
538    def fromconfig(cls, config):
539        return cls(config.executables['nwchem'])
540
541
542class NoSuchCalculator(Exception):
543    pass
544
545
546class Factories:
547    all_calculators = set(calculator_names)
548    builtin_calculators = {'eam', 'emt', 'ff', 'lj', 'morse', 'tip3p', 'tip4p'}
549    autoenabled_calculators = {'asap'} | builtin_calculators
550
551    # TODO: Port calculators to use factories.  As we do so, remove names
552    # from list of calculators that we monkeypatch:
553    monkeypatch_calculator_constructors = {
554        'ace',
555        'aims',
556        'amber',
557        'crystal',
558        'demon',
559        'demonnano',
560        'dftd3',
561        'dmol',
562        'exciting',
563        'fleur',
564        'gamess_us',
565        'gaussian',
566        'gulp',
567        'hotbit',
568        'lammpslib',
569        'mopac',
570        'onetep',
571        'orca',
572        'Psi4',
573        'qchem',
574        'turbomole',
575    }
576
577    def __init__(self, requested_calculators):
578        executable_config_paths, executables = get_testing_executables()
579        assert isinstance(executables, Mapping), executables
580        self.executables = executables
581        self.executable_config_paths = executable_config_paths
582
583        datafiles_module = None
584        datafiles = {}
585
586        try:
587            import asetest as datafiles_module
588        except ImportError:
589            pass
590        else:
591            datafiles.update(datafiles_module.datafiles.paths)
592            datafiles_module = datafiles_module
593
594        self.datafiles_module = datafiles_module
595        self.datafiles = datafiles
596
597        factories = {}
598
599        for name, cls in factory_classes.items():
600            try:
601                factory = cls.fromconfig(self)
602            except (NotInstalled, KeyError):
603                pass
604            else:
605                factories[name] = factory
606
607        self.factories = factories
608
609        requested_calculators = set(requested_calculators)
610        if 'auto' in requested_calculators:
611            requested_calculators.remove('auto')
612            requested_calculators |= set(self.factories)
613        self.requested_calculators = requested_calculators
614
615        for name in self.requested_calculators:
616            if name not in self.all_calculators:
617                raise NoSuchCalculator(name)
618
619    def installed(self, name):
620        return name in self.builtin_calculators | set(self.factories)
621
622    def is_adhoc(self, name):
623        return name not in factory_classes
624
625    def optional(self, name):
626        return name not in self.builtin_calculators
627
628    def enabled(self, name):
629        auto = name in self.autoenabled_calculators and self.installed(name)
630        return auto or (name in self.requested_calculators)
631
632    def require(self, name):
633        # XXX This is for old-style calculator tests.
634        # Newer calculator tests would depend on a fixture which would
635        # make them skip.
636        # Older tests call require(name) explicitly.
637        assert name in calculator_names
638        if not self.installed(name) and not self.is_adhoc(name):
639            pytest.skip(f'Not installed: {name}')
640        if name not in self.requested_calculators:
641            pytest.skip(f'Use --calculators={name} to enable')
642
643    def __getitem__(self, name):
644        return self.factories[name]
645
646    def monkeypatch_disabled_calculators(self):
647        test_calculator_names = (self.autoenabled_calculators
648                                 | self.builtin_calculators
649                                 | self.requested_calculators)
650        disable_names = self.monkeypatch_calculator_constructors - test_calculator_names
651        #disable_names = self.all_calculators - test_calculator_names
652
653        for name in disable_names:
654            try:
655                cls = get_calculator_class(name)
656            except ImportError:
657                pass
658            else:
659
660                def get_mock_init(name):
661                    def mock_init(obj, *args, **kwargs):
662                        pytest.skip(f'use --calculators={name} to enable')
663
664                    return mock_init
665
666                def mock_del(obj):
667                    pass
668
669                cls.__init__ = get_mock_init(name)
670                cls.__del__ = mock_del
671
672
673def get_factories(pytestconfig):
674    opt = pytestconfig.getoption('--calculators')
675    requested_calculators = opt.split(',') if opt else []
676    return Factories(requested_calculators)
677
678
679def parametrize_calculator_tests(metafunc):
680    """Parametrize tests using our custom markers.
681
682    We want tests marked with @pytest.mark.calculator(names) to be
683    parametrized over the named calculator or calculators."""
684    calculator_inputs = []
685
686    for marker in metafunc.definition.iter_markers(name='calculator'):
687        calculator_names = marker.args
688        kwargs = dict(marker.kwargs)
689        marks = kwargs.pop('marks', [])
690        for name in calculator_names:
691            param = pytest.param((name, kwargs), marks=marks)
692            calculator_inputs.append(param)
693
694    if calculator_inputs:
695        metafunc.parametrize('factory',
696                             calculator_inputs,
697                             indirect=True,
698                             ids=lambda input: input[0])
699
700
701class CalculatorInputs:
702    def __init__(self, factory, parameters=None):
703        if parameters is None:
704            parameters = {}
705        self.parameters = parameters
706        self.factory = factory
707
708    def require_version(self, version):
709        from ase.utils import tokenize_version
710        installed_version = self.factory.version()
711        old = tokenize_version(installed_version) < tokenize_version(version)
712        if old:
713            pytest.skip('Version too old: Requires {}; got {}'
714                        .format(version, installed_version))
715
716    @property
717    def name(self):
718        return self.factory.name
719
720    def __repr__(self):
721        cls = type(self)
722        return '{}({}, {})'.format(cls.__name__, self.name, self.parameters)
723
724    def new(self, **kwargs):
725        kw = dict(self.parameters)
726        kw.update(kwargs)
727        return CalculatorInputs(self.factory, kw)
728
729    def calc(self, **kwargs):
730        param = dict(self.parameters)
731        param.update(kwargs)
732        return self.factory.calc(**param)
733
734
735class ObsoleteFactoryWrapper:
736    # We use this for transitioning older tests to the new framework.
737    def __init__(self, name):
738        self.name = name
739
740    def calc(self, **kwargs):
741        from ase.calculators.calculator import get_calculator_class
742        cls = get_calculator_class(self.name)
743        return cls(**kwargs)
744