1import os 2import socket 3from subprocess import Popen, PIPE 4from contextlib import contextmanager 5 6import numpy as np 7 8from ase.calculators.calculator import (Calculator, all_changes, 9 PropertyNotImplementedError) 10import ase.units as units 11from ase.utils import IOContext 12from ase.stress import full_3x3_to_voigt_6_stress 13 14 15def actualunixsocketname(name): 16 return '/tmp/ipi_{}'.format(name) 17 18 19class SocketClosed(OSError): 20 pass 21 22 23class IPIProtocol: 24 """Communication using IPI protocol.""" 25 26 def __init__(self, socket, txt=None): 27 self.socket = socket 28 29 if txt is None: 30 def log(*args): 31 pass 32 else: 33 def log(*args): 34 print('Driver:', *args, file=txt) 35 txt.flush() 36 self.log = log 37 38 def sendmsg(self, msg): 39 self.log(' sendmsg', repr(msg)) 40 # assert msg in self.statements, msg 41 msg = msg.encode('ascii').ljust(12) 42 self.socket.sendall(msg) 43 44 def _recvall(self, nbytes): 45 """Repeatedly read chunks until we have nbytes. 46 47 Normally we get all bytes in one read, but that is not guaranteed.""" 48 remaining = nbytes 49 chunks = [] 50 while remaining > 0: 51 chunk = self.socket.recv(remaining) 52 if len(chunk) == 0: 53 # (If socket is still open, recv returns at least one byte) 54 raise SocketClosed() 55 chunks.append(chunk) 56 remaining -= len(chunk) 57 msg = b''.join(chunks) 58 assert len(msg) == nbytes and remaining == 0 59 return msg 60 61 def recvmsg(self): 62 msg = self._recvall(12) 63 if not msg: 64 raise SocketClosed() 65 66 assert len(msg) == 12, msg 67 msg = msg.rstrip().decode('ascii') 68 # assert msg in self.responses, msg 69 self.log(' recvmsg', repr(msg)) 70 return msg 71 72 def send(self, a, dtype): 73 buf = np.asarray(a, dtype).tobytes() 74 # self.log(' send {}'.format(np.array(a).ravel().tolist())) 75 self.log(' send {} bytes of {}'.format(len(buf), dtype)) 76 self.socket.sendall(buf) 77 78 def recv(self, shape, dtype): 79 a = np.empty(shape, dtype) 80 nbytes = np.dtype(dtype).itemsize * np.prod(shape) 81 buf = self._recvall(nbytes) 82 assert len(buf) == nbytes, (len(buf), nbytes) 83 self.log(' recv {} bytes of {}'.format(len(buf), dtype)) 84 # print(np.frombuffer(buf, dtype=dtype)) 85 a.flat[:] = np.frombuffer(buf, dtype=dtype) 86 # self.log(' recv {}'.format(a.ravel().tolist())) 87 assert np.isfinite(a).all() 88 return a 89 90 def sendposdata(self, cell, icell, positions): 91 assert cell.size == 9 92 assert icell.size == 9 93 assert positions.size % 3 == 0 94 95 self.log(' sendposdata') 96 self.sendmsg('POSDATA') 97 self.send(cell.T / units.Bohr, np.float64) 98 self.send(icell.T * units.Bohr, np.float64) 99 self.send(len(positions), np.int32) 100 self.send(positions / units.Bohr, np.float64) 101 102 def recvposdata(self): 103 cell = self.recv((3, 3), np.float64).T.copy() 104 icell = self.recv((3, 3), np.float64).T.copy() 105 natoms = self.recv(1, np.int32) 106 natoms = int(natoms) 107 positions = self.recv((natoms, 3), np.float64) 108 return cell * units.Bohr, icell / units.Bohr, positions * units.Bohr 109 110 def sendrecv_force(self): 111 self.log(' sendrecv_force') 112 self.sendmsg('GETFORCE') 113 msg = self.recvmsg() 114 assert msg == 'FORCEREADY', msg 115 e = self.recv(1, np.float64)[0] 116 natoms = self.recv(1, np.int32) 117 assert natoms >= 0 118 forces = self.recv((int(natoms), 3), np.float64) 119 virial = self.recv((3, 3), np.float64).T.copy() 120 nmorebytes = self.recv(1, np.int32) 121 nmorebytes = int(nmorebytes) 122 if nmorebytes > 0: 123 # Receiving 0 bytes will block forever on python2. 124 morebytes = self.recv(nmorebytes, np.byte) 125 else: 126 morebytes = b'' 127 return (e * units.Ha, (units.Ha / units.Bohr) * forces, 128 units.Ha * virial, morebytes) 129 130 def sendforce(self, energy, forces, virial, 131 morebytes=np.zeros(1, dtype=np.byte)): 132 assert np.array([energy]).size == 1 133 assert forces.shape[1] == 3 134 assert virial.shape == (3, 3) 135 136 self.log(' sendforce') 137 self.sendmsg('FORCEREADY') # mind the units 138 self.send(np.array([energy / units.Ha]), np.float64) 139 natoms = len(forces) 140 self.send(np.array([natoms]), np.int32) 141 self.send(units.Bohr / units.Ha * forces, np.float64) 142 self.send(1.0 / units.Ha * virial.T, np.float64) 143 # We prefer to always send at least one byte due to trouble with 144 # empty messages. Reading a closed socket yields 0 bytes 145 # and thus can be confused with a 0-length bytestring. 146 self.send(np.array([len(morebytes)]), np.int32) 147 self.send(morebytes, np.byte) 148 149 def status(self): 150 self.log(' status') 151 self.sendmsg('STATUS') 152 msg = self.recvmsg() 153 return msg 154 155 def end(self): 156 self.log(' end') 157 self.sendmsg('EXIT') 158 159 def recvinit(self): 160 self.log(' recvinit') 161 bead_index = self.recv(1, np.int32) 162 nbytes = self.recv(1, np.int32) 163 initbytes = self.recv(nbytes, np.byte) 164 return bead_index, initbytes 165 166 def sendinit(self): 167 # XXX Not sure what this function is supposed to send. 168 # It 'works' with QE, but for now we try not to call it. 169 self.log(' sendinit') 170 self.sendmsg('INIT') 171 self.send(0, np.int32) # 'bead index' always zero for now 172 # We send one byte, which is zero, since things may not work 173 # with 0 bytes. Apparently implementations ignore the 174 # initialization string anyway. 175 self.send(1, np.int32) 176 self.send(np.zeros(1), np.byte) # initialization string 177 178 def calculate(self, positions, cell): 179 self.log('calculate') 180 msg = self.status() 181 # We don't know how NEEDINIT is supposed to work, but some codes 182 # seem to be okay if we skip it and send the positions instead. 183 if msg == 'NEEDINIT': 184 self.sendinit() 185 msg = self.status() 186 assert msg == 'READY', msg 187 icell = np.linalg.pinv(cell).transpose() 188 self.sendposdata(cell, icell, positions) 189 msg = self.status() 190 assert msg == 'HAVEDATA', msg 191 e, forces, virial, morebytes = self.sendrecv_force() 192 r = dict(energy=e, 193 forces=forces, 194 virial=virial) 195 if morebytes: 196 r['morebytes'] = morebytes 197 return r 198 199 200@contextmanager 201def bind_unixsocket(socketfile): 202 assert socketfile.startswith('/tmp/ipi_'), socketfile 203 serversocket = socket.socket(socket.AF_UNIX) 204 try: 205 serversocket.bind(socketfile) 206 except OSError as err: 207 raise OSError('{}: {}'.format(err, repr(socketfile))) 208 209 try: 210 with serversocket: 211 yield serversocket 212 finally: 213 os.unlink(socketfile) 214 215 216@contextmanager 217def bind_inetsocket(port): 218 serversocket = socket.socket(socket.AF_INET) 219 serversocket.setsockopt(socket.SOL_SOCKET, 220 socket.SO_REUSEADDR, 1) 221 serversocket.bind(('', port)) 222 with serversocket: 223 yield serversocket 224 225 226class FileIOSocketClientLauncher: 227 def __init__(self, calc): 228 self.calc = calc 229 230 def __call__(self, atoms, properties=None, port=None, unixsocket=None): 231 assert self.calc is not None 232 cmd = self.calc.command.replace('PREFIX', self.calc.prefix) 233 self.calc.write_input(atoms, properties=properties, 234 system_changes=all_changes) 235 cwd = self.calc.directory 236 cmd = cmd.format(port=port, unixsocket=unixsocket) 237 return Popen(cmd, shell=True, cwd=cwd) 238 239 240class SocketServer(IOContext): 241 default_port = 31415 242 243 def __init__(self, # launch_client=None, 244 port=None, unixsocket=None, timeout=None, 245 log=None): 246 """Create server and listen for connections. 247 248 Parameters: 249 250 client_command: Shell command to launch client process, or None 251 The process will be launched immediately, if given. 252 Else the user is expected to launch a client whose connection 253 the server will then accept at any time. 254 One calculate() is called, the server will block to wait 255 for the client. 256 port: integer or None 257 Port on which to listen for INET connections. Defaults 258 to 31415 if neither this nor unixsocket is specified. 259 unixsocket: string or None 260 Filename for unix socket. 261 timeout: float or None 262 timeout in seconds, or unlimited by default. 263 This parameter is passed to the Python socket object; see 264 documentation therof 265 log: file object or None 266 useful debug messages are written to this.""" 267 268 if unixsocket is None and port is None: 269 port = self.default_port 270 elif unixsocket is not None and port is not None: 271 raise ValueError('Specify only one of unixsocket and port') 272 273 self.port = port 274 self.unixsocket = unixsocket 275 self.timeout = timeout 276 self._closed = False 277 278 if unixsocket is not None: 279 actualsocket = actualunixsocketname(unixsocket) 280 conn_name = 'UNIX-socket {}'.format(actualsocket) 281 socket_context = bind_unixsocket(actualsocket) 282 else: 283 conn_name = 'INET port {}'.format(port) 284 socket_context = bind_inetsocket(port) 285 286 self.serversocket = self.closelater(socket_context) 287 288 if log: 289 print('Accepting clients on {}'.format(conn_name), file=log) 290 291 self.serversocket.settimeout(timeout) 292 293 self.serversocket.listen(1) 294 295 self.log = log 296 297 self.proc = None 298 299 self.protocol = None 300 self.clientsocket = None 301 self.address = None 302 303 #if launch_client is not None: 304 # self.proc = launch_client(port=port, unixsocket=unixsocket) 305 306 def _accept(self): 307 """Wait for client and establish connection.""" 308 # It should perhaps be possible for process to be launched by user 309 log = self.log 310 if log: 311 print('Awaiting client', file=self.log) 312 313 # If we launched the subprocess, the process may crash. 314 # We want to detect this, using loop with timeouts, and 315 # raise an error rather than blocking forever. 316 if self.proc is not None: 317 self.serversocket.settimeout(1.0) 318 319 while True: 320 try: 321 self.clientsocket, self.address = self.serversocket.accept() 322 self.closelater(self.clientsocket) 323 except socket.timeout: 324 if self.proc is not None: 325 status = self.proc.poll() 326 if status is not None: 327 raise OSError('Subprocess terminated unexpectedly' 328 ' with status {}'.format(status)) 329 else: 330 break 331 332 self.serversocket.settimeout(self.timeout) 333 self.clientsocket.settimeout(self.timeout) 334 335 if log: 336 # For unix sockets, address is b''. 337 source = ('client' if self.address == b'' else self.address) 338 print('Accepted connection from {}'.format(source), file=log) 339 340 self.protocol = IPIProtocol(self.clientsocket, txt=log) 341 342 def close(self): 343 if self._closed: 344 return 345 346 super().close() 347 348 if self.log: 349 print('Close socket server', file=self.log) 350 self._closed = True 351 352 # Proper way to close sockets? 353 # And indeed i-pi connections... 354 # if self.protocol is not None: 355 # self.protocol.end() # Send end-of-communication string 356 self.protocol = None 357 if self.proc is not None: 358 exitcode = self.proc.wait() 359 if exitcode != 0: 360 import warnings 361 # Quantum Espresso seems to always exit with status 128, 362 # even if successful. 363 # Should investigate at some point 364 warnings.warn('Subprocess exited with status {}' 365 .format(exitcode)) 366 # self.log('IPI server closed') 367 368 def calculate(self, atoms): 369 """Send geometry to client and return calculated things as dict. 370 371 This will block until client has established connection, then 372 wait for the client to finish the calculation.""" 373 assert not self._closed 374 375 # If we have not established connection yet, we must block 376 # until the client catches up: 377 if self.protocol is None: 378 self._accept() 379 return self.protocol.calculate(atoms.positions, atoms.cell) 380 381 382class SocketClient: 383 def __init__(self, host='localhost', port=None, 384 unixsocket=None, timeout=None, log=None, comm=None): 385 """Create client and connect to server. 386 387 Parameters: 388 389 host: string 390 Hostname of server. Defaults to localhost 391 port: integer or None 392 Port to which to connect. By default 31415. 393 unixsocket: string or None 394 If specified, use corresponding UNIX socket. 395 See documentation of unixsocket for SocketIOCalculator. 396 timeout: float or None 397 See documentation of timeout for SocketIOCalculator. 398 log: file object or None 399 Log events to this file 400 comm: communicator or None 401 MPI communicator object. Defaults to ase.parallel.world. 402 When ASE runs in parallel, only the process with world.rank == 0 403 will communicate over the socket. The received information 404 will then be broadcast on the communicator. The SocketClient 405 must be created on all ranks of world, and will see the same 406 Atoms objects.""" 407 if comm is None: 408 from ase.parallel import world 409 comm = world 410 411 # Only rank0 actually does the socket work. 412 # The other ranks only need to follow. 413 # 414 # Note: We actually refrain from assigning all the 415 # socket-related things except on master 416 self.comm = comm 417 418 if self.comm.rank == 0: 419 if unixsocket is not None: 420 sock = socket.socket(socket.AF_UNIX) 421 actualsocket = actualunixsocketname(unixsocket) 422 sock.connect(actualsocket) 423 else: 424 if port is None: 425 port = SocketServer.default_port 426 sock = socket.socket(socket.AF_INET) 427 sock.connect((host, port)) 428 sock.settimeout(timeout) 429 self.host = host 430 self.port = port 431 self.unixsocket = unixsocket 432 433 self.protocol = IPIProtocol(sock, txt=log) 434 self.log = self.protocol.log 435 self.closed = False 436 437 self.bead_index = 0 438 self.bead_initbytes = b'' 439 self.state = 'READY' 440 441 def close(self): 442 if not self.closed: 443 self.log('Close SocketClient') 444 self.closed = True 445 self.protocol.socket.close() 446 447 def calculate(self, atoms, use_stress): 448 # We should also broadcast the bead index, once we support doing 449 # multiple beads. 450 self.comm.broadcast(atoms.positions, 0) 451 self.comm.broadcast(np.ascontiguousarray(atoms.cell), 0) 452 453 energy = atoms.get_potential_energy() 454 forces = atoms.get_forces() 455 if use_stress: 456 stress = atoms.get_stress(voigt=False) 457 virial = -atoms.get_volume() * stress 458 else: 459 virial = np.zeros((3, 3)) 460 return energy, forces, virial 461 462 def irun(self, atoms, use_stress=None): 463 if use_stress is None: 464 use_stress = any(atoms.pbc) 465 466 my_irun = self.irun_rank0 if self.comm.rank == 0 else self.irun_rankN 467 return my_irun(atoms, use_stress) 468 469 def irun_rankN(self, atoms, use_stress=True): 470 stop_criterion = np.zeros(1, bool) 471 while True: 472 self.comm.broadcast(stop_criterion, 0) 473 if stop_criterion[0]: 474 return 475 476 self.calculate(atoms, use_stress) 477 yield 478 479 def irun_rank0(self, atoms, use_stress=True): 480 # For every step we either calculate or quit. We need to 481 # tell other MPI processes (if this is MPI-parallel) whether they 482 # should calculate or quit. 483 try: 484 while True: 485 try: 486 msg = self.protocol.recvmsg() 487 except SocketClosed: 488 # Server closed the connection, but we want to 489 # exit gracefully anyway 490 msg = 'EXIT' 491 492 if msg == 'EXIT': 493 # Send stop signal to clients: 494 self.comm.broadcast(np.ones(1, bool), 0) 495 # (When otherwise exiting, things crashed and we should 496 # let MPI_ABORT take care of the mess instead of trying 497 # to synchronize the exit) 498 return 499 elif msg == 'STATUS': 500 self.protocol.sendmsg(self.state) 501 elif msg == 'POSDATA': 502 assert self.state == 'READY' 503 cell, icell, positions = self.protocol.recvposdata() 504 atoms.cell[:] = cell 505 atoms.positions[:] = positions 506 507 # User may wish to do something with the atoms object now. 508 # Should we provide option to yield here? 509 # 510 # (In that case we should MPI-synchronize *before* 511 # whereas now we do it after.) 512 513 # Send signal for other ranks to proceed with calculation: 514 self.comm.broadcast(np.zeros(1, bool), 0) 515 energy, forces, virial = self.calculate(atoms, use_stress) 516 517 self.state = 'HAVEDATA' 518 yield 519 elif msg == 'GETFORCE': 520 assert self.state == 'HAVEDATA', self.state 521 self.protocol.sendforce(energy, forces, virial) 522 self.state = 'NEEDINIT' 523 elif msg == 'INIT': 524 assert self.state == 'NEEDINIT' 525 bead_index, initbytes = self.protocol.recvinit() 526 self.bead_index = bead_index 527 self.bead_initbytes = initbytes 528 self.state = 'READY' 529 else: 530 raise KeyError('Bad message', msg) 531 finally: 532 self.close() 533 534 def run(self, atoms, use_stress=False): 535 for _ in self.irun(atoms, use_stress=use_stress): 536 pass 537 538 539class SocketIOCalculator(Calculator, IOContext): 540 implemented_properties = ['energy', 'free_energy', 'forces', 'stress'] 541 supported_changes = {'positions', 'cell'} 542 543 def __init__(self, calc=None, port=None, 544 unixsocket=None, timeout=None, log=None, *, 545 launch_client=None): 546 """Initialize socket I/O calculator. 547 548 This calculator launches a server which passes atomic 549 coordinates and unit cells to an external code via a socket, 550 and receives energy, forces, and stress in return. 551 552 ASE integrates this with the Quantum Espresso, FHI-aims and 553 Siesta calculators. This works with any external code that 554 supports running as a client over the i-PI protocol. 555 556 Parameters: 557 558 calc: calculator or None 559 560 If calc is not None, a client process will be launched 561 using calc.command, and the input file will be generated 562 using ``calc.write_input()``. Otherwise only the server will 563 run, and it is up to the user to launch a compliant client 564 process. 565 566 port: integer 567 568 port number for socket. Should normally be between 1025 569 and 65535. Typical ports for are 31415 (default) or 3141. 570 571 unixsocket: str or None 572 573 if not None, ignore host and port, creating instead a 574 unix socket using this name prefixed with ``/tmp/ipi_``. 575 The socket is deleted when the calculator is closed. 576 577 timeout: float >= 0 or None 578 579 timeout for connection, by default infinite. See 580 documentation of Python sockets. For longer jobs it is 581 recommended to set a timeout in case of undetected 582 client-side failure. 583 584 log: file object or None (default) 585 586 logfile for communication over socket. For debugging or 587 the curious. 588 589 In order to correctly close the sockets, it is 590 recommended to use this class within a with-block: 591 592 >>> with SocketIOCalculator(...) as calc: 593 ... atoms.calc = calc 594 ... atoms.get_forces() 595 ... atoms.rattle() 596 ... atoms.get_forces() 597 598 It is also possible to call calc.close() after 599 use. This is best done in a finally-block.""" 600 601 Calculator.__init__(self) 602 603 if calc is not None: 604 if launch_client is not None: 605 raise ValueError('Cannot pass both calc and launch_client') 606 launch_client = FileIOSocketClientLauncher(calc) 607 self.launch_client = launch_client 608 #self.calc = calc 609 self.timeout = timeout 610 self.server = None 611 612 self.log = self.openfile(log) 613 614 # We only hold these so we can pass them on to the server. 615 # They may both be None as stored here. 616 self._port = port 617 self._unixsocket = unixsocket 618 619 # If there is a calculator, we will launch in calculate() because 620 # we are responsible for executing the external process, too, and 621 # should do so before blocking. Without a calculator we want to 622 # block immediately: 623 if self.launch_client is None: 624 self.server = self.launch_server() 625 626 def todict(self): 627 d = {'type': 'calculator', 628 'name': 'socket-driver'} 629 #if self.calc is not None: 630 # d['calc'] = self.calc.todict() 631 return d 632 633 def launch_server(self): 634 return self.closelater(SocketServer( 635 #launch_client=launch_client, 636 port=self._port, 637 unixsocket=self._unixsocket, 638 timeout=self.timeout, log=self.log, 639 )) 640 641 def calculate(self, atoms=None, properties=['energy'], 642 system_changes=all_changes): 643 bad = [change for change in system_changes 644 if change not in self.supported_changes] 645 646 # First time calculate() is called, system_changes will be 647 # all_changes. After that, only positions and cell may change. 648 if self.atoms is not None and any(bad): 649 raise PropertyNotImplementedError( 650 'Cannot change {} through IPI protocol. ' 651 'Please create new socket calculator.' 652 .format(bad if len(bad) > 1 else bad[0])) 653 654 self.atoms = atoms.copy() 655 656 if self.server is None: 657 self.server = self.launch_server() 658 proc = self.launch_client(atoms, properties, 659 port=self._port, 660 unixsocket=self._unixsocket) 661 self.server.proc = proc # XXX nasty hack 662 663 results = self.server.calculate(atoms) 664 results['free_energy'] = results['energy'] 665 virial = results.pop('virial') 666 if self.atoms.cell.rank == 3 and any(self.atoms.pbc): 667 vol = atoms.get_volume() 668 results['stress'] = -full_3x3_to_voigt_6_stress(virial) / vol 669 self.results.update(results) 670 671 def close(self): 672 self.server = None 673 super().close() 674 675 676class PySocketIOClient: 677 def __init__(self, calculator_factory): 678 self._calculator_factory = calculator_factory 679 680 def __call__(self, atoms, properties=None, port=None, unixsocket=None): 681 import sys 682 import pickle 683 684 # We pickle everything first, so we won't need to bother with the 685 # process as long as it succeeds. 686 transferbytes = pickle.dumps([ 687 dict(unixsocket=unixsocket, port=port), 688 atoms.copy(), 689 self._calculator_factory, 690 ]) 691 692 proc = Popen([sys.executable, '-m', 'ase.calculators.socketio'], 693 stdin=PIPE) 694 695 proc.stdin.write(transferbytes) 696 proc.stdin.close() 697 return proc 698 699 @staticmethod 700 def main(): 701 import sys 702 import pickle 703 704 socketinfo, atoms, get_calculator = pickle.load(sys.stdin.buffer) 705 atoms.calc = get_calculator() 706 client = SocketClient(host='localhost', 707 unixsocket=socketinfo.get('unixsocket'), 708 port=socketinfo.get('port')) 709 # XXX In principle we could avoid calculating stress until 710 # someone requests the stress, could we not? 711 # Which would make use_stress boolean unnecessary. 712 client.run(atoms, use_stress=True) 713 714 715if __name__ == '__main__': 716 PySocketIOClient.main() 717