1import os
2import socket
3from subprocess import Popen, PIPE
4from contextlib import contextmanager
5
6import numpy as np
7
8from ase.calculators.calculator import (Calculator, all_changes,
9                                        PropertyNotImplementedError)
10import ase.units as units
11from ase.utils import IOContext
12from ase.stress import full_3x3_to_voigt_6_stress
13
14
15def actualunixsocketname(name):
16    return '/tmp/ipi_{}'.format(name)
17
18
19class SocketClosed(OSError):
20    pass
21
22
23class IPIProtocol:
24    """Communication using IPI protocol."""
25
26    def __init__(self, socket, txt=None):
27        self.socket = socket
28
29        if txt is None:
30            def log(*args):
31                pass
32        else:
33            def log(*args):
34                print('Driver:', *args, file=txt)
35                txt.flush()
36        self.log = log
37
38    def sendmsg(self, msg):
39        self.log('  sendmsg', repr(msg))
40        # assert msg in self.statements, msg
41        msg = msg.encode('ascii').ljust(12)
42        self.socket.sendall(msg)
43
44    def _recvall(self, nbytes):
45        """Repeatedly read chunks until we have nbytes.
46
47        Normally we get all bytes in one read, but that is not guaranteed."""
48        remaining = nbytes
49        chunks = []
50        while remaining > 0:
51            chunk = self.socket.recv(remaining)
52            if len(chunk) == 0:
53                # (If socket is still open, recv returns at least one byte)
54                raise SocketClosed()
55            chunks.append(chunk)
56            remaining -= len(chunk)
57        msg = b''.join(chunks)
58        assert len(msg) == nbytes and remaining == 0
59        return msg
60
61    def recvmsg(self):
62        msg = self._recvall(12)
63        if not msg:
64            raise SocketClosed()
65
66        assert len(msg) == 12, msg
67        msg = msg.rstrip().decode('ascii')
68        # assert msg in self.responses, msg
69        self.log('  recvmsg', repr(msg))
70        return msg
71
72    def send(self, a, dtype):
73        buf = np.asarray(a, dtype).tobytes()
74        # self.log('  send {}'.format(np.array(a).ravel().tolist()))
75        self.log('  send {} bytes of {}'.format(len(buf), dtype))
76        self.socket.sendall(buf)
77
78    def recv(self, shape, dtype):
79        a = np.empty(shape, dtype)
80        nbytes = np.dtype(dtype).itemsize * np.prod(shape)
81        buf = self._recvall(nbytes)
82        assert len(buf) == nbytes, (len(buf), nbytes)
83        self.log('  recv {} bytes of {}'.format(len(buf), dtype))
84        # print(np.frombuffer(buf, dtype=dtype))
85        a.flat[:] = np.frombuffer(buf, dtype=dtype)
86        # self.log('  recv {}'.format(a.ravel().tolist()))
87        assert np.isfinite(a).all()
88        return a
89
90    def sendposdata(self, cell, icell, positions):
91        assert cell.size == 9
92        assert icell.size == 9
93        assert positions.size % 3 == 0
94
95        self.log(' sendposdata')
96        self.sendmsg('POSDATA')
97        self.send(cell.T / units.Bohr, np.float64)
98        self.send(icell.T * units.Bohr, np.float64)
99        self.send(len(positions), np.int32)
100        self.send(positions / units.Bohr, np.float64)
101
102    def recvposdata(self):
103        cell = self.recv((3, 3), np.float64).T.copy()
104        icell = self.recv((3, 3), np.float64).T.copy()
105        natoms = self.recv(1, np.int32)
106        natoms = int(natoms)
107        positions = self.recv((natoms, 3), np.float64)
108        return cell * units.Bohr, icell / units.Bohr, positions * units.Bohr
109
110    def sendrecv_force(self):
111        self.log(' sendrecv_force')
112        self.sendmsg('GETFORCE')
113        msg = self.recvmsg()
114        assert msg == 'FORCEREADY', msg
115        e = self.recv(1, np.float64)[0]
116        natoms = self.recv(1, np.int32)
117        assert natoms >= 0
118        forces = self.recv((int(natoms), 3), np.float64)
119        virial = self.recv((3, 3), np.float64).T.copy()
120        nmorebytes = self.recv(1, np.int32)
121        nmorebytes = int(nmorebytes)
122        if nmorebytes > 0:
123            # Receiving 0 bytes will block forever on python2.
124            morebytes = self.recv(nmorebytes, np.byte)
125        else:
126            morebytes = b''
127        return (e * units.Ha, (units.Ha / units.Bohr) * forces,
128                units.Ha * virial, morebytes)
129
130    def sendforce(self, energy, forces, virial,
131                  morebytes=np.zeros(1, dtype=np.byte)):
132        assert np.array([energy]).size == 1
133        assert forces.shape[1] == 3
134        assert virial.shape == (3, 3)
135
136        self.log(' sendforce')
137        self.sendmsg('FORCEREADY')  # mind the units
138        self.send(np.array([energy / units.Ha]), np.float64)
139        natoms = len(forces)
140        self.send(np.array([natoms]), np.int32)
141        self.send(units.Bohr / units.Ha * forces, np.float64)
142        self.send(1.0 / units.Ha * virial.T, np.float64)
143        # We prefer to always send at least one byte due to trouble with
144        # empty messages.  Reading a closed socket yields 0 bytes
145        # and thus can be confused with a 0-length bytestring.
146        self.send(np.array([len(morebytes)]), np.int32)
147        self.send(morebytes, np.byte)
148
149    def status(self):
150        self.log(' status')
151        self.sendmsg('STATUS')
152        msg = self.recvmsg()
153        return msg
154
155    def end(self):
156        self.log(' end')
157        self.sendmsg('EXIT')
158
159    def recvinit(self):
160        self.log(' recvinit')
161        bead_index = self.recv(1, np.int32)
162        nbytes = self.recv(1, np.int32)
163        initbytes = self.recv(nbytes, np.byte)
164        return bead_index, initbytes
165
166    def sendinit(self):
167        # XXX Not sure what this function is supposed to send.
168        # It 'works' with QE, but for now we try not to call it.
169        self.log(' sendinit')
170        self.sendmsg('INIT')
171        self.send(0, np.int32)  # 'bead index' always zero for now
172        # We send one byte, which is zero, since things may not work
173        # with 0 bytes.  Apparently implementations ignore the
174        # initialization string anyway.
175        self.send(1, np.int32)
176        self.send(np.zeros(1), np.byte)  # initialization string
177
178    def calculate(self, positions, cell):
179        self.log('calculate')
180        msg = self.status()
181        # We don't know how NEEDINIT is supposed to work, but some codes
182        # seem to be okay if we skip it and send the positions instead.
183        if msg == 'NEEDINIT':
184            self.sendinit()
185            msg = self.status()
186        assert msg == 'READY', msg
187        icell = np.linalg.pinv(cell).transpose()
188        self.sendposdata(cell, icell, positions)
189        msg = self.status()
190        assert msg == 'HAVEDATA', msg
191        e, forces, virial, morebytes = self.sendrecv_force()
192        r = dict(energy=e,
193                 forces=forces,
194                 virial=virial)
195        if morebytes:
196            r['morebytes'] = morebytes
197        return r
198
199
200@contextmanager
201def bind_unixsocket(socketfile):
202    assert socketfile.startswith('/tmp/ipi_'), socketfile
203    serversocket = socket.socket(socket.AF_UNIX)
204    try:
205        serversocket.bind(socketfile)
206    except OSError as err:
207        raise OSError('{}: {}'.format(err, repr(socketfile)))
208
209    try:
210        with serversocket:
211            yield serversocket
212    finally:
213        os.unlink(socketfile)
214
215
216@contextmanager
217def bind_inetsocket(port):
218    serversocket = socket.socket(socket.AF_INET)
219    serversocket.setsockopt(socket.SOL_SOCKET,
220                            socket.SO_REUSEADDR, 1)
221    serversocket.bind(('', port))
222    with serversocket:
223        yield serversocket
224
225
226class FileIOSocketClientLauncher:
227    def __init__(self, calc):
228        self.calc = calc
229
230    def __call__(self, atoms, properties=None, port=None, unixsocket=None):
231        assert self.calc is not None
232        cmd = self.calc.command.replace('PREFIX', self.calc.prefix)
233        self.calc.write_input(atoms, properties=properties,
234                              system_changes=all_changes)
235        cwd = self.calc.directory
236        cmd = cmd.format(port=port, unixsocket=unixsocket)
237        return Popen(cmd, shell=True, cwd=cwd)
238
239
240class SocketServer(IOContext):
241    default_port = 31415
242
243    def __init__(self,  # launch_client=None,
244                 port=None, unixsocket=None, timeout=None,
245                 log=None):
246        """Create server and listen for connections.
247
248        Parameters:
249
250        client_command: Shell command to launch client process, or None
251            The process will be launched immediately, if given.
252            Else the user is expected to launch a client whose connection
253            the server will then accept at any time.
254            One calculate() is called, the server will block to wait
255            for the client.
256        port: integer or None
257            Port on which to listen for INET connections.  Defaults
258            to 31415 if neither this nor unixsocket is specified.
259        unixsocket: string or None
260            Filename for unix socket.
261        timeout: float or None
262            timeout in seconds, or unlimited by default.
263            This parameter is passed to the Python socket object; see
264            documentation therof
265        log: file object or None
266            useful debug messages are written to this."""
267
268        if unixsocket is None and port is None:
269            port = self.default_port
270        elif unixsocket is not None and port is not None:
271            raise ValueError('Specify only one of unixsocket and port')
272
273        self.port = port
274        self.unixsocket = unixsocket
275        self.timeout = timeout
276        self._closed = False
277
278        if unixsocket is not None:
279            actualsocket = actualunixsocketname(unixsocket)
280            conn_name = 'UNIX-socket {}'.format(actualsocket)
281            socket_context = bind_unixsocket(actualsocket)
282        else:
283            conn_name = 'INET port {}'.format(port)
284            socket_context = bind_inetsocket(port)
285
286        self.serversocket = self.closelater(socket_context)
287
288        if log:
289            print('Accepting clients on {}'.format(conn_name), file=log)
290
291        self.serversocket.settimeout(timeout)
292
293        self.serversocket.listen(1)
294
295        self.log = log
296
297        self.proc = None
298
299        self.protocol = None
300        self.clientsocket = None
301        self.address = None
302
303        #if launch_client is not None:
304        #    self.proc = launch_client(port=port, unixsocket=unixsocket)
305
306    def _accept(self):
307        """Wait for client and establish connection."""
308        # It should perhaps be possible for process to be launched by user
309        log = self.log
310        if log:
311            print('Awaiting client', file=self.log)
312
313        # If we launched the subprocess, the process may crash.
314        # We want to detect this, using loop with timeouts, and
315        # raise an error rather than blocking forever.
316        if self.proc is not None:
317            self.serversocket.settimeout(1.0)
318
319        while True:
320            try:
321                self.clientsocket, self.address = self.serversocket.accept()
322                self.closelater(self.clientsocket)
323            except socket.timeout:
324                if self.proc is not None:
325                    status = self.proc.poll()
326                    if status is not None:
327                        raise OSError('Subprocess terminated unexpectedly'
328                                      ' with status {}'.format(status))
329            else:
330                break
331
332        self.serversocket.settimeout(self.timeout)
333        self.clientsocket.settimeout(self.timeout)
334
335        if log:
336            # For unix sockets, address is b''.
337            source = ('client' if self.address == b'' else self.address)
338            print('Accepted connection from {}'.format(source), file=log)
339
340        self.protocol = IPIProtocol(self.clientsocket, txt=log)
341
342    def close(self):
343        if self._closed:
344            return
345
346        super().close()
347
348        if self.log:
349            print('Close socket server', file=self.log)
350        self._closed = True
351
352        # Proper way to close sockets?
353        # And indeed i-pi connections...
354        # if self.protocol is not None:
355        #     self.protocol.end()  # Send end-of-communication string
356        self.protocol = None
357        if self.proc is not None:
358            exitcode = self.proc.wait()
359            if exitcode != 0:
360                import warnings
361                # Quantum Espresso seems to always exit with status 128,
362                # even if successful.
363                # Should investigate at some point
364                warnings.warn('Subprocess exited with status {}'
365                              .format(exitcode))
366        # self.log('IPI server closed')
367
368    def calculate(self, atoms):
369        """Send geometry to client and return calculated things as dict.
370
371        This will block until client has established connection, then
372        wait for the client to finish the calculation."""
373        assert not self._closed
374
375        # If we have not established connection yet, we must block
376        # until the client catches up:
377        if self.protocol is None:
378            self._accept()
379        return self.protocol.calculate(atoms.positions, atoms.cell)
380
381
382class SocketClient:
383    def __init__(self, host='localhost', port=None,
384                 unixsocket=None, timeout=None, log=None, comm=None):
385        """Create client and connect to server.
386
387        Parameters:
388
389        host: string
390            Hostname of server.  Defaults to localhost
391        port: integer or None
392            Port to which to connect.  By default 31415.
393        unixsocket: string or None
394            If specified, use corresponding UNIX socket.
395            See documentation of unixsocket for SocketIOCalculator.
396        timeout: float or None
397            See documentation of timeout for SocketIOCalculator.
398        log: file object or None
399            Log events to this file
400        comm: communicator or None
401            MPI communicator object.  Defaults to ase.parallel.world.
402            When ASE runs in parallel, only the process with world.rank == 0
403            will communicate over the socket.  The received information
404            will then be broadcast on the communicator.  The SocketClient
405            must be created on all ranks of world, and will see the same
406            Atoms objects."""
407        if comm is None:
408            from ase.parallel import world
409            comm = world
410
411        # Only rank0 actually does the socket work.
412        # The other ranks only need to follow.
413        #
414        # Note: We actually refrain from assigning all the
415        # socket-related things except on master
416        self.comm = comm
417
418        if self.comm.rank == 0:
419            if unixsocket is not None:
420                sock = socket.socket(socket.AF_UNIX)
421                actualsocket = actualunixsocketname(unixsocket)
422                sock.connect(actualsocket)
423            else:
424                if port is None:
425                    port = SocketServer.default_port
426                sock = socket.socket(socket.AF_INET)
427                sock.connect((host, port))
428            sock.settimeout(timeout)
429            self.host = host
430            self.port = port
431            self.unixsocket = unixsocket
432
433            self.protocol = IPIProtocol(sock, txt=log)
434            self.log = self.protocol.log
435            self.closed = False
436
437            self.bead_index = 0
438            self.bead_initbytes = b''
439            self.state = 'READY'
440
441    def close(self):
442        if not self.closed:
443            self.log('Close SocketClient')
444            self.closed = True
445            self.protocol.socket.close()
446
447    def calculate(self, atoms, use_stress):
448        # We should also broadcast the bead index, once we support doing
449        # multiple beads.
450        self.comm.broadcast(atoms.positions, 0)
451        self.comm.broadcast(np.ascontiguousarray(atoms.cell), 0)
452
453        energy = atoms.get_potential_energy()
454        forces = atoms.get_forces()
455        if use_stress:
456            stress = atoms.get_stress(voigt=False)
457            virial = -atoms.get_volume() * stress
458        else:
459            virial = np.zeros((3, 3))
460        return energy, forces, virial
461
462    def irun(self, atoms, use_stress=None):
463        if use_stress is None:
464            use_stress = any(atoms.pbc)
465
466        my_irun = self.irun_rank0 if self.comm.rank == 0 else self.irun_rankN
467        return my_irun(atoms, use_stress)
468
469    def irun_rankN(self, atoms, use_stress=True):
470        stop_criterion = np.zeros(1, bool)
471        while True:
472            self.comm.broadcast(stop_criterion, 0)
473            if stop_criterion[0]:
474                return
475
476            self.calculate(atoms, use_stress)
477            yield
478
479    def irun_rank0(self, atoms, use_stress=True):
480        # For every step we either calculate or quit.  We need to
481        # tell other MPI processes (if this is MPI-parallel) whether they
482        # should calculate or quit.
483        try:
484            while True:
485                try:
486                    msg = self.protocol.recvmsg()
487                except SocketClosed:
488                    # Server closed the connection, but we want to
489                    # exit gracefully anyway
490                    msg = 'EXIT'
491
492                if msg == 'EXIT':
493                    # Send stop signal to clients:
494                    self.comm.broadcast(np.ones(1, bool), 0)
495                    # (When otherwise exiting, things crashed and we should
496                    # let MPI_ABORT take care of the mess instead of trying
497                    # to synchronize the exit)
498                    return
499                elif msg == 'STATUS':
500                    self.protocol.sendmsg(self.state)
501                elif msg == 'POSDATA':
502                    assert self.state == 'READY'
503                    cell, icell, positions = self.protocol.recvposdata()
504                    atoms.cell[:] = cell
505                    atoms.positions[:] = positions
506
507                    # User may wish to do something with the atoms object now.
508                    # Should we provide option to yield here?
509                    #
510                    # (In that case we should MPI-synchronize *before*
511                    #  whereas now we do it after.)
512
513                    # Send signal for other ranks to proceed with calculation:
514                    self.comm.broadcast(np.zeros(1, bool), 0)
515                    energy, forces, virial = self.calculate(atoms, use_stress)
516
517                    self.state = 'HAVEDATA'
518                    yield
519                elif msg == 'GETFORCE':
520                    assert self.state == 'HAVEDATA', self.state
521                    self.protocol.sendforce(energy, forces, virial)
522                    self.state = 'NEEDINIT'
523                elif msg == 'INIT':
524                    assert self.state == 'NEEDINIT'
525                    bead_index, initbytes = self.protocol.recvinit()
526                    self.bead_index = bead_index
527                    self.bead_initbytes = initbytes
528                    self.state = 'READY'
529                else:
530                    raise KeyError('Bad message', msg)
531        finally:
532            self.close()
533
534    def run(self, atoms, use_stress=False):
535        for _ in self.irun(atoms, use_stress=use_stress):
536            pass
537
538
539class SocketIOCalculator(Calculator, IOContext):
540    implemented_properties = ['energy', 'free_energy', 'forces', 'stress']
541    supported_changes = {'positions', 'cell'}
542
543    def __init__(self, calc=None, port=None,
544                 unixsocket=None, timeout=None, log=None, *,
545                 launch_client=None):
546        """Initialize socket I/O calculator.
547
548        This calculator launches a server which passes atomic
549        coordinates and unit cells to an external code via a socket,
550        and receives energy, forces, and stress in return.
551
552        ASE integrates this with the Quantum Espresso, FHI-aims and
553        Siesta calculators.  This works with any external code that
554        supports running as a client over the i-PI protocol.
555
556        Parameters:
557
558        calc: calculator or None
559
560            If calc is not None, a client process will be launched
561            using calc.command, and the input file will be generated
562            using ``calc.write_input()``.  Otherwise only the server will
563            run, and it is up to the user to launch a compliant client
564            process.
565
566        port: integer
567
568            port number for socket.  Should normally be between 1025
569            and 65535.  Typical ports for are 31415 (default) or 3141.
570
571        unixsocket: str or None
572
573            if not None, ignore host and port, creating instead a
574            unix socket using this name prefixed with ``/tmp/ipi_``.
575            The socket is deleted when the calculator is closed.
576
577        timeout: float >= 0 or None
578
579            timeout for connection, by default infinite.  See
580            documentation of Python sockets.  For longer jobs it is
581            recommended to set a timeout in case of undetected
582            client-side failure.
583
584        log: file object or None (default)
585
586            logfile for communication over socket.  For debugging or
587            the curious.
588
589        In order to correctly close the sockets, it is
590        recommended to use this class within a with-block:
591
592        >>> with SocketIOCalculator(...) as calc:
593        ...    atoms.calc = calc
594        ...    atoms.get_forces()
595        ...    atoms.rattle()
596        ...    atoms.get_forces()
597
598        It is also possible to call calc.close() after
599        use.  This is best done in a finally-block."""
600
601        Calculator.__init__(self)
602
603        if calc is not None:
604            if launch_client is not None:
605                raise ValueError('Cannot pass both calc and launch_client')
606            launch_client = FileIOSocketClientLauncher(calc)
607        self.launch_client = launch_client
608        #self.calc = calc
609        self.timeout = timeout
610        self.server = None
611
612        self.log = self.openfile(log)
613
614        # We only hold these so we can pass them on to the server.
615        # They may both be None as stored here.
616        self._port = port
617        self._unixsocket = unixsocket
618
619        # If there is a calculator, we will launch in calculate() because
620        # we are responsible for executing the external process, too, and
621        # should do so before blocking.  Without a calculator we want to
622        # block immediately:
623        if self.launch_client is None:
624            self.server = self.launch_server()
625
626    def todict(self):
627        d = {'type': 'calculator',
628             'name': 'socket-driver'}
629        #if self.calc is not None:
630        #    d['calc'] = self.calc.todict()
631        return d
632
633    def launch_server(self):
634        return self.closelater(SocketServer(
635            #launch_client=launch_client,
636            port=self._port,
637            unixsocket=self._unixsocket,
638            timeout=self.timeout, log=self.log,
639        ))
640
641    def calculate(self, atoms=None, properties=['energy'],
642                  system_changes=all_changes):
643        bad = [change for change in system_changes
644               if change not in self.supported_changes]
645
646        # First time calculate() is called, system_changes will be
647        # all_changes.  After that, only positions and cell may change.
648        if self.atoms is not None and any(bad):
649            raise PropertyNotImplementedError(
650                'Cannot change {} through IPI protocol.  '
651                'Please create new socket calculator.'
652                .format(bad if len(bad) > 1 else bad[0]))
653
654        self.atoms = atoms.copy()
655
656        if self.server is None:
657            self.server = self.launch_server()
658            proc = self.launch_client(atoms, properties,
659                                      port=self._port,
660                                      unixsocket=self._unixsocket)
661            self.server.proc = proc  # XXX nasty hack
662
663        results = self.server.calculate(atoms)
664        results['free_energy'] = results['energy']
665        virial = results.pop('virial')
666        if self.atoms.cell.rank == 3 and any(self.atoms.pbc):
667            vol = atoms.get_volume()
668            results['stress'] = -full_3x3_to_voigt_6_stress(virial) / vol
669        self.results.update(results)
670
671    def close(self):
672        self.server = None
673        super().close()
674
675
676class PySocketIOClient:
677    def __init__(self, calculator_factory):
678        self._calculator_factory = calculator_factory
679
680    def __call__(self, atoms, properties=None, port=None, unixsocket=None):
681        import sys
682        import pickle
683
684        # We pickle everything first, so we won't need to bother with the
685        # process as long as it succeeds.
686        transferbytes = pickle.dumps([
687            dict(unixsocket=unixsocket, port=port),
688            atoms.copy(),
689            self._calculator_factory,
690        ])
691
692        proc = Popen([sys.executable, '-m', 'ase.calculators.socketio'],
693                     stdin=PIPE)
694
695        proc.stdin.write(transferbytes)
696        proc.stdin.close()
697        return proc
698
699    @staticmethod
700    def main():
701        import sys
702        import pickle
703
704        socketinfo, atoms, get_calculator = pickle.load(sys.stdin.buffer)
705        atoms.calc = get_calculator()
706        client = SocketClient(host='localhost',
707                              unixsocket=socketinfo.get('unixsocket'),
708                              port=socketinfo.get('port'))
709        # XXX In principle we could avoid calculating stress until
710        # someone requests the stress, could we not?
711        # Which would make use_stress boolean unnecessary.
712        client.run(atoms, use_stress=True)
713
714
715if __name__ == '__main__':
716    PySocketIOClient.main()
717